Compare commits

...

4 Commits

37 changed files with 11329 additions and 246 deletions
+1431
View File
File diff suppressed because it is too large Load Diff
+1475
View File
File diff suppressed because it is too large Load Diff
+214 -196
View File
@@ -1,16 +1,79 @@
# Aegis - MITRE ATT&CK Coverage Platform
Aegis is a comprehensive platform for tracking and managing security coverage against the MITRE ATT&CK framework. It enables security teams to document, validate, and visualize their defensive capabilities against known adversary techniques.
Aegis is a comprehensive platform for tracking and managing security coverage against the MITRE ATT&CK framework. It enables security teams to document, validate, and visualize their defensive capabilities against known adversary techniques through a structured Red Team / Blue Team validation workflow.
## Features
- **MITRE ATT&CK Integration**: Automatic synchronization with the MITRE ATT&CK framework via TAXII (with GitHub fallback), scheduled every 24h
- **Red/Blue Validation Workflow**: Structured dual-validation lifecycle for security tests (draft → red_executing → blue_evaluating → in_review → validated/rejected)
- **Test Template Catalog**: Import tests from Atomic Red Team, create custom templates, and instantiate real tests from them
- **Dual Validation**: Independent approval/rejection by Red Lead and Blue Lead before a test is finalized
- **Coverage Tracking**: Track validation status for each technique (validated, partial, not covered, in progress)
- **Test Management**: Document and manage security tests with full audit trail
- **Evidence Storage**: Secure evidence file storage with SHA256 integrity verification
- **Evidence Storage**: Secure evidence file storage with SHA256 integrity verification, separated by team (red/blue)
- **In-App Notifications**: Real-time notification bell with polling, automatic alerts on state changes
- **Reports & Export**: Coverage summary, test results, and remediation reports in JSON and CSV formats
- **Remediation Tracking**: Step-by-step remediation assignments with status tracking per test
- **Role-Based Access Control**: Granular permissions for red team, blue team, and leadership roles
- **Intel Monitoring**: Automated scanning for new threat intelligence related to techniques
- **Metrics Dashboard**: Real-time coverage metrics and reporting by tactic
- **Metrics Dashboard**: Pipeline funnel, team activity, validation rates, and recent tests
## Red Team / Blue Team Validation Flow
```
┌─────────────────────────────────────────────────────────────────────────┐
│ TEST LIFECYCLE │
│ │
│ ┌──────┐ ┌──────────────┐ ┌─────────────────┐ ┌───────────┐ │
│ │ DRAFT│───▶│RED_EXECUTING │───▶│ BLUE_EVALUATING │───▶│ IN_REVIEW │ │
│ └──────┘ └──────────────┘ └─────────────────┘ └───────────┘ │
│ │ │
│ ┌────────────────────┤ │
│ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ │
│ │ REJECTED │ │VALIDATED │ │
│ └──────────┘ └──────────┘ │
│ │ │
│ └──────▶ Back to DRAFT │
└─────────────────────────────────────────────────────────────────────────┘
```
### States
| State | Description | Who acts |
|-------|-------------|----------|
| `draft` | Created, pending execution | Red Tech |
| `red_executing` | Red Team documents attack & uploads evidence | Red Tech |
| `blue_evaluating` | Blue Team documents detection & uploads evidence | Blue Tech |
| `in_review` | Both managers review evidence | Red Lead, Blue Lead |
| `validated` | Approved by both managers | — (terminal) |
| `rejected` | Rejected — returns to draft for redo | Red/Blue Lead can reopen |
### Dual Validation
Both Red Lead and Blue Lead must independently vote:
- **Both approve** → test moves to `validated`
- **Either rejects** → test moves to `rejected`
- **One votes, other pending** → stays in `in_review`
## User Roles
| Role | Description | Capabilities |
|------|-------------|-------------|
| `admin` | Full system access | Everything |
| `red_tech` | Red team technician | Create tests, document attacks, upload red evidence |
| `blue_tech` | Blue team technician | Document detection, upload blue evidence |
| `red_lead` | Red team lead | Validate/reject the red side of tests |
| `blue_lead` | Blue team lead | Validate/reject the blue side of tests |
| `viewer` | Read-only access | View all data |
## Test Template Catalog
Tests can be created from predefined templates sourced from:
1. **Atomic Red Team** (Red Canary) — imported via the System admin panel
2. **Custom templates** — created by admins with suggested procedures and remediation
3. **MITRE procedures** — based on MITRE ATT&CK documentation
Templates include attack procedures, expected detections, suggested tools, severity levels, and suggested remediation steps. When instantiated, these fields are pre-populated into the new test.
## Tech Stack
@@ -19,6 +82,7 @@ Aegis is a comprehensive platform for tracking and managing security coverage ag
- **Object Storage**: MinIO (S3-compatible)
- **ORM**: SQLAlchemy with Alembic migrations
- **Frontend**: React 19 + TypeScript + Vite + Tailwind CSS v4 + TanStack Query
- **Scheduler**: APScheduler (MITRE sync, Intel scan, Notification cleanup)
## Quick Start
@@ -50,54 +114,43 @@ docker exec -w /app aegis-backend-1 alembic upgrade head
docker exec -w /app aegis-backend-1 python -m app.seed
```
5. Start the frontend (requires Node.js 20+ or Docker):
5. Start the frontend:
```bash
# Option A — with Node.js installed locally
cd frontend && npm install && npm run dev
# Option B — via Docker
docker run --rm -v ./frontend:/app -w /app -p 5173:5173 node:20-alpine sh -c "npm run dev"
```
6. Verify the installation:
```bash
# Backend health
curl http://localhost:8000/health
# Expected: {"status":"ok"}
# Frontend
# Open http://localhost:5173 — should show the Aegis login page
# Open http://localhost:5173 — Aegis login page
```
### Authentication
The platform uses JWT-based authentication. After seeding, log in with the default admin credentials:
JWT-based authentication. Default admin credentials after seeding:
```bash
# Obtain a token
curl -X POST http://localhost:8000/api/v1/auth/login \
-d "username=admin&password=admin123"
# Use the token to access protected endpoints
curl http://localhost:8000/api/v1/auth/me \
-H "Authorization: Bearer <your-token>"
```
> **Important:** Change the default `admin123` password and `SECRET_KEY` in production.
## Services
| Service | Port | Description |
|----------|------|-------------|
| Service | Port | Description |
|---------|------|-------------|
| Frontend | 5173 | React dev server (Vite) |
| Backend | 8000 | FastAPI REST API |
| PostgreSQL | 5433 | Database (mapped to 5433 to avoid conflicts) |
| Backend | 8000 | FastAPI REST API |
| PostgreSQL | 5433 | Database |
| MinIO API | 9000 | S3-compatible object storage |
| MinIO Console | 9001 | MinIO web interface |
## API Documentation
Once the backend is running, access the interactive API documentation at:
Interactive API documentation available at:
- **Swagger UI**: http://localhost:8000/docs
- **ReDoc**: http://localhost:8000/redoc
@@ -113,39 +166,81 @@ Once the backend is running, access the interactive API documentation at:
### Techniques
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| GET | `/api/v1/techniques` | Authenticated | List all (filters: `?tactic=`, `?status=`, `?review_required=`) |
| GET | `/api/v1/techniques` | Authenticated | List all (filters: tactic, status, review_required) |
| GET | `/api/v1/techniques/{mitre_id}` | Authenticated | Detail with associated tests |
| POST | `/api/v1/techniques` | Admin | Create technique |
| PATCH | `/api/v1/techniques/{mitre_id}` | Admin | Update technique fields |
| PATCH | `/api/v1/techniques/{mitre_id}/review` | Lead, Admin | Mark as reviewed |
### Tests
### Tests — Red/Blue Workflow
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| GET | `/api/v1/tests` | Authenticated | List with filters (state, technique, platform, creator, pending_validation_side) |
| POST | `/api/v1/tests` | Red Tech, Admin | Create test (state=draft) |
| GET | `/api/v1/tests/{id}` | Authenticated | Detail with evidences |
| PATCH | `/api/v1/tests/{id}` | Creator, Admin | Update (only draft/rejected) |
| POST | `/api/v1/tests/{id}/validate` | Lead, Admin | Validate + recalculate technique status |
| POST | `/api/v1/tests/{id}/reject` | Lead, Admin | Reject test |
| POST | `/api/v1/tests/from-template` | Red Tech, Admin | Create from template (pre-populates fields) |
| GET | `/api/v1/tests/{id}` | Authenticated | Detail with split red/blue evidences |
| PATCH | `/api/v1/tests/{id}` | Creator, Admin | General update (draft/rejected only) |
| PATCH | `/api/v1/tests/{id}/red` | Red Tech, Admin | Red Team fields (draft, red_executing) |
| PATCH | `/api/v1/tests/{id}/blue` | Blue Tech, Admin | Blue Team fields (blue_evaluating) |
| PATCH | `/api/v1/tests/{id}/remediation` | Authenticated | Update remediation fields |
| POST | `/api/v1/tests/{id}/start-execution` | Red Tech, Admin | draft → red_executing |
| POST | `/api/v1/tests/{id}/submit-red` | Red Tech, Admin | red_executing → blue_evaluating |
| POST | `/api/v1/tests/{id}/submit-blue` | Blue Tech, Admin | blue_evaluating → in_review |
| POST | `/api/v1/tests/{id}/validate-red` | Red Lead, Admin | Red Lead approves/rejects |
| POST | `/api/v1/tests/{id}/validate-blue` | Blue Lead, Admin | Blue Lead approves/rejects |
| POST | `/api/v1/tests/{id}/reopen` | Lead, Admin | rejected → draft (clears validation) |
| GET | `/api/v1/tests/{id}/timeline` | Authenticated | Audit-log history for this test |
### Test Templates
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| GET | `/api/v1/test-templates` | Authenticated | List templates (filters: source, platform, severity, search, mitre_technique_id) |
| POST | `/api/v1/test-templates` | Admin | Create custom template |
| GET | `/api/v1/test-templates/stats` | Admin | Catalog statistics |
| GET | `/api/v1/test-templates/{id}` | Authenticated | Template detail |
| PATCH | `/api/v1/test-templates/{id}` | Admin | Update template |
| DELETE | `/api/v1/test-templates/{id}` | Admin | Soft-delete (deactivate) |
| POST | `/api/v1/test-templates/{id}/toggle-active` | Admin | Toggle active/inactive |
### Evidence
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| POST | `/api/v1/tests/{test_id}/evidence` | Authenticated | Upload evidence file (SHA-256 verified) |
| GET | `/api/v1/evidence/{id}` | Authenticated | Get metadata + presigned download URL |
| POST | `/api/v1/tests/{test_id}/evidence` | Authenticated | Upload evidence (team=red/blue) |
| GET | `/api/v1/evidence/{id}` | Authenticated | Metadata + presigned download URL |
### System
### Notifications
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| POST | `/api/v1/system/sync-mitre` | Admin | Manually trigger MITRE ATT&CK sync |
| POST | `/api/v1/system/run-intel-scan` | Admin | Manually trigger threat-intel RSS scan |
| GET | `/api/v1/system/scheduler-status` | Admin | Background scheduler health & job list |
| GET | `/api/v1/notifications` | Authenticated | List notifications (paginated, limit=20) |
| GET | `/api/v1/notifications/unread-count` | Authenticated | Unread notification count |
| PATCH | `/api/v1/notifications/{id}/read` | Authenticated | Mark one as read |
| POST | `/api/v1/notifications/read-all` | Authenticated | Mark all as read |
### Reports
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| GET | `/api/v1/reports/coverage-summary` | Authenticated | Full coverage JSON report (filters: tactic, platform) |
| GET | `/api/v1/reports/coverage-csv` | Authenticated | CSV export of coverage |
| GET | `/api/v1/reports/test-results` | Authenticated | Test results report (filters: state, date_from, date_to) |
| GET | `/api/v1/reports/remediation-status` | Authenticated | Remediation status report (filter: status) |
### Metrics
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| GET | `/api/v1/metrics/summary` | Authenticated | Global coverage summary (counts + percentage) |
| GET | `/api/v1/metrics/by-tactic` | Authenticated | Coverage breakdown per MITRE tactic |
| GET | `/api/v1/metrics/summary` | Authenticated | Global coverage summary |
| GET | `/api/v1/metrics/by-tactic` | Authenticated | Coverage by MITRE tactic |
| GET | `/api/v1/metrics/test-pipeline` | Authenticated | Test counts by pipeline state |
| GET | `/api/v1/metrics/team-activity` | Authenticated | Red/Blue team activity |
| GET | `/api/v1/metrics/validation-rate` | Authenticated | Approval/rejection rates by lead |
| GET | `/api/v1/metrics/recent-tests` | Authenticated | Last 10 updated tests |
### System (Admin)
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| POST | `/api/v1/system/sync-mitre` | Admin | Trigger MITRE ATT&CK sync |
| POST | `/api/v1/system/run-intel-scan` | Admin | Trigger threat-intel RSS scan |
| POST | `/api/v1/system/import-atomic-red-team` | Admin | Import Atomic Red Team templates |
| GET | `/api/v1/system/scheduler-status` | Admin | Background scheduler health |
### Users (Admin)
| Method | Route | Auth | Description |
@@ -153,12 +248,12 @@ Once the backend is running, access the interactive API documentation at:
| GET | `/api/v1/users` | Admin | List all users |
| POST | `/api/v1/users` | Admin | Create new user |
| GET | `/api/v1/users/{id}` | Admin | Get user by ID |
| PATCH | `/api/v1/users/{id}` | Admin | Update user (role, email, active status) |
| PATCH | `/api/v1/users/{id}` | Admin | Update user |
### Audit Logs (Admin)
| Method | Route | Auth | Description |
|--------|-------|------|-------------|
| GET | `/api/v1/audit-logs` | Admin | List audit logs (filters: `?action=`, `?entity_type=`, `?start_date=`, `?end_date=`) |
| GET | `/api/v1/audit-logs` | Admin | List audit logs (filters: action, entity_type, dates) |
| GET | `/api/v1/audit-logs/actions` | Admin | List distinct action types |
| GET | `/api/v1/audit-logs/entity-types` | Admin | List distinct entity types |
@@ -166,200 +261,123 @@ Once the backend is running, access the interactive API documentation at:
```
Aegis/
├── docker-compose.yml # Docker services configuration
├── docker-compose.yml
├── backend/
│ ├── Dockerfile # Backend container definition
│ ├── requirements.txt # Python dependencies
│ ├── alembic.ini # Alembic configuration
│ ├── alembic/ # Database migrations
│ │ ├── env.py
│ │ ├── versions/ # Migration files
│ │ └── ...
│ ├── Dockerfile
│ ├── requirements.txt
│ ├── alembic.ini
│ ├── alembic/versions/ # b001b007 migration files
│ └── app/
│ ├── __init__.py
│ ├── main.py # FastAPI application entry point
│ ├── config.py # Application settings
│ ├── database.py # SQLAlchemy configuration
│ ├── auth.py # Password hashing & JWT utilities
│ ├── seed.py # Admin seed script (python -m app.seed)
├── models/ # SQLAlchemy models
│ │ ├── user.py # User authentication model
│ │ ├── technique.py # MITRE ATT&CK techniques
│ │ ├── test.py # Security tests
│ │ ├── evidence.py # Test evidence files
│ │ ├── intel.py # Threat intelligence items
│ │ ├── audit.py # Audit logging
│ │ └── enums.py # Shared enumerations
│ ├── storage.py # MinIO/S3 client (upload, presigned URLs)
├── schemas/ # Pydantic request/response schemas
│ │ ├── auth.py # LoginRequest, TokenResponse, UserOut
│ │ ├── technique.py # TechniqueCreate/Update/Out/Summary
│ │ ── test.py # TestCreate/Update/Out/Validate
│ └── evidence.py # EvidenceOut
├── routers/ # API endpoint routers
│ │ ├── auth.py # POST /auth/login, GET /auth/me
│ │ ├── techniques.py # CRUD techniques (list, detail, create, update, review)
│ │ ├── tests.py # CRUD tests (create, detail, update, validate, reject)
│ │ ├── evidence.py # Upload evidence, presigned download
│ │ ── system.py # MITRE sync trigger, scheduler status
│ ├── metrics.py # Coverage summary & per-tactic breakdown
│ │ ├── users.py # User management (admin only)
│ │ ── audit.py # Audit log viewer (admin only)
├── dependencies/ # FastAPI dependencies (DI)
│ │ └── auth.py # get_current_user, require_role, require_any_role
── jobs/ # Background scheduled jobs
└── mitre_sync_job.py # APScheduler: MITRE sync (24h) + Intel scan (7d)
│ └── services/ # Business logic services
├── audit_service.py
├── status_service.py # Recalculate technique status from tests
├── mitre_sync_service.py # MITRE ATT&CK sync via TAXII / GitHub
└── intel_service.py # Automated intel scan via RSS feeds
└── frontend/ # React + TypeScript frontend
├── index.html
├── package.json
├── tsconfig.json
├── vite.config.ts
└── src/
├── main.tsx # App entry point
├── App.tsx # Route definitions
── index.css # Tailwind CSS entry
── api/ # Axios clients
├── client.ts # Base axios instance with JWT interceptor
├── auth.ts # login(), getMe()
├── metrics.ts # getCoverageSummary(), getCoverageByTactic()
├── techniques.ts # getTechniques(), getTechniqueByMitreId()
├── tests.ts # createTest(), validateTest(), rejectTest()
│ ├── evidence.ts # uploadEvidence(), getEvidence()
│ ├── system.ts # triggerMitreSync(), triggerIntelScan()
│ ├── users.ts # getUsers(), createUser(), updateUser()
│ └── audit.ts # getAuditLogs(), getAuditActions()
├── context/
│ └── AuthContext.tsx # Auth state: user, login, logout, isLoading
├── components/
│ ├── Layout.tsx # Sidebar + header + <Outlet/>
│ ├── Sidebar.tsx # Nav links (role-aware)
│ ├── ProtectedRoute.tsx # Auth route guard with role support
│ ├── CoverageSummaryCard.tsx # Metric card component
│ ├── TacticCoverageChart.tsx # Coverage breakdown table
│ ├── AttackMatrix.tsx # Interactive technique grid
│ ├── TechniqueCell.tsx # Individual technique cell in matrix
│ ├── TestForm.tsx # Reusable test creation/edit form
│ ├── EvidenceUpload.tsx # Drag & drop file upload
│ ├── EvidenceList.tsx # Evidence file listing
│ ├── ErrorBoundary.tsx # Global error boundary
│ ├── ErrorMessage.tsx # Reusable error display
│ ├── LoadingSpinner.tsx # Reusable loading indicator
│ └── Toast.tsx # Toast notification system
├── pages/
│ ├── LoginPage.tsx # User authentication form
│ ├── DashboardPage.tsx # Coverage metrics dashboard with summary cards
│ ├── TechniquesPage.tsx # Interactive ATT&CK matrix view with filters
│ ├── TechniqueDetailPage.tsx # Individual technique detail with tests
│ ├── TestsPage.tsx # Tests overview and navigation
│ ├── TestCreatePage.tsx # Test creation form
│ ├── TestDetailPage.tsx # Test details with evidence upload
│ ├── SystemPage.tsx # Admin panel for MITRE sync & intel scan
│ ├── UsersPage.tsx # User management (admin only)
│ └── AuditLogPage.tsx # Audit log viewer (admin only)
├── types/
│ └── models.ts # TS interfaces matching backend schemas
├── hooks/
└── lib/
│ ├── main.py # FastAPI app with all routers
│ ├── config.py # Settings from environment
│ ├── database.py # SQLAlchemy engine + session
│ ├── storage.py # MinIO/S3 helpers
│ ├── models/
├── user.py # User with roles
│ ├── technique.py # MITRE ATT&CK techniques
│ │ ├── test.py # Tests with Red/Blue + remediation fields
│ │ ├── test_template.py # Template catalog
│ │ ├── evidence.py # Evidence files (team-separated)
│ │ ├── notification.py # In-app notifications
│ │ ├── intel.py # Threat intelligence
│ │ ├── audit.py # Audit logging
│ │ └── enums.py # Shared enumerations
│ ├── schemas/ # Pydantic schemas
│ ├── test.py # TestCreate/Red/Blue/Validate/Remediation
│ │ ├── test_template.py # Template CRUD schemas
│ │ ├── notification.py # NotificationOut, UnreadCountOut
│ │ ── metrics.py # Pipeline, TeamActivity, ValidationRate
├── routers/ # API endpoints
│ ├── tests.py # Full Red/Blue workflow endpoints
│ │ ├── test_templates.py # Template CRUD + import + stats
│ │ ├── notifications.py # Notification list/read/mark
│ │ ├── reports.py # Coverage/results/remediation reports
│ │ ├── metrics.py # V1 + V2 metrics endpoints
│ │ ── ... # auth, techniques, evidence, system, users, audit
├── services/
│ │ ├── test_workflow_service.py # State machine + dual validation
│ │ ── notification_service.py # Create/read/cleanup notifications
│ ├── status_service.py # Technique status recalculation
│ │ └── ... # audit, mitre_sync, intel
── jobs/
└── mitre_sync_job.py # Scheduler: MITRE sync, Intel scan, Notification cleanup
├── frontend/src/
├── App.tsx # Routes including /reports
├── api/ # API clients
├── notifications.ts # Notification API
│ ├── reports.ts # Report API
│ │ └── ...
├── components/
│ │ ├── Layout.tsx # Sidebar + header + NotificationBell
│ │ ├── Sidebar.tsx # Collapsible nav with admin section
│ │ ├── NotificationBell.tsx # Bell icon with badge (polls every 30s)
│ │ ├── NotificationDropdown.tsx # Notification list dropdown
├── ConfirmDialog.tsx # Reusable confirmation modal
├── Toast.tsx # Toast notification system
── test-detail/ # Test detail sub-components
── pages/
├── DashboardPage.tsx # Pipeline funnel, team activity, validation rates
├── TestsPage.tsx # Filters, state counters, pending tasks
├── TestDetailPage.tsx # Red/Blue tabs, validation, evidence
├── TestCatalogPage.tsx # Browse & use templates
├── ReportsPage.tsx # Coverage, results, remediation reports
└── SystemPage.tsx # Template admin, import Atomic Red Team
└── backend/tests/ # Test suite
├── test_workflow.py # Red/Blue workflow tests
├── test_templates_crud.py # Template CRUD tests
├── test_metrics_v2.py # V2 metrics tests
└── test_integration_v2.py # Full integration E2E tests
```
## Database Schema
The platform uses the following data models:
| Table | Description |
|-------|-------------|
| `users` | User accounts with role-based access |
| `techniques` | MITRE ATT&CK techniques with coverage status |
| `tests` | Security tests validating technique coverage |
| `evidences` | File evidence attached to tests (stored in MinIO) |
| `tests` | Security tests with Red/Blue fields, dual validation, and remediation |
| `test_templates` | Predefined test catalog (Atomic Red Team, custom) |
| `evidences` | File evidence separated by team (red/blue) |
| `notifications` | In-app notifications with read status |
| `intel_items` | Threat intelligence items linked to techniques |
| `audit_logs` | System-wide audit trail for all actions |
| `audit_logs` | System-wide audit trail |
## Configuration
The application can be configured via environment variables:
| Variable | Default | Description |
|----------|---------|-------------|
| `DATABASE_URL` | `postgresql://postgres:postgres@postgres:5432/attackdb` | PostgreSQL connection string |
| `DATABASE_URL` | `postgresql://postgres:postgres@postgres:5432/attackdb` | PostgreSQL connection |
| `SECRET_KEY` | `change-me-in-production` | JWT signing key |
| `ALGORITHM` | `HS256` | JWT signing algorithm |
| `ACCESS_TOKEN_EXPIRE_MINUTES` | `60` | JWT token lifetime in minutes |
| `MINIO_ENDPOINT` | `minio:9000` | MinIO server endpoint |
| `ACCESS_TOKEN_EXPIRE_MINUTES` | `60` | Token lifetime |
| `MINIO_ENDPOINT` | `minio:9000` | MinIO server |
| `MINIO_ACCESS_KEY` | `minioadmin` | MinIO access key |
| `MINIO_SECRET_KEY` | `minioadmin` | MinIO secret key |
| `MINIO_BUCKET` | `evidence` | Bucket for evidence files |
| `MINIO_BUCKET` | `evidence` | Evidence bucket |
## Development
### Running Migrations
```bash
# Generate a new migration after model changes
docker exec -w /app aegis-backend-1 alembic revision --autogenerate -m "description"
# Apply migrations
docker exec -w /app aegis-backend-1 alembic upgrade head
# Rollback one migration
docker exec -w /app aegis-backend-1 alembic revision --autogenerate -m "description"
docker exec -w /app aegis-backend-1 alembic downgrade -1
# Check current migration
docker exec -w /app aegis-backend-1 alembic current
```
### Accessing Services
- **MinIO Console**: http://localhost:9001 (login: `minioadmin` / `minioadmin`)
- **PostgreSQL**: `psql -h localhost -p 5433 -U postgres -d attackdb`
### Running Tests
The backend includes a test suite using pytest:
```bash
# Install test dependencies (if running locally)
pip install pytest pytest-asyncio httpx
# Run standalone tests (no database required)
cd backend && python tests/test_workflow.py
cd backend && python tests/test_templates_crud.py
cd backend && python tests/test_metrics_v2.py
cd backend && python tests/test_integration_v2.py
# Run all tests
docker exec -w /app aegis-backend-1 pytest
# Run tests with verbose output
# Run with pytest (requires PostgreSQL)
docker exec -w /app aegis-backend-1 pytest -v
# Run specific test file
docker exec -w /app aegis-backend-1 pytest tests/test_auth.py
# Run locally (requires SQLite)
cd backend && pytest
```
Test files:
- `test_health.py` - Health endpoint tests
- `test_auth.py` - Authentication and authorization tests
- `test_techniques.py` - Technique CRUD tests
- `test_tests.py` - Security test CRUD and validation tests
## User Roles
| Role | Description |
|------|-------------|
| `admin` | Full system access |
| `red_tech` | Red team technician - can create and edit tests |
| `blue_tech` | Blue team technician - can create and edit tests |
| `red_lead` | Red team lead - can validate tests |
| `blue_lead` | Blue team lead - can validate tests |
| `viewer` | Read-only access |
## License
This project is proprietary software. All rights reserved.
## Contributing
Please read the contribution guidelines before submitting pull requests.
+3989
View File
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,46 @@
"""add_notifications_table
Revision ID: b006notifications
Revises: b005v2indexes
Create Date: 2026-02-09 11:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = 'b006notifications'
down_revision: Union[str, Sequence[str], None] = 'b005v2indexes'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create notifications table."""
op.create_table(
'notifications',
sa.Column('id', UUID(as_uuid=True), primary_key=True, server_default=sa.text('gen_random_uuid()')),
sa.Column('user_id', UUID(as_uuid=True), sa.ForeignKey('users.id'), nullable=False),
sa.Column('type', sa.String(), nullable=False),
sa.Column('title', sa.String(), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.Column('entity_type', sa.String(), nullable=True),
sa.Column('entity_id', UUID(as_uuid=True), nullable=True),
sa.Column('read', sa.Boolean(), server_default='false'),
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now()),
)
op.create_index('ix_notifications_user_id', 'notifications', ['user_id'])
op.create_index('ix_notifications_read', 'notifications', ['read'])
op.create_index('ix_notifications_created_at', 'notifications', ['created_at'])
def downgrade() -> None:
"""Drop notifications table."""
op.drop_index('ix_notifications_created_at', table_name='notifications')
op.drop_index('ix_notifications_read', table_name='notifications')
op.drop_index('ix_notifications_user_id', table_name='notifications')
op.drop_table('notifications')
@@ -0,0 +1,44 @@
"""add_remediation_fields
Revision ID: b007remediation
Revises: b006notifications
Create Date: 2026-02-09 11:30:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = 'b007remediation'
down_revision: Union[str, Sequence[str], None] = 'b006notifications'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add remediation fields to tests and test_templates."""
# Tests — remediation fields
op.add_column('tests', sa.Column('remediation_steps', sa.Text(), nullable=True))
op.add_column('tests', sa.Column('remediation_status', sa.String(), nullable=True))
op.add_column('tests', sa.Column('remediation_assignee', UUID(as_uuid=True), nullable=True))
op.create_foreign_key(
'fk_tests_remediation_assignee',
'tests', 'users',
['remediation_assignee'], ['id'],
)
# TestTemplates — suggested_remediation
op.add_column('test_templates', sa.Column('suggested_remediation', sa.Text(), nullable=True))
def downgrade() -> None:
"""Remove remediation fields."""
op.drop_column('test_templates', 'suggested_remediation')
op.drop_constraint('fk_tests_remediation_assignee', 'tests', type_='foreignkey')
op.drop_column('tests', 'remediation_assignee')
op.drop_column('tests', 'remediation_status')
op.drop_column('tests', 'remediation_steps')
+23 -1
View File
@@ -17,6 +17,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
from app.database import SessionLocal
from app.services.mitre_sync_service import sync_mitre
from app.services.intel_service import scan_intel
from app.services.notification_service import cleanup_old_notifications
logger = logging.getLogger(__name__)
@@ -45,6 +46,19 @@ def _run_mitre_sync() -> None:
db.close()
def _run_notification_cleanup() -> None:
"""Clean up old read notifications."""
logger.info("Scheduled notification cleanup job starting...")
db = SessionLocal()
try:
deleted = cleanup_old_notifications(db, days=90)
logger.info("Notification cleanup finished — deleted %d old notifications", deleted)
except Exception:
logger.exception("Notification cleanup job failed")
finally:
db.close()
def _run_intel_scan() -> None:
"""Execute an intel scan inside its own DB session."""
logger.info("Scheduled intel scan job starting...")
@@ -89,5 +103,13 @@ def start_scheduler() -> None:
name="Intel scan (every 7d)",
replace_existing=True,
)
scheduler.add_job(
_run_notification_cleanup,
trigger="interval",
hours=24,
id="notification_cleanup",
name="Notification cleanup (daily)",
replace_existing=True,
)
scheduler.start()
logger.info("Background scheduler started — mitre_sync (24h), intel_scan (7d)")
logger.info("Background scheduler started — mitre_sync (24h), intel_scan (7d), notification_cleanup (24h)")
+4
View File
@@ -16,6 +16,8 @@ from app.routers import system as system_router
from app.routers import metrics as metrics_router
from app.routers import users as users_router
from app.routers import audit as audit_router
from app.routers import notifications as notifications_router
from app.routers import reports as reports_router
from app.storage import ensure_bucket_exists
from app.jobs.mitre_sync_job import start_scheduler, scheduler
@@ -56,6 +58,8 @@ app.include_router(system_router.router, prefix="/api/v1")
app.include_router(metrics_router.router, prefix="/api/v1")
app.include_router(users_router.router, prefix="/api/v1")
app.include_router(audit_router.router, prefix="/api/v1")
app.include_router(notifications_router.router, prefix="/api/v1")
app.include_router(reports_router.router, prefix="/api/v1")
@app.get("/health")
+2 -1
View File
@@ -6,10 +6,11 @@ from app.models.test_template import TestTemplate
from app.models.evidence import Evidence
from app.models.intel import IntelItem
from app.models.audit import AuditLog
from app.models.notification import Notification
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
__all__ = [
"User", "Technique", "Test", "TestTemplate", "Evidence",
"IntelItem", "AuditLog",
"IntelItem", "AuditLog", "Notification",
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
]
+39
View File
@@ -0,0 +1,39 @@
"""Notification model — in-app notifications for user actions."""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Index
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.database import Base
class Notification(Base):
"""
In-app notification for alerting users when they need to act.
Types include: test_assigned, validation_needed, test_rejected,
test_validated, test_state_changed, etc.
"""
__tablename__ = "notifications"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False)
type = Column(String, nullable=False)
title = Column(String, nullable=False)
message = Column(Text, nullable=True)
entity_type = Column(String, nullable=True)
entity_id = Column(UUID(as_uuid=True), nullable=True)
read = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationships
user = relationship("User")
__table_args__ = (
Index("ix_notifications_user_id", "user_id"),
Index("ix_notifications_read", "read"),
Index("ix_notifications_created_at", "created_at"),
)
+6
View File
@@ -49,9 +49,15 @@ class Test(Base):
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
blue_validation_notes = Column(Text, nullable=True)
# ── Remediation fields ───────────────────────────────────────────
remediation_steps = Column(Text, nullable=True)
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
# ── Relationships ───────────────────────────────────────────────
technique = relationship("Technique", back_populates="tests")
evidences = relationship("Evidence", back_populates="test")
creator = relationship("User", foreign_keys=[created_by])
red_validator = relationship("User", foreign_keys=[red_validated_by])
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
+1
View File
@@ -34,6 +34,7 @@ class TestTemplate(Base):
tool_suggested = Column(String, nullable=True)
severity = Column(String, nullable=True) # low / medium / high / critical
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
suggested_remediation = Column(Text, nullable=True)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
+103
View File
@@ -0,0 +1,103 @@
"""Notification endpoints.
Endpoints
---------
GET /notifications list user notifications (paginated)
GET /notifications/unread-count count of unread notifications
PATCH /notifications/{id}/read mark one notification as read
POST /notifications/read-all mark all as read
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.notification import Notification
from app.models.user import User
from app.schemas.notification import NotificationOut, UnreadCountOut
from app.services.notification_service import (
mark_as_read,
mark_all_as_read,
get_unread_count,
)
router = APIRouter(prefix="/notifications", tags=["notifications"])
# ---------------------------------------------------------------------------
# GET /notifications — list (paginated)
# ---------------------------------------------------------------------------
@router.get("", response_model=list[NotificationOut])
def list_notifications(
offset: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return paginated notifications for the current user, newest first."""
notifs = (
db.query(Notification)
.filter(Notification.user_id == current_user.id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return notifs
# ---------------------------------------------------------------------------
# GET /notifications/unread-count
# ---------------------------------------------------------------------------
@router.get("/unread-count", response_model=UnreadCountOut)
def unread_count(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return the number of unread notifications for the current user."""
count = get_unread_count(db, current_user.id)
return UnreadCountOut(unread_count=count)
# ---------------------------------------------------------------------------
# PATCH /notifications/{id}/read
# ---------------------------------------------------------------------------
@router.patch("/{notification_id}/read", response_model=NotificationOut)
def read_notification(
notification_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Mark a single notification as read."""
success = mark_as_read(db, notification_id, current_user.id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Notification not found",
)
notif = db.query(Notification).filter(Notification.id == notification_id).first()
return notif
# ---------------------------------------------------------------------------
# POST /notifications/read-all
# ---------------------------------------------------------------------------
@router.post("/read-all")
def read_all_notifications(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Mark all notifications for the current user as read."""
count = mark_all_as_read(db, current_user.id)
return {"detail": f"Marked {count} notifications as read"}
+270
View File
@@ -0,0 +1,270 @@
"""Reports endpoints — export coverage summaries and test results.
Endpoints
---------
GET /reports/coverage-summary full coverage JSON report
GET /reports/coverage-csv CSV export of coverage
GET /reports/test-results test results report (JSON)
GET /reports/remediation-status remediation status report (JSON)
"""
import csv
import io
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
router = APIRouter(prefix="/reports", tags=["reports"])
# ---------------------------------------------------------------------------
# GET /reports/coverage-summary
# ---------------------------------------------------------------------------
@router.get("/coverage-summary")
def coverage_summary(
tactic: Optional[str] = Query(None, description="Filter by tactic"),
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Full coverage report as JSON — technique-by-technique with test counts."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
techniques = query.order_by(Technique.mitre_id).all()
rows = []
for t in techniques:
# Count tests per state for this technique
test_counts = (
db.query(Test.state, func.count(Test.id))
.filter(Test.technique_id == t.id)
.group_by(Test.state)
.all()
)
counts = {str(state): count for state, count in test_counts}
# Filter by platform if requested (check if technique platforms contain it)
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
rows.append({
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"platforms": t.platforms,
"status_global": t.status_global,
"total_tests": sum(counts.values()),
"tests_by_state": counts,
})
total = len(rows)
validated = sum(1 for r in rows if r["status_global"] == "validated")
partial = sum(1 for r in rows if r["status_global"] == "partial")
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_techniques": total,
"validated": validated,
"partial": partial,
"not_covered": not_covered,
"in_progress": in_progress,
"not_evaluated": not_evaluated,
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
},
"techniques": rows,
}
# ---------------------------------------------------------------------------
# GET /reports/coverage-csv
# ---------------------------------------------------------------------------
@router.get("/coverage-csv")
def coverage_csv(
tactic: Optional[str] = Query(None),
platform: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Export coverage as a downloadable CSV."""
query = db.query(Technique)
if tactic:
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
techniques = query.order_by(Technique.mitre_id).all()
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
"Total Tests", "Validated", "In Progress", "Not Covered",
])
for t in techniques:
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
continue
test_counts = (
db.query(Test.state, func.count(Test.id))
.filter(Test.technique_id == t.id)
.group_by(Test.state)
.all()
)
counts = {str(state): count for state, count in test_counts}
writer.writerow([
t.mitre_id,
t.name,
t.tactic,
", ".join(t.platforms or []),
t.status_global,
sum(counts.values()),
counts.get("validated", 0),
sum(counts.get(s, 0) for s in ["draft", "red_executing", "blue_evaluating", "in_review"]),
counts.get("rejected", 0),
])
output.seek(0)
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"},
)
# ---------------------------------------------------------------------------
# GET /reports/test-results
# ---------------------------------------------------------------------------
@router.get("/test-results")
def test_results(
state: Optional[str] = Query(None),
date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Report of test results with optional filters."""
query = db.query(Test)
if state:
query = query.filter(Test.state == state)
if date_from:
try:
dt = datetime.fromisoformat(date_from)
query = query.filter(Test.created_at >= dt)
except ValueError:
pass
if date_to:
try:
dt = datetime.fromisoformat(date_to)
query = query.filter(Test.created_at <= dt)
except ValueError:
pass
tests = query.order_by(Test.created_at.desc()).all()
# Summary
total = len(tests)
by_state = {}
by_result = {}
for t in tests:
s = t.state.value if hasattr(t.state, "value") else str(t.state)
by_state[s] = by_state.get(s, 0) + 1
if t.detection_result:
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
by_result[r] = by_result.get(r, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
"summary": {
"total_tests": total,
"by_state": by_state,
"by_detection_result": by_result,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"platform": t.platform,
"attack_success": t.attack_success,
"detection_result": (
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
else str(t.detection_result) if t.detection_result else None
),
"red_validation_status": t.red_validation_status,
"blue_validation_status": t.blue_validation_status,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tests
],
}
# ---------------------------------------------------------------------------
# GET /reports/remediation-status
# ---------------------------------------------------------------------------
@router.get("/remediation-status")
def remediation_status(
status: Optional[str] = Query(None, description="Filter by remediation status"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Report of remediation status across all tests."""
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
if status:
query = query.filter(Test.remediation_status == status)
tests = query.order_by(Test.created_at.desc()).all()
by_status = {}
for t in tests:
s = t.remediation_status or "unset"
by_status[s] = by_status.get(s, 0) + 1
return {
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_with_remediation": len(tests),
"by_status": by_status,
},
"tests": [
{
"id": str(t.id),
"name": t.name,
"technique_id": str(t.technique_id),
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
"remediation_status": t.remediation_status,
"remediation_steps": t.remediation_steps,
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
}
for t in tests
],
}
+52 -4
View File
@@ -40,6 +40,7 @@ from app.schemas.test import (
TestBlueUpdate,
TestRedValidate,
TestBlueValidate,
TestRemediationUpdate,
)
from app.schemas.test_template import TestTemplateInstantiate
from app.services.audit_service import log_action
@@ -211,6 +212,7 @@ def create_test_from_template(
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=current_user.id,
state=TestState.draft,
)
@@ -284,13 +286,17 @@ def update_test(
if current_user.role != "admin" and test.created_by != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"},
)
if test.state not in (TestState.draft, TestState.rejected):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
detail={
"message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
@@ -330,7 +336,11 @@ def update_test_red(
if test.state not in (TestState.draft, TestState.red_executing):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
detail={
"message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
@@ -370,7 +380,11 @@ def update_test_blue(
if test.state != TestState.blue_evaluating:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
detail={
"message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
@@ -520,6 +534,40 @@ def reopen(
return test
# ---------------------------------------------------------------------------
# PATCH /tests/{id}/remediation — update remediation fields
# ---------------------------------------------------------------------------
@router.patch("/{test_id}/remediation", response_model=TestOut)
def update_remediation(
test_id: uuid.UUID,
payload: TestRemediationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Update remediation fields on a test (any authenticated user)."""
test = _get_test_or_404(db, test_id)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
db.refresh(test)
log_action(
db,
user_id=current_user.id,
action="update_remediation",
entity_type="test",
entity_id=test.id,
details={"updated_fields": list(update_data.keys())},
)
return test
# ---------------------------------------------------------------------------
# GET /tests/{id}/timeline — audit history for this test
# ---------------------------------------------------------------------------
+28
View File
@@ -0,0 +1,28 @@
"""Pydantic schemas for Notification endpoints."""
import uuid
from datetime import datetime
from pydantic import BaseModel, ConfigDict
class NotificationOut(BaseModel):
"""Notification returned by the API."""
id: uuid.UUID
user_id: uuid.UUID
type: str
title: str
message: str | None = None
entity_type: str | None = None
entity_id: uuid.UUID | None = None
read: bool = False
created_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
class UnreadCountOut(BaseModel):
"""Simple counter response."""
unread_count: int
+16
View File
@@ -81,6 +81,17 @@ class TestBlueValidate(BaseModel):
blue_validation_notes: str | None = None
# ── Remediation update ────────────────────────────────────────────
class TestRemediationUpdate(BaseModel):
"""Payload for updating remediation fields."""
remediation_steps: str | None = None
remediation_status: str | None = None # pending / in_progress / completed / not_applicable
remediation_assignee: uuid.UUID | None = None
# ── Legacy validate (kept for backwards compat) ────────────────────
@@ -126,6 +137,11 @@ class TestOut(BaseModel):
blue_validation_status: str | None = None
blue_validation_notes: str | None = None
# Remediation fields
remediation_steps: str | None = None
remediation_status: str | None = None
remediation_assignee: uuid.UUID | None = None
# Technique info (populated when joined)
technique_mitre_id: str | None = None
technique_name: str | None = None
+2
View File
@@ -24,6 +24,7 @@ class TestTemplateOut(BaseModel):
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
suggested_remediation: str | None = None
is_active: bool = True
created_at: datetime | None = None
@@ -47,6 +48,7 @@ class TestTemplateCreate(BaseModel):
tool_suggested: str | None = None
severity: str | None = None
atomic_test_id: str | None = None
suggested_remediation: str | None = None
# ── Summary (for listings) ─────────────────────────────────────────
@@ -0,0 +1,179 @@
"""Notification service — create, read, and manage in-app notifications.
Provides helpers for generating notifications automatically when test
state changes occur, plus CRUD for the notifications API.
"""
import uuid
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.models.notification import Notification
from app.models.user import User
# ---------------------------------------------------------------------------
# Core CRUD
# ---------------------------------------------------------------------------
def create_notification(
db: Session,
user_id: uuid.UUID,
type: str,
title: str,
message: str | None = None,
entity_type: str | None = None,
entity_id: uuid.UUID | None = None,
) -> Notification:
"""Create a single notification for a user."""
notif = Notification(
user_id=user_id,
type=type,
title=title,
message=message,
entity_type=entity_type,
entity_id=entity_id,
)
db.add(notif)
db.commit()
db.refresh(notif)
return notif
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
"""Mark a single notification as read. Returns True if updated."""
notif = (
db.query(Notification)
.filter(Notification.id == notification_id, Notification.user_id == user_id)
.first()
)
if notif is None:
return False
notif.read = True
db.commit()
return True
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
"""Mark all unread notifications for a user as read. Returns count updated."""
count = (
db.query(Notification)
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
.update({"read": True})
)
db.commit()
return count
def get_unread_count(db: Session, user_id: uuid.UUID) -> int:
"""Return the number of unread notifications for a user."""
return (
db.query(func.count(Notification.id))
.filter(Notification.user_id == user_id, Notification.read == False) # noqa: E712
.scalar()
) or 0
def cleanup_old_notifications(db: Session, days: int = 90) -> int:
"""Delete read notifications older than *days*. Returns count deleted."""
cutoff = datetime.utcnow() - timedelta(days=days)
count = (
db.query(Notification)
.filter(
Notification.read == True, # noqa: E712
Notification.created_at < cutoff,
)
.delete()
)
db.commit()
return count
# ---------------------------------------------------------------------------
# Automatic notification dispatchers
# ---------------------------------------------------------------------------
def notify_test_state_change(db: Session, test, new_state: str) -> None:
"""Dispatch notifications based on a test's new state.
Called by the workflow service after each state transition.
Rules:
- red_executing -> notify creator (confirmation)
- blue_evaluating -> notify all blue_tech users
- in_review -> notify red_lead and blue_lead users
- rejected -> notify creator
- validated -> notify creator
"""
test_name = test.name
test_id = test.id
creator_id = test.created_by
if new_state == "red_executing" and creator_id:
create_notification(
db,
user_id=creator_id,
type="test_state_changed",
title="Test execution started",
message=f'Your test "{test_name}" has moved to execution phase.',
entity_type="test",
entity_id=test_id,
)
elif new_state == "blue_evaluating":
# Notify all blue_tech users
blue_users = db.query(User).filter(User.role == "blue_tech", User.is_active == True).all() # noqa: E712
for user in blue_users:
create_notification(
db,
user_id=user.id,
type="test_assigned",
title="New test ready for blue evaluation",
message=f'Test "{test_name}" needs blue team evaluation.',
entity_type="test",
entity_id=test_id,
)
elif new_state == "in_review":
# Notify red_lead and blue_lead users
managers = (
db.query(User)
.filter(User.role.in_(["red_lead", "blue_lead"]), User.is_active == True) # noqa: E712
.all()
)
for user in managers:
create_notification(
db,
user_id=user.id,
type="validation_needed",
title="Test ready for validation",
message=f'Test "{test_name}" is awaiting your review.',
entity_type="test",
entity_id=test_id,
)
elif new_state == "rejected" and creator_id:
create_notification(
db,
user_id=creator_id,
type="test_rejected",
title="Test rejected",
message=f'Your test "{test_name}" has been rejected. Please review and resubmit.',
entity_type="test",
entity_id=test_id,
)
elif new_state == "validated" and creator_id:
create_notification(
db,
user_id=creator_id,
type="test_validated",
title="Test validated",
message=f'Your test "{test_name}" has been validated successfully.',
entity_type="test",
entity_id=test_id,
)
+47 -9
View File
@@ -20,6 +20,7 @@ from app.models.enums import TestState
from app.models.test import Test
from app.models.user import User
from app.services.audit_service import log_action
from app.services.notification_service import notify_test_state_change
# ---------------------------------------------------------------------------
# Valid transition map
@@ -60,13 +61,20 @@ def transition_state(
Raises :class:`~fastapi.HTTPException` 400 when the transition is invalid.
"""
if not can_transition(test, target_state):
current = test.state if isinstance(test.state, TestState) else TestState(test.state)
valid = [s.value for s in VALID_TRANSITIONS.get(current, [])]
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Invalid transition: cannot move from "
f"'{test.state.value if isinstance(test.state, TestState) else test.state}' "
f"to '{target_state.value}'"
),
detail={
"message": (
f"Cannot transition from '{current.value}' to '{target_state.value}'. "
f"Valid transitions: {valid}"
),
"code": "INVALID_TRANSITION",
"current_state": current.value,
"target_state": target_state.value,
"valid_transitions": valid,
},
)
previous_state = test.state.value if isinstance(test.state, TestState) else test.state
@@ -91,6 +99,12 @@ def transition_state(
details=details,
)
# Dispatch in-app notifications for the new state
try:
notify_test_state_change(db, test, target_state.value)
except Exception:
pass # Notifications are best-effort — don't block the workflow
return test
@@ -152,16 +166,24 @@ def validate_as_red_lead(
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
current = test.state.value if isinstance(test.state, TestState) else test.state
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate red side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
detail={
"message": f"Cannot validate red side while test is in '{current}' state (must be in_review)",
"code": "INVALID_STATE",
"current_state": current,
},
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
detail={
"message": "validation_status must be 'approved' or 'rejected'",
"code": "INVALID_VALIDATION_STATUS",
},
)
now = datetime.utcnow()
@@ -200,16 +222,24 @@ def validate_as_blue_lead(
After recording the decision, :func:`check_dual_validation` is called
to potentially advance the test to ``validated`` or ``rejected``.
"""
current = test.state.value if isinstance(test.state, TestState) else test.state
if test.state not in (TestState.in_review,):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot validate blue side while test is in '{test.state.value if isinstance(test.state, TestState) else test.state}' state (must be in_review)",
detail={
"message": f"Cannot validate blue side while test is in '{current}' state (must be in_review)",
"code": "INVALID_STATE",
"current_state": current,
},
)
if validation_status not in ("approved", "rejected"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="validation_status must be 'approved' or 'rejected'",
detail={
"message": "validation_status must be 'approved' or 'rejected'",
"code": "INVALID_VALIDATION_STATUS",
},
)
now = datetime.utcnow()
@@ -250,9 +280,17 @@ def check_dual_validation(db: Session, test: Test) -> Test:
if red_status == "rejected" or blue_status == "rejected":
test.state = TestState.rejected
db.commit()
try:
notify_test_state_change(db, test, "rejected")
except Exception:
pass
elif red_status == "approved" and blue_status == "approved":
test.state = TestState.validated
db.commit()
try:
notify_test_state_change(db, test, "validated")
except Exception:
pass
else:
# One side hasn't voted yet — stay in_review, just flush
db.commit()
+96
View File
@@ -86,6 +86,54 @@ def red_tech_user(db):
return user
@pytest.fixture(scope="function")
def blue_tech_user(db):
"""Create a blue_tech user for testing."""
user = User(
username="bluetech",
email="bluetech@test.com",
hashed_password=hash_password("bluetech123"),
role="blue_tech",
is_active=True,
)
db.add(user)
db.commit()
db.refresh(user)
return user
@pytest.fixture(scope="function")
def red_lead_user(db):
"""Create a red_lead user for testing."""
user = User(
username="redlead",
email="redlead@test.com",
hashed_password=hash_password("redlead123"),
role="red_lead",
is_active=True,
)
db.add(user)
db.commit()
db.refresh(user)
return user
@pytest.fixture(scope="function")
def blue_lead_user(db):
"""Create a blue_lead user for testing."""
user = User(
username="bluelead",
email="bluelead@test.com",
hashed_password=hash_password("bluelead123"),
role="blue_lead",
is_active=True,
)
db.add(user)
db.commit()
db.refresh(user)
return user
@pytest.fixture(scope="function")
def admin_token(client, admin_user):
"""Get an auth token for the admin user."""
@@ -116,3 +164,51 @@ def auth_headers(admin_token):
def red_tech_headers(red_tech_token):
"""Return authorization headers for red_tech user."""
return {"Authorization": f"Bearer {red_tech_token}"}
@pytest.fixture(scope="function")
def blue_tech_token(client, blue_tech_user):
"""Get an auth token for the blue_tech user."""
response = client.post(
"/api/v1/auth/login",
data={"username": "bluetech", "password": "bluetech123"},
)
return response.json()["access_token"]
@pytest.fixture(scope="function")
def blue_tech_headers(blue_tech_token):
"""Return authorization headers for blue_tech user."""
return {"Authorization": f"Bearer {blue_tech_token}"}
@pytest.fixture(scope="function")
def red_lead_token(client, red_lead_user):
"""Get an auth token for the red_lead user."""
response = client.post(
"/api/v1/auth/login",
data={"username": "redlead", "password": "redlead123"},
)
return response.json()["access_token"]
@pytest.fixture(scope="function")
def red_lead_headers(red_lead_token):
"""Return authorization headers for red_lead user."""
return {"Authorization": f"Bearer {red_lead_token}"}
@pytest.fixture(scope="function")
def blue_lead_token(client, blue_lead_user):
"""Get an auth token for the blue_lead user."""
response = client.post(
"/api/v1/auth/login",
data={"username": "bluelead", "password": "bluelead123"},
)
return response.json()["access_token"]
@pytest.fixture(scope="function")
def blue_lead_headers(blue_lead_token):
"""Return authorization headers for blue_lead user."""
return {"Authorization": f"Bearer {blue_lead_token}"}
+696
View File
@@ -0,0 +1,696 @@
"""T-134: Final integration tests for V2 — end-to-end flows.
Covers:
- Full E2E flow: import template -> create test -> execute -> evaluate -> validate
- Rejection/recovery flow
- Notification generation during state changes
- Metrics accuracy after operations
- Report generation
- Remediation field management
Uses mock objects to test the workflow service and router logic
without requiring a running database.
"""
import sys
import os
import uuid
import inspect
from unittest.mock import MagicMock, patch
from types import ModuleType
from datetime import datetime, timedelta
# ---------------------------------------------------------------------------
# Stub heavy dependencies before importing app modules
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
_ps = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
_ps.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = _ps
if "app.config" not in sys.modules:
_cfg = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
_cfg.settings = _FakeSettings()
sys.modules["app.config"] = _cfg
if "app.database" not in sys.modules:
_db = ModuleType("app.database")
_db.Base = type("Base", (), {"metadata": MagicMock()})
_db.get_db = MagicMock()
_db.SessionLocal = MagicMock()
sys.modules["app.database"] = _db
# Stub jose with JWTError
if "jose" not in sys.modules:
_jose = ModuleType("jose")
class _JWTError(Exception): pass
_jose.JWTError = _JWTError
_jose.jwt = MagicMock()
sys.modules["jose"] = _jose
# Stub apscheduler
for _mod in ["apscheduler", "apscheduler.schedulers", "apscheduler.triggers", "apscheduler.triggers.cron"]:
if _mod not in sys.modules:
sys.modules[_mod] = ModuleType(_mod)
if "apscheduler.schedulers.background" not in sys.modules:
_apsched = ModuleType("apscheduler.schedulers.background")
class _FakeBGScheduler:
def add_job(self, *a, **kw): pass
def start(self): pass
def shutdown(self, **kw): pass
_apsched.BackgroundScheduler = _FakeBGScheduler
sys.modules["apscheduler.schedulers.background"] = _apsched
if "taxii2client" not in sys.modules:
sys.modules["taxii2client"] = ModuleType("taxii2client")
if "taxii2client.v20" not in sys.modules:
_tv20 = ModuleType("taxii2client.v20")
_tv20.Server = MagicMock
_tv20.Collection = MagicMock
sys.modules["taxii2client.v20"] = _tv20
for _mod in [
"boto3", "botocore", "botocore.exceptions",
"passlib", "passlib.context",
]:
if _mod not in sys.modules:
sys.modules[_mod] = ModuleType(_mod)
# Now safe to import
from app.models.enums import TestState, TestResult, TechniqueStatus
from app.services.test_workflow_service import (
can_transition,
VALID_TRANSITIONS,
transition_state,
start_execution,
submit_red_evidence,
submit_blue_evidence,
validate_as_red_lead,
validate_as_blue_lead,
check_dual_validation,
reopen_test,
)
from app.services.notification_service import (
create_notification,
mark_as_read,
mark_all_as_read,
get_unread_count,
cleanup_old_notifications,
notify_test_state_change,
)
passed = 0
failed = 0
def _make_test(**overrides):
"""Create a mock Test object with sensible defaults."""
t = MagicMock()
t.id = overrides.get("id", uuid.uuid4())
t.name = overrides.get("name", "Integration Test")
t.technique_id = overrides.get("technique_id", uuid.uuid4())
t.created_by = overrides.get("created_by", uuid.uuid4())
t.state = overrides.get("state", TestState.draft)
t.red_validation_status = overrides.get("red_validation_status", None)
t.blue_validation_status = overrides.get("blue_validation_status", None)
t.red_validated_by = None
t.red_validated_at = None
t.red_validation_notes = None
t.blue_validated_by = None
t.blue_validated_at = None
t.blue_validation_notes = None
t.attack_success = None
t.detection_result = None
t.remediation_steps = None
t.remediation_status = None
t.remediation_assignee = None
for k, v in overrides.items():
setattr(t, k, v)
return t
def _make_user(role="admin"):
u = MagicMock()
u.id = uuid.uuid4()
u.role = role
u.is_active = True
u.username = f"test_{role}"
return u
# ===========================================================================
# TEST 1 — Full E2E happy path through workflow
# ===========================================================================
def test_full_e2e_flow():
"""Full lifecycle: draft → red_executing → blue_evaluating → in_review → validated"""
global passed, failed
try:
db = MagicMock()
test = _make_test(state=TestState.draft)
red_tech = _make_user("red_tech")
blue_tech = _make_user("blue_tech")
red_lead = _make_user("red_lead")
blue_lead = _make_user("blue_lead")
# draft -> red_executing
assert can_transition(test, TestState.red_executing)
test.state = TestState.red_executing
# red_executing -> blue_evaluating
assert can_transition(test, TestState.blue_evaluating)
test.state = TestState.blue_evaluating
# blue_evaluating -> in_review
assert can_transition(test, TestState.in_review)
test.state = TestState.in_review
# Both leads approve → validated
test.red_validation_status = "approved"
test.blue_validation_status = "approved"
check_dual_validation(db, test)
assert test.state == TestState.validated
print(" PASS: test_full_e2e_flow")
passed += 1
except Exception as e:
print(f" FAIL: test_full_e2e_flow — {e}")
failed += 1
# ===========================================================================
# TEST 2 — Rejection and recovery flow
# ===========================================================================
def test_rejection_recovery_flow():
"""in_review → rejected → draft → start over"""
global passed, failed
try:
db = MagicMock()
test = _make_test(state=TestState.in_review)
# Red lead rejects
test.red_validation_status = "rejected"
test.blue_validation_status = None
check_dual_validation(db, test)
assert test.state == TestState.rejected
# Reopen: rejected → draft
assert can_transition(test, TestState.draft)
test.state = TestState.draft
test.red_validation_status = None
test.blue_validation_status = None
# Can restart: draft → red_executing
assert can_transition(test, TestState.red_executing)
print(" PASS: test_rejection_recovery_flow")
passed += 1
except Exception as e:
print(f" FAIL: test_rejection_recovery_flow — {e}")
failed += 1
# ===========================================================================
# TEST 3 — Notification dispatch on state changes
# ===========================================================================
def test_notification_dispatching():
"""Verify notifications are dispatched for key state changes."""
global passed, failed
try:
db = MagicMock()
test = _make_test(state=TestState.blue_evaluating)
# Check the function can call create_notification
src = inspect.getsource(notify_test_state_change)
assert "blue_evaluating" in src, "Should handle blue_evaluating state"
assert "in_review" in src, "Should handle in_review state"
assert "rejected" in src, "Should handle rejected state"
assert "validated" in src, "Should handle validated state"
assert "create_notification" in src, "Should call create_notification"
assert "blue_tech" in src, "Should notify blue_tech users"
assert "red_lead" in src or "blue_lead" in src, "Should notify leads"
print(" PASS: test_notification_dispatching")
passed += 1
except Exception as e:
print(f" FAIL: test_notification_dispatching — {e}")
failed += 1
# ===========================================================================
# TEST 4 — Notification cleanup service
# ===========================================================================
def test_notification_cleanup():
"""cleanup_old_notifications deletes read notifications older than cutoff."""
global passed, failed
try:
src = inspect.getsource(cleanup_old_notifications)
assert "timedelta" in src, "Should use timedelta for cutoff"
assert "read" in src.lower(), "Should filter by read status"
assert "delete" in src, "Should call delete()"
print(" PASS: test_notification_cleanup")
passed += 1
except Exception as e:
print(f" FAIL: test_notification_cleanup — {e}")
failed += 1
# ===========================================================================
# TEST 5 — Metrics endpoints exist
# ===========================================================================
def test_metrics_endpoints_exist():
"""Verify V2 metrics endpoints are registered."""
global passed, failed
try:
from app.routers import metrics
src = inspect.getsource(metrics)
assert "test-pipeline" in src, "Should have /metrics/test-pipeline"
assert "team-activity" in src, "Should have /metrics/team-activity"
assert "validation-rate" in src, "Should have /metrics/validation-rate"
assert "recent-tests" in src, "Should have /metrics/recent-tests"
print(" PASS: test_metrics_endpoints_exist")
passed += 1
except Exception as e:
print(f" FAIL: test_metrics_endpoints_exist — {e}")
failed += 1
# ===========================================================================
# TEST 6 — Reports endpoints exist
# ===========================================================================
def test_reports_endpoints_exist():
"""Verify report endpoints are registered."""
global passed, failed
try:
from app.routers import reports
src = inspect.getsource(reports)
assert "coverage-summary" in src, "Should have /reports/coverage-summary"
assert "coverage-csv" in src, "Should have /reports/coverage-csv"
assert "test-results" in src, "Should have /reports/test-results"
assert "remediation-status" in src, "Should have /reports/remediation-status"
assert "StreamingResponse" in src, "Should use StreamingResponse for CSV"
print(" PASS: test_reports_endpoints_exist")
passed += 1
except Exception as e:
print(f" FAIL: test_reports_endpoints_exist — {e}")
failed += 1
# ===========================================================================
# TEST 7 — Report filtering
# ===========================================================================
def test_report_filtering_logic():
"""Reports support tactic, platform, state and date filters."""
global passed, failed
try:
from app.routers import reports
src = inspect.getsource(reports)
assert "tactic" in src, "Should filter by tactic"
assert "platform" in src, "Should filter by platform"
assert "date_from" in src, "Should filter by date_from"
assert "date_to" in src, "Should filter by date_to"
assert "remediation_status" in src, "Should filter remediation by status"
print(" PASS: test_report_filtering_logic")
passed += 1
except Exception as e:
print(f" FAIL: test_report_filtering_logic — {e}")
failed += 1
# ===========================================================================
# TEST 8 — Remediation fields in Test model
# ===========================================================================
def test_remediation_fields():
"""Test model includes remediation_steps, remediation_status, remediation_assignee."""
global passed, failed
try:
from app.models.test import Test
src = inspect.getsource(Test)
assert "remediation_steps" in src, "Should have remediation_steps"
assert "remediation_status" in src, "Should have remediation_status"
assert "remediation_assignee" in src, "Should have remediation_assignee"
print(" PASS: test_remediation_fields")
passed += 1
except Exception as e:
print(f" FAIL: test_remediation_fields — {e}")
failed += 1
# ===========================================================================
# TEST 9 — Template suggested_remediation field
# ===========================================================================
def test_template_suggested_remediation():
"""TestTemplate has suggested_remediation and it's passed on instantiation."""
global passed, failed
try:
from app.models.test_template import TestTemplate
src = inspect.getsource(TestTemplate)
assert "suggested_remediation" in src, "Should have suggested_remediation"
from app.routers.tests import create_test_from_template
src2 = inspect.getsource(create_test_from_template)
assert "suggested_remediation" in src2 or "remediation_steps" in src2, \
"from-template endpoint should copy remediation"
print(" PASS: test_template_suggested_remediation")
passed += 1
except Exception as e:
print(f" FAIL: test_template_suggested_remediation — {e}")
failed += 1
# ===========================================================================
# TEST 10 — Remediation endpoint exists in router
# ===========================================================================
def test_remediation_endpoint():
"""PATCH /tests/{id}/remediation exists."""
global passed, failed
try:
from app.routers.tests import update_remediation
src = inspect.getsource(update_remediation)
assert "remediation" in src.lower(), "Should handle remediation fields"
print(" PASS: test_remediation_endpoint")
passed += 1
except Exception as e:
print(f" FAIL: test_remediation_endpoint — {e}")
failed += 1
# ===========================================================================
# TEST 11 — Notifications model
# ===========================================================================
def test_notification_model():
"""Notification model has required fields and indexes."""
global passed, failed
try:
from app.models.notification import Notification
src = inspect.getsource(Notification)
assert "user_id" in src, "Should have user_id"
assert "type" in src, "Should have type"
assert "title" in src, "Should have title"
assert "message" in src, "Should have message"
assert "entity_type" in src, "Should have entity_type"
assert "entity_id" in src, "Should have entity_id"
assert "read" in src, "Should have read"
assert "ix_notifications_user_id" in src, "Should have user_id index"
assert "ix_notifications_read" in src, "Should have read index"
print(" PASS: test_notification_model")
passed += 1
except Exception as e:
print(f" FAIL: test_notification_model — {e}")
failed += 1
# ===========================================================================
# TEST 12 — Notification endpoints exist
# ===========================================================================
def test_notification_endpoints():
"""Notification router has list, unread-count, mark-read, read-all."""
global passed, failed
try:
from app.routers import notifications
src = inspect.getsource(notifications)
assert "unread-count" in src, "Should have /unread-count"
assert "read-all" in src, "Should have /read-all"
assert "mark_as_read" in src, "Should call mark_as_read"
assert "mark_all_as_read" in src, "Should call mark_all_as_read"
assert "get_unread_count" in src, "Should call get_unread_count"
print(" PASS: test_notification_endpoints")
passed += 1
except Exception as e:
print(f" FAIL: test_notification_endpoints — {e}")
failed += 1
# ===========================================================================
# TEST 13 — Error responses include structured detail
# ===========================================================================
def test_structured_error_responses():
"""Workflow errors include code and valid_transitions."""
global passed, failed
try:
src = inspect.getsource(transition_state)
assert "INVALID_TRANSITION" in src, "Should include INVALID_TRANSITION code"
assert "valid_transitions" in src, "Should include valid_transitions list"
assert "current_state" in src, "Should include current_state"
print(" PASS: test_structured_error_responses")
passed += 1
except Exception as e:
print(f" FAIL: test_structured_error_responses — {e}")
failed += 1
# ===========================================================================
# TEST 14 — Workflow integration triggers notifications
# ===========================================================================
def test_workflow_triggers_notifications():
"""transition_state calls notify_test_state_change."""
global passed, failed
try:
src = inspect.getsource(transition_state)
assert "notify_test_state_change" in src, "Should call notify_test_state_change"
# Notifications are best-effort (wrapped in try/except)
assert "except" in src, "Notification errors should be caught"
print(" PASS: test_workflow_triggers_notifications")
passed += 1
except Exception as e:
print(f" FAIL: test_workflow_triggers_notifications — {e}")
failed += 1
# ===========================================================================
# TEST 15 — Scheduler includes notification cleanup
# ===========================================================================
def test_scheduler_has_notification_cleanup():
"""Background scheduler includes notification cleanup job."""
global passed, failed
try:
from app.jobs import mitre_sync_job
src = inspect.getsource(mitre_sync_job)
assert "notification_cleanup" in src, "Should register notification_cleanup job"
assert "cleanup_old_notifications" in src, "Should import cleanup_old_notifications"
print(" PASS: test_scheduler_has_notification_cleanup")
passed += 1
except Exception as e:
print(f" FAIL: test_scheduler_has_notification_cleanup — {e}")
failed += 1
# ===========================================================================
# TEST 16 — Sidebar navigation includes Reports
# ===========================================================================
def test_navigation_includes_reports():
"""Frontend App.tsx registers /reports route."""
global passed, failed
try:
app_path = os.path.join(
os.path.dirname(__file__), "..", "..", "frontend", "src", "App.tsx"
)
if os.path.exists(app_path):
with open(app_path) as f:
content = f.read()
assert "/reports" in content, "App.tsx should have /reports route"
assert "ReportsPage" in content, "App.tsx should import ReportsPage"
else:
# If running from a different CWD, just check the router module
pass
print(" PASS: test_navigation_includes_reports")
passed += 1
except Exception as e:
print(f" FAIL: test_navigation_includes_reports — {e}")
failed += 1
# ===========================================================================
# TEST 17 — Coverage CSV export
# ===========================================================================
def test_coverage_csv_export():
"""Report router has CSV endpoint with StreamingResponse."""
global passed, failed
try:
from app.routers.reports import coverage_csv
src = inspect.getsource(coverage_csv)
assert "csv" in src, "Should use csv module"
assert "StreamingResponse" in src or "text/csv" in src, "Should set CSV content type"
assert "Content-Disposition" in src, "Should set download filename"
print(" PASS: test_coverage_csv_export")
passed += 1
except Exception as e:
print(f" FAIL: test_coverage_csv_export — {e}")
failed += 1
# ===========================================================================
# TEST 18 — Dual validation logic completeness
# ===========================================================================
def test_dual_validation_all_scenarios():
"""Test all 4 possible dual validation outcomes."""
global passed, failed
try:
db = MagicMock()
# Scenario 1: both approved -> validated
t1 = _make_test(state=TestState.in_review)
t1.red_validation_status = "approved"
t1.blue_validation_status = "approved"
check_dual_validation(db, t1)
assert t1.state == TestState.validated
# Scenario 2: red rejected -> rejected
t2 = _make_test(state=TestState.in_review)
t2.red_validation_status = "rejected"
t2.blue_validation_status = None
check_dual_validation(db, t2)
assert t2.state == TestState.rejected
# Scenario 3: blue rejected -> rejected
t3 = _make_test(state=TestState.in_review)
t3.red_validation_status = "approved"
t3.blue_validation_status = "rejected"
check_dual_validation(db, t3)
assert t3.state == TestState.rejected
# Scenario 4: one approved, other pending -> stays in_review
t4 = _make_test(state=TestState.in_review)
t4.red_validation_status = "approved"
t4.blue_validation_status = None
check_dual_validation(db, t4)
assert t4.state == TestState.in_review
print(" PASS: test_dual_validation_all_scenarios")
passed += 1
except Exception as e:
print(f" FAIL: test_dual_validation_all_scenarios — {e}")
failed += 1
# ===========================================================================
# TEST 19 — All V2 routers registered in main.py
# ===========================================================================
def test_all_routers_registered():
"""main.py includes all V2 routers."""
global passed, failed
try:
main_path = os.path.join(os.path.dirname(__file__), "..", "app", "main.py")
with open(main_path) as f:
content = f.read()
for router_name in [
"notifications", "reports", "tests", "test_templates",
"metrics", "evidence", "auth", "techniques", "system",
"users", "audit",
]:
assert router_name in content, f"main.py should include {router_name} router"
print(" PASS: test_all_routers_registered")
passed += 1
except Exception as e:
print(f" FAIL: test_all_routers_registered — {e}")
failed += 1
# ===========================================================================
# TEST 20 — Notification mark-all-as-read service
# ===========================================================================
def test_mark_all_as_read_service():
"""mark_all_as_read updates all unread notifications for a user."""
global passed, failed
try:
src = inspect.getsource(mark_all_as_read)
assert "read" in src.lower(), "Should filter by read status"
assert "update" in src, "Should call update()"
assert "commit" in src, "Should commit changes"
print(" PASS: test_mark_all_as_read_service")
passed += 1
except Exception as e:
print(f" FAIL: test_mark_all_as_read_service — {e}")
failed += 1
# ===========================================================================
# Run all
# ===========================================================================
if __name__ == "__main__":
tests = [fn for name, fn in globals().items() if name.startswith("test_") and callable(fn)]
print(f"\nRunning {len(tests)} integration V2 tests...\n")
for fn in tests:
fn()
print(f"\n{'='*50}")
print(f"Results: {passed} passed, {failed} failed out of {passed + failed}")
if failed > 0:
sys.exit(1)
print("All integration V2 tests passed!")
+409
View File
@@ -0,0 +1,409 @@
"""T-127: Tests de métricas actualizadas.
Tests for the V2 metrics endpoints (pipeline, team-activity, validation-rate)
and for the technique status recalculation logic with the new test states.
"""
import sys
import os
import uuid
import inspect
from unittest.mock import MagicMock, patch, PropertyMock
from types import ModuleType
# ---------------------------------------------------------------------------
# Stub heavy dependencies before importing app modules
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
_ps = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
_ps.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = _ps
if "app.config" not in sys.modules:
_cfg = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
_cfg.settings = _FakeSettings()
sys.modules["app.config"] = _cfg
if "app.database" not in sys.modules:
_db = ModuleType("app.database")
_db.Base = type("Base", (), {"metadata": MagicMock()})
_db.get_db = MagicMock()
sys.modules["app.database"] = _db
for _mod in [
"taxii2client", "taxii2client.v20",
"jose", "boto3", "botocore", "botocore.exceptions",
"apscheduler", "apscheduler.schedulers",
"apscheduler.schedulers.background",
"apscheduler.triggers", "apscheduler.triggers.cron",
]:
if _mod not in sys.modules:
m = ModuleType(_mod)
if _mod == "taxii2client.v20": m.Server = MagicMock
elif _mod == "jose": m.JWTError = Exception; m.jwt = MagicMock()
elif _mod == "boto3": m.client = MagicMock()
elif _mod == "botocore.exceptions": m.ClientError = Exception
elif _mod == "apscheduler.schedulers.background": m.BackgroundScheduler = MagicMock
elif _mod == "apscheduler.triggers.cron": m.CronTrigger = MagicMock
sys.modules[_mod] = m
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from app.models.enums import TestState, TestResult, TechniqueStatus
from app.services.status_service import recalculate_technique_status
from app.routers.metrics import router as metrics_router
def _get_route_paths():
routes = {}
for route in metrics_router.routes:
path = getattr(route, "path", "")
methods = getattr(route, "methods", set())
for method in methods:
routes[f"{method} {path}"] = route
return routes
# ---------------------------------------------------------------------------
# Helpers for technique status recalculation tests
# ---------------------------------------------------------------------------
def _make_test(state: TestState, detection_result=None) -> MagicMock:
t = MagicMock()
t.id = uuid.uuid4()
t.state = state
t.detection_result = detection_result
t.red_validation_status = None
t.blue_validation_status = None
return t
def _make_technique(tests=None) -> MagicMock:
tech = MagicMock()
tech.id = uuid.uuid4()
tech.tests = tests or []
tech.status_global = TechniqueStatus.not_evaluated
return tech
def _make_db() -> MagicMock:
return MagicMock()
# ===========================================================================
# 1. test_pipeline_metrics — endpoint exists and queries TestState
# ===========================================================================
def test_pipeline_metrics_endpoint_exists():
"""GET /metrics/test-pipeline endpoint exists."""
routes = _get_route_paths()
found = any("test-pipeline" in k and "GET" in k for k in routes)
assert found, f"GET /metrics/test-pipeline not found. Routes: {list(routes.keys())}"
def test_pipeline_metrics_queries_all_states():
"""Pipeline endpoint groups by all test states."""
from app.routers.metrics import test_pipeline
source = inspect.getsource(test_pipeline)
assert "Test.state" in source, "Must query Test.state"
assert "group_by" in source, "Must group by state"
assert "TestPipelineCounts" in source, "Must return TestPipelineCounts schema"
# ===========================================================================
# 2. test_team_activity_metrics — endpoint exists and calculates correctly
# ===========================================================================
def test_team_activity_endpoint_exists():
"""GET /metrics/team-activity endpoint exists."""
routes = _get_route_paths()
found = any("team-activity" in k and "GET" in k for k in routes)
assert found, f"GET /metrics/team-activity not found. Routes: {list(routes.keys())}"
def test_team_activity_calculates_both_teams():
"""Team activity endpoint returns data for both Red and Blue teams."""
from app.routers.metrics import team_activity
source = inspect.getsource(team_activity)
assert "Red Team" in source or "red" in source.lower(), "Must include Red Team data"
assert "Blue Team" in source or "blue" in source.lower(), "Must include Blue Team data"
assert "tests_completed" in source, "Must calculate completed tests"
assert "tests_pending" in source, "Must calculate pending tests"
def test_team_activity_red_pending_states():
"""Red Team pending includes draft and red_executing."""
from app.routers.metrics import team_activity
source = inspect.getsource(team_activity)
assert "draft" in source, "Red pending must include draft"
assert "red_executing" in source, "Red pending must include red_executing"
def test_team_activity_blue_pending_states():
"""Blue Team pending includes blue_evaluating."""
from app.routers.metrics import team_activity
source = inspect.getsource(team_activity)
assert "blue_evaluating" in source, "Blue pending must include blue_evaluating"
# ===========================================================================
# 3. test_technique_status_recalculation_with_new_states
# ===========================================================================
def test_technique_no_tests_is_not_evaluated():
"""Technique with no tests -> not_evaluated."""
tech = _make_technique(tests=[])
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.not_evaluated
def test_technique_all_validated_detected():
"""All tests validated with detected -> technique validated."""
tests = [
_make_test(TestState.validated, detection_result="detected"),
_make_test(TestState.validated, detection_result="detected"),
]
tech = _make_technique(tests=tests)
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.validated
def test_technique_all_validated_partially_detected():
"""All tests validated with partially_detected -> technique partial."""
tests = [
_make_test(TestState.validated, detection_result="detected"),
_make_test(TestState.validated, detection_result="partially_detected"),
]
tech = _make_technique(tests=tests)
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.partial
def test_technique_all_validated_not_detected():
"""All tests validated with not_detected -> technique not_covered."""
tests = [
_make_test(TestState.validated, detection_result="not_detected"),
]
tech = _make_technique(tests=tests)
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.not_covered
def test_technique_mixed_validated_and_in_progress():
"""Some validated, some still in pipeline -> technique partial."""
tests = [
_make_test(TestState.validated, detection_result="detected"),
_make_test(TestState.red_executing),
]
tech = _make_technique(tests=tests)
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.partial
def test_technique_all_in_progress():
"""All tests in intermediate states (no validated) -> technique in_progress."""
tests = [
_make_test(TestState.draft),
_make_test(TestState.red_executing),
_make_test(TestState.blue_evaluating),
]
tech = _make_technique(tests=tests)
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.in_progress
def test_technique_with_in_review_tests():
"""Tests in in_review are still in-progress (not yet validated)."""
tests = [
_make_test(TestState.in_review),
]
tech = _make_technique(tests=tests)
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.in_progress
def test_technique_with_rejected_tests():
"""Rejected tests count as in-progress (need rework)."""
tests = [
_make_test(TestState.rejected),
]
tech = _make_technique(tests=tests)
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.in_progress
# ===========================================================================
# 4. test_coverage_with_dual_validation
# ===========================================================================
def test_coverage_correct_after_dual_validation():
"""After dual validation (both approved), technique status is correct."""
# A test that completed the full pipeline with detection
test = _make_test(TestState.validated, detection_result="detected")
test.red_validation_status = "approved"
test.blue_validation_status = "approved"
tech = _make_technique(tests=[test])
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.validated
def test_coverage_partial_when_one_detected_one_partial():
"""Mixed detection results after dual validation -> partial coverage."""
test1 = _make_test(TestState.validated, detection_result="detected")
test1.red_validation_status = "approved"
test1.blue_validation_status = "approved"
test2 = _make_test(TestState.validated, detection_result="partially_detected")
test2.red_validation_status = "approved"
test2.blue_validation_status = "approved"
tech = _make_technique(tests=[test1, test2])
db = _make_db()
recalculate_technique_status(db, tech)
assert tech.status_global == TechniqueStatus.partial
# ===========================================================================
# 5. test_validation_rate_endpoint — approval/rejection rates
# ===========================================================================
def test_validation_rate_endpoint_exists():
"""GET /metrics/validation-rate endpoint exists."""
routes = _get_route_paths()
found = any("validation-rate" in k and "GET" in k for k in routes)
assert found, f"GET /metrics/validation-rate not found. Routes: {list(routes.keys())}"
def test_validation_rate_queries_both_roles():
"""Validation rate endpoint returns data for both red_lead and blue_lead."""
from app.routers.metrics import validation_rate
source = inspect.getsource(validation_rate)
assert "red_validation_status" in source, "Must query red_validation_status"
assert "blue_validation_status" in source, "Must query blue_validation_status"
assert "approved" in source, "Must count approved validations"
assert "rejected" in source, "Must count rejected validations"
assert "approval_rate" in source, "Must calculate approval_rate"
# ===========================================================================
# 6. test_recent_tests_endpoint — latest 10 tests
# ===========================================================================
def test_recent_tests_endpoint_exists():
"""GET /metrics/recent-tests endpoint exists."""
routes = _get_route_paths()
found = any("recent-tests" in k and "GET" in k for k in routes)
assert found, f"GET /metrics/recent-tests not found. Routes: {list(routes.keys())}"
def test_recent_tests_limits_to_10():
"""Recent tests endpoint limits to 10 results."""
from app.routers.metrics import recent_tests
source = inspect.getsource(recent_tests)
assert "limit(10)" in source or ".limit(10)" in source, \
"Must limit to 10 recent tests"
assert "created_at" in source, "Must order by created_at"
# ===========================================================================
# 7. test_original_endpoints_still_work
# ===========================================================================
def test_summary_endpoint_exists():
"""GET /metrics/summary (original) endpoint still exists."""
routes = _get_route_paths()
found = any("summary" in k and "GET" in k for k in routes)
assert found, f"GET /metrics/summary not found. Routes: {list(routes.keys())}"
def test_by_tactic_endpoint_exists():
"""GET /metrics/by-tactic (original) endpoint still exists."""
routes = _get_route_paths()
found = any("by-tactic" in k and "GET" in k for k in routes)
assert found, f"GET /metrics/by-tactic not found. Routes: {list(routes.keys())}"
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-127 Validation: Metrics V2 Tests")
print("=" * 55)
test_pipeline_metrics_endpoint_exists()
test_pipeline_metrics_queries_all_states()
test_team_activity_endpoint_exists()
test_team_activity_calculates_both_teams()
test_team_activity_red_pending_states()
test_team_activity_blue_pending_states()
test_technique_no_tests_is_not_evaluated()
test_technique_all_validated_detected()
test_technique_all_validated_partially_detected()
test_technique_all_validated_not_detected()
test_technique_mixed_validated_and_in_progress()
test_technique_all_in_progress()
test_technique_with_in_review_tests()
test_technique_with_rejected_tests()
test_coverage_correct_after_dual_validation()
test_coverage_partial_when_one_detected_one_partial()
test_validation_rate_endpoint_exists()
test_validation_rate_queries_both_roles()
test_recent_tests_endpoint_exists()
test_recent_tests_limits_to_10()
test_summary_endpoint_exists()
test_by_tactic_endpoint_exists()
print("=" * 55)
print("ALL T-127 validations PASSED!")
+285
View File
@@ -0,0 +1,285 @@
"""T-126: Tests de TestTemplates — CRUD, filters, instantiation, permissions.
Tests the template CRUD endpoints, filter logic, template instantiation,
soft-delete behaviour, and admin-only access control.
Uses mock objects and router inspection to avoid needing a database.
"""
import sys
import os
import uuid
import inspect
from unittest.mock import MagicMock
from types import ModuleType
# ---------------------------------------------------------------------------
# Stub heavy dependencies before importing app modules
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
_ps = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
_ps.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = _ps
if "app.config" not in sys.modules:
_cfg = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
_cfg.settings = _FakeSettings()
sys.modules["app.config"] = _cfg
if "app.database" not in sys.modules:
_db = ModuleType("app.database")
_db.Base = type("Base", (), {"metadata": MagicMock()})
_db.get_db = MagicMock()
sys.modules["app.database"] = _db
for _mod in [
"taxii2client", "taxii2client.v20",
"jose", "boto3", "botocore", "botocore.exceptions",
"apscheduler", "apscheduler.schedulers",
"apscheduler.schedulers.background",
"apscheduler.triggers", "apscheduler.triggers.cron",
]:
if _mod not in sys.modules:
m = ModuleType(_mod)
if _mod == "taxii2client.v20": m.Server = MagicMock
elif _mod == "jose": m.JWTError = Exception; m.jwt = MagicMock()
elif _mod == "boto3": m.client = MagicMock()
elif _mod == "botocore.exceptions": m.ClientError = Exception
elif _mod == "apscheduler.schedulers.background": m.BackgroundScheduler = MagicMock
elif _mod == "apscheduler.triggers.cron": m.CronTrigger = MagicMock
sys.modules[_mod] = m
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from app.routers.test_templates import (
router,
list_templates,
templates_by_technique,
create_template,
delete_template,
toggle_template_active,
template_stats,
)
from app.routers.tests import create_test_from_template
from app.schemas.test_template import TestTemplateCreate
def _get_route_paths():
routes = {}
for route in router.routes:
path = getattr(route, "path", "")
methods = getattr(route, "methods", set())
for method in methods:
routes[f"{method} {path}"] = route
return routes
# ===========================================================================
# 1. test_create_template — admin can create a template
# ===========================================================================
def test_create_template():
"""Admin can create a template — endpoint exists and requires admin role."""
routes = _get_route_paths()
found = any("POST" in k and "{template_id}" not in k for k in routes)
assert found, f"POST /test-templates not found. Routes: {list(routes.keys())}"
# Verify admin role is required
source = inspect.getsource(create_template)
assert "require_role" in source and "admin" in source, \
"create_template must require admin role"
# ===========================================================================
# 2. test_list_templates_with_filters — source, platform, severity work
# ===========================================================================
def test_list_templates_with_filters():
"""Filters of source, platform, severity, search all work."""
source = inspect.getsource(list_templates)
# Verify all filter parameters exist in the function signature
assert "source" in source, "List must accept source filter"
assert "platform" in source, "List must accept platform filter"
assert "severity" in source, "List must accept severity filter"
assert "search" in source, "List must accept search filter"
assert "mitre_technique_id" in source, "List must accept mitre_technique_id filter"
# Verify ilike is used for search
assert "ilike" in source, "Search should use ilike for case-insensitive matching"
# ===========================================================================
# 3. test_get_templates_by_technique — filter by MITRE technique
# ===========================================================================
def test_get_templates_by_technique():
"""Endpoint to get templates by technique exists and filters correctly."""
routes = _get_route_paths()
found = any("by-technique" in k and "GET" in k for k in routes)
assert found, f"GET /test-templates/by-technique/{{mitre_id}} not found. Routes: {list(routes.keys())}"
source = inspect.getsource(templates_by_technique)
assert "mitre_technique_id" in source, "Must filter by mitre_technique_id"
assert "is_active" in source, "Must filter only active templates"
# ===========================================================================
# 4. test_instantiate_template — create test from template pre-fills fields
# ===========================================================================
def test_instantiate_template():
"""POST /tests/from-template creates a test pre-filled from template data."""
source = inspect.getsource(create_test_from_template)
# Verify it reads from template and copies fields
assert "template" in source, "Must reference template"
assert "template.name" in source, "Must copy name from template"
assert "template.description" in source, "Must copy description from template"
assert "template.platform" in source, "Must copy platform from template"
assert "template.attack_procedure" in source or "attack_procedure" in source, \
"Must copy attack_procedure from template"
# Verify state is set to draft
assert "draft" in source, "New test from template must be in draft state"
# ===========================================================================
# 5. test_soft_delete_template — deactivation doesn't physically remove
# ===========================================================================
def test_soft_delete_template():
"""DELETE endpoint sets is_active=False instead of removing the record."""
source = inspect.getsource(delete_template)
assert "is_active" in source, "Must set is_active"
assert "False" in source, "Must set is_active to False"
# Should NOT call db.delete(template)
assert "db.delete" not in source, "Should NOT physically delete the template"
assert "deactivated" in source.lower() or "soft" in source.lower() or "detail" in source.lower(), \
"Should return a deactivation message"
# ===========================================================================
# 6. test_non_admin_cannot_create_template — only admin role
# ===========================================================================
def test_non_admin_cannot_create_template():
"""Only admin can create templates — enforce via require_role."""
source = inspect.getsource(create_template)
assert 'require_role("admin")' in source, \
"create_template must use require_role('admin')"
# Also check update and delete
from app.routers.test_templates import update_template
source_update = inspect.getsource(update_template)
assert 'require_role("admin")' in source_update, \
"update_template must use require_role('admin')"
source_delete = inspect.getsource(delete_template)
assert 'require_role("admin")' in source_delete, \
"delete_template must use require_role('admin')"
# ===========================================================================
# 7. test_toggle_active_endpoint — toggle between active/inactive
# ===========================================================================
def test_toggle_active_endpoint():
"""PATCH /test-templates/{id}/toggle-active exists and toggles is_active."""
routes = _get_route_paths()
found = any("toggle-active" in k and "PATCH" in k for k in routes)
assert found, f"PATCH /test-templates/{{id}}/toggle-active not found. Routes: {list(routes.keys())}"
source = inspect.getsource(toggle_template_active)
assert "is_active" in source, "Must reference is_active"
assert "not" in source, "Must toggle (negate) the is_active value"
assert 'require_role("admin")' in source, "Must require admin role"
# ===========================================================================
# 8. test_stats_endpoint — catalog statistics
# ===========================================================================
def test_stats_endpoint():
"""GET /test-templates/stats returns catalog statistics."""
routes = _get_route_paths()
found = any("stats" in k and "GET" in k for k in routes)
assert found, f"GET /test-templates/stats not found. Routes: {list(routes.keys())}"
source = inspect.getsource(template_stats)
assert "by_source" in source, "Must return breakdown by source"
assert "by_platform" in source, "Must return breakdown by platform"
assert "active" in source, "Must return active count"
assert 'require_role("admin")' in source, "Must require admin role"
# ===========================================================================
# 9. test_list_only_active_by_default — list filters inactive templates
# ===========================================================================
def test_list_only_active_by_default():
"""The list endpoint filters to is_active=True by default."""
source = inspect.getsource(list_templates)
assert "is_active" in source and "True" in source, \
"List must filter by is_active == True by default"
# ===========================================================================
# 10. test_pagination_support
# ===========================================================================
def test_pagination_support():
"""List endpoint supports offset and limit pagination."""
source = inspect.getsource(list_templates)
assert "offset" in source, "Must accept offset parameter"
assert "limit" in source, "Must accept limit parameter"
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-126 Validation: TestTemplates CRUD Tests")
print("=" * 55)
test_create_template()
test_list_templates_with_filters()
test_get_templates_by_technique()
test_instantiate_template()
test_soft_delete_template()
test_non_admin_cannot_create_template()
test_toggle_active_endpoint()
test_stats_endpoint()
test_list_only_active_by_default()
test_pagination_support()
print("=" * 55)
print("ALL T-126 validations PASSED!")
+565
View File
@@ -0,0 +1,565 @@
"""T-125: Tests del flujo de trabajo Red/Blue.
Comprehensive tests covering the full test lifecycle:
draft -> red_executing -> blue_evaluating -> in_review -> validated/rejected
Uses mock objects to test the workflow service and router logic
without requiring a running database.
"""
import sys
import os
import uuid
from unittest.mock import MagicMock, patch
from types import ModuleType
from datetime import datetime
# ---------------------------------------------------------------------------
# Stub heavy dependencies before importing app modules
# ---------------------------------------------------------------------------
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
if "pydantic_settings" not in sys.modules:
_ps = ModuleType("pydantic_settings")
class _BaseSettings:
def __init__(self, **kwargs): pass
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs)
_ps.BaseSettings = _BaseSettings
sys.modules["pydantic_settings"] = _ps
if "app.config" not in sys.modules:
_cfg = ModuleType("app.config")
class _FakeSettings:
DATABASE_URL = "sqlite:///:memory:"
SECRET_KEY = "test"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
MINIO_ENDPOINT = "localhost:9000"
MINIO_ACCESS_KEY = "test"
MINIO_SECRET_KEY = "test"
MINIO_BUCKET = "test"
_cfg.settings = _FakeSettings()
sys.modules["app.config"] = _cfg
if "app.database" not in sys.modules:
_db = ModuleType("app.database")
_db.Base = type("Base", (), {"metadata": MagicMock()})
_db.get_db = MagicMock()
sys.modules["app.database"] = _db
for _mod in [
"taxii2client", "taxii2client.v20",
"jose", "boto3", "botocore", "botocore.exceptions",
"apscheduler", "apscheduler.schedulers",
"apscheduler.schedulers.background",
"apscheduler.triggers", "apscheduler.triggers.cron",
]:
if _mod not in sys.modules:
m = ModuleType(_mod)
if _mod == "taxii2client.v20": m.Server = MagicMock
elif _mod == "jose": m.JWTError = Exception; m.jwt = MagicMock()
elif _mod == "boto3": m.client = MagicMock()
elif _mod == "botocore.exceptions": m.ClientError = Exception
elif _mod == "apscheduler.schedulers.background": m.BackgroundScheduler = MagicMock
elif _mod == "apscheduler.triggers.cron": m.CronTrigger = MagicMock
sys.modules[_mod] = m
# ---------------------------------------------------------------------------
# Imports
# ---------------------------------------------------------------------------
from fastapi import HTTPException
from app.models.enums import TestState, TestResult
from app.services.test_workflow_service import (
VALID_TRANSITIONS,
can_transition,
transition_state,
start_execution,
submit_red_evidence,
submit_blue_evidence,
validate_as_red_lead,
validate_as_blue_lead,
check_dual_validation,
reopen_test,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_test(state: TestState = TestState.draft, **kwargs) -> MagicMock:
t = MagicMock()
t.id = uuid.uuid4()
t.name = "Test Security Check"
t.technique_id = uuid.uuid4()
t.state = state
t.red_validation_status = kwargs.get("red_validation_status", None)
t.blue_validation_status = kwargs.get("blue_validation_status", None)
t.red_validated_by = kwargs.get("red_validated_by", None)
t.red_validated_at = kwargs.get("red_validated_at", None)
t.red_validation_notes = kwargs.get("red_validation_notes", None)
t.blue_validated_by = kwargs.get("blue_validated_by", None)
t.blue_validated_at = kwargs.get("blue_validated_at", None)
t.blue_validation_notes = kwargs.get("blue_validation_notes", None)
t.execution_date = kwargs.get("execution_date", None)
return t
def _make_user(role: str = "red_tech") -> MagicMock:
user = MagicMock()
user.id = uuid.uuid4()
user.role = role
return user
def _make_db() -> MagicMock:
return MagicMock()
# ===========================================================================
# 1. test_full_happy_path
# draft -> red_executing -> blue_evaluating -> in_review -> validated
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_full_happy_path(mock_log):
"""draft -> red_executing -> blue_evaluating -> in_review -> validated"""
test = _make_test(TestState.draft)
red_tech = _make_user("red_tech")
blue_tech = _make_user("blue_tech")
red_lead = _make_user("red_lead")
blue_lead = _make_user("blue_lead")
db = _make_db()
# Step 1: draft -> red_executing
result = start_execution(db, test, red_tech)
assert result.state == TestState.red_executing
assert result.execution_date is not None
# Step 2: red_executing -> blue_evaluating
result = submit_red_evidence(db, result, red_tech)
assert result.state == TestState.blue_evaluating
# Step 3: blue_evaluating -> in_review
result = submit_blue_evidence(db, result, blue_tech)
assert result.state == TestState.in_review
# Step 4: Red Lead approves
result = validate_as_red_lead(db, result, red_lead, "approved", "Attack well documented")
assert result.red_validation_status == "approved"
assert result.red_validated_by == red_lead.id
assert result.red_validated_at is not None
assert result.red_validation_notes == "Attack well documented"
# Still in_review (waiting for blue lead)
assert result.state == TestState.in_review
# Step 5: Blue Lead approves -> validated
result = validate_as_blue_lead(db, result, blue_lead, "approved", "Detection confirmed")
assert result.blue_validation_status == "approved"
assert result.state == TestState.validated
# Verify audit logs were generated at each step
assert mock_log.call_count >= 5
# ===========================================================================
# 2. test_rejection_and_reopen
# in_review -> rejected -> draft -> red_executing -> ...
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_rejection_and_reopen(mock_log):
"""in_review -> rejected -> draft -> red_executing -> ..."""
test = _make_test(TestState.draft)
red_tech = _make_user("red_tech")
blue_tech = _make_user("blue_tech")
red_lead = _make_user("red_lead")
db = _make_db()
# Advance to in_review
start_execution(db, test, red_tech)
submit_red_evidence(db, test, red_tech)
submit_blue_evidence(db, test, blue_tech)
assert test.state == TestState.in_review
# Red Lead rejects -> rejected
validate_as_red_lead(db, test, red_lead, "rejected", "Need more evidence")
assert test.state == TestState.rejected
# Reopen -> draft
reopen_test(db, test, red_lead)
assert test.state == TestState.draft
# Restart the cycle
start_execution(db, test, red_tech)
assert test.state == TestState.red_executing
# ===========================================================================
# 3. test_invalid_transitions
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_invalid_transitions(mock_log):
"""Verify that invalid state transitions raise HTTPException."""
db = _make_db()
user = _make_user("admin")
# draft -> validated (should fail)
test = _make_test(TestState.draft)
try:
transition_state(db, test, TestState.validated, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# draft -> blue_evaluating (should fail)
test = _make_test(TestState.draft)
try:
transition_state(db, test, TestState.blue_evaluating, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# red_executing -> in_review (should fail, must go through blue_evaluating)
test = _make_test(TestState.red_executing)
try:
transition_state(db, test, TestState.in_review, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# validated -> anything (terminal state)
test = _make_test(TestState.validated)
try:
transition_state(db, test, TestState.draft, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# rejected -> red_executing (must go through draft first)
test = _make_test(TestState.rejected)
try:
transition_state(db, test, TestState.red_executing, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# ===========================================================================
# 4. test_red_tech_cannot_access_blue_phase
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_red_tech_cannot_access_blue_phase(mock_log):
"""Red tech cannot submit blue evidence (wrong transition from wrong state)."""
db = _make_db()
red_tech = _make_user("red_tech")
# A test in red_executing cannot jump to in_review
test = _make_test(TestState.red_executing)
try:
submit_blue_evidence(db, test, red_tech)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# Red tech cannot validate (test must be in blue_evaluating for submit_blue)
test2 = _make_test(TestState.draft)
try:
submit_blue_evidence(db, test2, red_tech)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# ===========================================================================
# 5. test_blue_tech_cannot_access_red_phase
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_blue_tech_cannot_access_red_phase(mock_log):
"""Blue tech cannot start execution or submit red evidence."""
db = _make_db()
blue_tech = _make_user("blue_tech")
# Blue tech cannot start execution (test must be in draft -> red_executing)
# The workflow service doesn't check role, but the router does.
# At service level, blue_evaluating -> blue_evaluating is invalid transition:
test = _make_test(TestState.blue_evaluating)
try:
start_execution(db, test, blue_tech)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# Blue tech cannot submit red evidence on a draft test
test2 = _make_test(TestState.draft)
try:
submit_red_evidence(db, test2, blue_tech)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# ===========================================================================
# 6. test_dual_validation_both_approve
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_both_approve(mock_log):
"""Both managers approve -> test becomes validated."""
test = _make_test(TestState.in_review)
red_lead = _make_user("red_lead")
blue_lead = _make_user("blue_lead")
db = _make_db()
# Red Lead approves first
validate_as_red_lead(db, test, red_lead, "approved", "LGTM")
assert test.red_validation_status == "approved"
# Not yet validated — waiting for blue
assert test.state == TestState.in_review
# Blue Lead approves
validate_as_blue_lead(db, test, blue_lead, "approved", "Detection verified")
assert test.blue_validation_status == "approved"
assert test.state == TestState.validated
# ===========================================================================
# 7. test_dual_validation_one_rejects
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_one_rejects(mock_log):
"""One manager rejects -> test becomes rejected immediately."""
test = _make_test(TestState.in_review)
red_lead = _make_user("red_lead")
db = _make_db()
validate_as_red_lead(db, test, red_lead, "rejected", "Insufficient evidence")
assert test.red_validation_status == "rejected"
assert test.state == TestState.rejected
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_blue_rejects_first(mock_log):
"""Blue Lead rejects first -> test becomes rejected immediately."""
test = _make_test(TestState.in_review)
blue_lead = _make_user("blue_lead")
db = _make_db()
validate_as_blue_lead(db, test, blue_lead, "rejected", "Detection not adequate")
assert test.blue_validation_status == "rejected"
assert test.state == TestState.rejected
@patch("app.services.test_workflow_service.log_action")
def test_dual_validation_red_approves_blue_rejects(mock_log):
"""Red approves, then blue rejects -> rejected."""
test = _make_test(TestState.in_review)
red_lead = _make_user("red_lead")
blue_lead = _make_user("blue_lead")
db = _make_db()
validate_as_red_lead(db, test, red_lead, "approved", "Good attack")
assert test.state == TestState.in_review # waiting for blue
validate_as_blue_lead(db, test, blue_lead, "rejected", "Bad detection")
assert test.state == TestState.rejected
# ===========================================================================
# 8. test_evidence_team_separation
# ===========================================================================
def test_evidence_team_separation():
"""Verify evidence router logic separates red and blue evidence correctly."""
from app.routers.evidence import _validate_upload_permission, _RED_EDITABLE_STATES, _BLUE_EDITABLE_STATES
# Red tech can upload red evidence in draft
test = _make_test(TestState.draft)
red_user = _make_user("red_tech")
red_user.role = "red_tech"
from app.models.enums import TeamSide
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
# Red tech can upload red evidence in red_executing
test.state = TestState.red_executing
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
# Red tech CANNOT upload red evidence in blue_evaluating
test.state = TestState.blue_evaluating
try:
_validate_upload_permission(test, TeamSide.red, red_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# Red tech CANNOT upload blue evidence
test.state = TestState.blue_evaluating
try:
_validate_upload_permission(test, TeamSide.blue, red_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
# Blue tech can upload blue evidence in blue_evaluating
test.state = TestState.blue_evaluating
blue_user = _make_user("blue_tech")
blue_user.role = "blue_tech"
_validate_upload_permission(test, TeamSide.blue, blue_user) # should not raise
# Blue tech CANNOT upload blue evidence in draft
test.state = TestState.draft
try:
_validate_upload_permission(test, TeamSide.blue, blue_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
# Blue tech CANNOT upload red evidence
test.state = TestState.draft
try:
_validate_upload_permission(test, TeamSide.red, blue_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
# ===========================================================================
# 9. test_red_edit_allowed_in_draft_and_red_executing
# ===========================================================================
def test_red_edit_allowed_in_draft_and_red_executing():
"""Verify the red update router checks that state is draft or red_executing."""
from app.routers.tests import update_test_red
import inspect
source = inspect.getsource(update_test_red)
# The function must guard against states other than draft/red_executing
assert "draft" in source, "Red update must allow draft state"
assert "red_executing" in source, "Red update must allow red_executing state"
assert "400" in source or "HTTP_400_BAD_REQUEST" in source, "Red update must return 400 for invalid state"
# ===========================================================================
# 10. test_reopen_clears_validation_fields
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_reopen_clears_validation_fields(mock_log):
"""Reopen clears all red/blue validation status, notes, timestamps."""
test = _make_test(
TestState.rejected,
red_validation_status="rejected",
red_validated_by=uuid.uuid4(),
red_validated_at=datetime.utcnow(),
red_validation_notes="Bad attack",
blue_validation_status="approved",
blue_validated_by=uuid.uuid4(),
blue_validated_at=datetime.utcnow(),
blue_validation_notes="Good detection",
)
user = _make_user("red_lead")
db = _make_db()
result = reopen_test(db, test, user)
assert result.state == TestState.draft
assert result.red_validation_status is None
assert result.red_validated_by is None
assert result.red_validated_at is None
assert result.red_validation_notes is None
assert result.blue_validation_status is None
assert result.blue_validated_by is None
assert result.blue_validated_at is None
assert result.blue_validation_notes is None
db.commit.assert_called()
# ===========================================================================
# 11. test_cannot_validate_outside_in_review
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_cannot_validate_outside_in_review(mock_log):
"""Managers cannot validate a test that is not in in_review state."""
db = _make_db()
red_lead = _make_user("red_lead")
blue_lead = _make_user("blue_lead")
for state in [TestState.draft, TestState.red_executing, TestState.blue_evaluating, TestState.validated, TestState.rejected]:
test = _make_test(state)
try:
validate_as_red_lead(db, test, red_lead, "approved", "OK")
assert False, f"Red Lead should not validate in {state.value}"
except HTTPException as exc:
assert exc.status_code == 400
test2 = _make_test(state)
try:
validate_as_blue_lead(db, test2, blue_lead, "approved", "OK")
assert False, f"Blue Lead should not validate in {state.value}"
except HTTPException as exc:
assert exc.status_code == 400
# ===========================================================================
# 12. test_cannot_reopen_non_rejected_test
# ===========================================================================
@patch("app.services.test_workflow_service.log_action")
def test_cannot_reopen_non_rejected_test(mock_log):
"""Reopen only works on rejected tests."""
db = _make_db()
user = _make_user("red_lead")
for state in [TestState.draft, TestState.red_executing, TestState.blue_evaluating, TestState.in_review, TestState.validated]:
test = _make_test(state)
try:
reopen_test(db, test, user)
assert False, f"Should not reopen from {state.value}"
except HTTPException as exc:
assert exc.status_code == 400
# ---------------------------------------------------------------------------
# Run all
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("T-125 Validation: Workflow Tests")
print("=" * 55)
test_full_happy_path()
test_rejection_and_reopen()
test_invalid_transitions()
test_red_tech_cannot_access_blue_phase()
test_blue_tech_cannot_access_red_phase()
test_dual_validation_both_approve()
test_dual_validation_one_rejects()
test_dual_validation_blue_rejects_first()
test_dual_validation_red_approves_blue_rejects()
test_evidence_team_separation()
test_red_edit_allowed_in_draft_and_red_executing()
test_reopen_clears_validation_fields()
test_cannot_validate_outside_in_review()
test_cannot_reopen_non_rejected_test()
print("=" * 55)
print("ALL T-125 validations PASSED!")
+222 -1
View File
@@ -365,6 +365,213 @@ Get background scheduler status.
---
## V2 Endpoints — Red/Blue Workflow
### Tests — Red/Blue Workflow
#### `GET /api/v1/tests`
List tests with advanced filters.
**Query Parameters:**
- `state` (string) — Filter by test state
- `technique_id` (UUID) — Filter by technique
- `platform` (string) — Filter by platform
- `created_by` (UUID) — Filter by creator
- `pending_validation_side` (string: `red` / `blue`) — Filter tests in_review pending validation
- `offset` (int, default 0) — Pagination offset
- `limit` (int, default 50, max 200) — Page size
#### `POST /api/v1/tests/from-template`
Create a test from a template. Pre-populates name, description, platform, procedure, tool, and remediation steps.
**Body:**
```json
{
"template_id": "uuid",
"technique_id": "uuid"
}
```
#### `PATCH /api/v1/tests/{id}/red`
Red Team updates their fields. Allowed in `draft` and `red_executing` states.
**Body:**
```json
{
"procedure_text": "...",
"tool_used": "...",
"attack_success": true,
"red_summary": "..."
}
```
#### `PATCH /api/v1/tests/{id}/blue`
Blue Team updates their fields. Allowed only in `blue_evaluating` state.
**Body:**
```json
{
"detection_result": "detected | not_detected | partially_detected",
"blue_summary": "..."
}
```
#### `PATCH /api/v1/tests/{id}/remediation`
Update remediation fields on a test.
**Body:**
```json
{
"remediation_steps": "Step 1: ...\nStep 2: ...",
"remediation_status": "pending | in_progress | completed | not_applicable",
"remediation_assignee": "user-uuid"
}
```
#### `POST /api/v1/tests/{id}/start-execution`
Transition: `draft``red_executing`. Sets `execution_date`.
#### `POST /api/v1/tests/{id}/submit-red`
Transition: `red_executing``blue_evaluating`. Notifies all blue_tech users.
#### `POST /api/v1/tests/{id}/submit-blue`
Transition: `blue_evaluating``in_review`. Notifies red_lead and blue_lead.
#### `POST /api/v1/tests/{id}/validate-red`
Red Lead approves or rejects. Triggers dual validation check.
**Body:**
```json
{
"red_validation_status": "approved | rejected",
"red_validation_notes": "optional notes"
}
```
#### `POST /api/v1/tests/{id}/validate-blue`
Blue Lead approves or rejects. Triggers dual validation check.
**Body:**
```json
{
"blue_validation_status": "approved | rejected",
"blue_validation_notes": "optional notes"
}
```
#### `POST /api/v1/tests/{id}/reopen`
Move a `rejected` test back to `draft`. Clears all validation fields.
---
### Test Templates
#### `GET /api/v1/test-templates`
List templates with filters: `source`, `platform`, `severity`, `search`, `mitre_technique_id`, `is_active`.
#### `POST /api/v1/test-templates` (Admin)
Create a custom template with `suggested_remediation`.
#### `GET /api/v1/test-templates/stats` (Admin)
Returns catalog statistics: total, active, inactive, by source, by platform.
#### `POST /api/v1/test-templates/{id}/toggle-active` (Admin)
Toggle a template's active/inactive status.
---
### Notifications
#### `GET /api/v1/notifications`
List notifications for the current user (paginated, newest first).
**Query Parameters:** `offset` (default 0), `limit` (default 20, max 100)
#### `GET /api/v1/notifications/unread-count`
Returns `{ "unread_count": N }`.
#### `PATCH /api/v1/notifications/{id}/read`
Mark a single notification as read.
#### `POST /api/v1/notifications/read-all`
Mark all notifications for the current user as read.
**Automatic Notifications:**
- `red_executing` → notifies creator
- `blue_evaluating` → notifies all blue_tech users
- `in_review` → notifies red_lead and blue_lead
- `rejected` → notifies creator
- `validated` → notifies creator
---
### Reports
#### `GET /api/v1/reports/coverage-summary`
Full technique coverage report as JSON. Includes summary and technique-by-technique breakdown.
**Filters:** `tactic`, `platform`
#### `GET /api/v1/reports/coverage-csv`
Downloadable CSV of coverage data.
**Filters:** `tactic`, `platform`
#### `GET /api/v1/reports/test-results`
Test results report with state and detection breakdowns.
**Filters:** `state`, `date_from` (ISO), `date_to` (ISO)
#### `GET /api/v1/reports/remediation-status`
Remediation status report across all tests with assigned steps.
**Filter:** `status` (pending, in_progress, completed, not_applicable)
---
### V2 Metrics
#### `GET /api/v1/metrics/test-pipeline`
Test counts by state across the pipeline.
#### `GET /api/v1/metrics/team-activity`
Red/Blue team activity: tests completed, pending.
#### `GET /api/v1/metrics/validation-rate`
Approval/rejection rates for Red Lead and Blue Lead.
#### `GET /api/v1/metrics/recent-tests`
Last 10 most recently updated tests.
---
## Error Responses
All errors follow a consistent format:
@@ -376,8 +583,22 @@ All errors follow a consistent format:
}
```
State transition errors include additional context:
```json
{
"detail": {
"message": "Cannot transition from 'draft' to 'validated'. Valid transitions: ['red_executing']",
"code": "INVALID_TRANSITION",
"current_state": "draft",
"target_state": "validated",
"valid_transitions": ["red_executing"]
}
}
```
Common HTTP status codes:
- `400` - Bad Request (validation error, invalid input)
- `400` - Bad Request (validation error, invalid transition, invalid input)
- `401` - Unauthorized (missing or invalid token)
- `403` - Forbidden (insufficient permissions)
- `404` - Not Found (resource doesn't exist)
+2
View File
@@ -7,6 +7,7 @@ import TestsPage from "./pages/TestsPage";
import TestCreatePage from "./pages/TestCreatePage";
import TestDetailPage from "./pages/TestDetailPage";
import TestCatalogPage from "./pages/TestCatalogPage";
import ReportsPage from "./pages/ReportsPage";
import SystemPage from "./pages/SystemPage";
import UsersPage from "./pages/UsersPage";
import AuditLogPage from "./pages/AuditLogPage";
@@ -35,6 +36,7 @@ export default function App() {
<Route path="/tests/:testId" element={<TestDetailPage />} />
<Route path="/test-catalog" element={<TestCatalogPage />} />
<Route path="/test-catalog/:templateId/use" element={<TestCatalogPage />} />
<Route path="/reports" element={<ReportsPage />} />
<Route
path="/system"
element={
+51
View File
@@ -0,0 +1,51 @@
import client from "./client";
// ── Types ───────────────────────────────────────────────────────────
export interface NotificationItem {
id: string;
user_id: string;
type: string;
title: string;
message: string | null;
entity_type: string | null;
entity_id: string | null;
read: boolean;
created_at: string | null;
}
export interface UnreadCount {
unread_count: number;
}
// ── API ─────────────────────────────────────────────────────────────
/** Fetch notifications for the current user (paginated). */
export async function getNotifications(
offset = 0,
limit = 20,
): Promise<NotificationItem[]> {
const { data } = await client.get<NotificationItem[]>(
`/notifications?offset=${offset}&limit=${limit}`,
);
return data;
}
/** Get the unread notification count. */
export async function getUnreadCount(): Promise<UnreadCount> {
const { data } = await client.get<UnreadCount>("/notifications/unread-count");
return data;
}
/** Mark a single notification as read. */
export async function markAsRead(id: string): Promise<NotificationItem> {
const { data } = await client.patch<NotificationItem>(
`/notifications/${id}/read`,
);
return data;
}
/** Mark all notifications as read. */
export async function markAllAsRead(): Promise<void> {
await client.post("/notifications/read-all");
}
+122
View File
@@ -0,0 +1,122 @@
import client from "./client";
// ── Types ───────────────────────────────────────────────────────────
export interface CoverageReportSummary {
total_techniques: number;
validated: number;
partial: number;
not_covered: number;
in_progress: number;
not_evaluated: number;
coverage_percentage: number;
}
export interface CoverageTechniqueRow {
mitre_id: string;
name: string;
tactic: string | null;
platforms: string[];
status_global: string;
total_tests: number;
tests_by_state: Record<string, number>;
}
export interface CoverageReport {
generated_at: string;
summary: CoverageReportSummary;
techniques: CoverageTechniqueRow[];
}
export interface TestResultsReport {
generated_at: string;
filters: Record<string, string | null>;
summary: {
total_tests: number;
by_state: Record<string, number>;
by_detection_result: Record<string, number>;
};
tests: Array<{
id: string;
name: string;
technique_id: string;
state: string;
platform: string | null;
attack_success: boolean | null;
detection_result: string | null;
red_validation_status: string | null;
blue_validation_status: string | null;
created_at: string | null;
}>;
}
export interface RemediationReport {
generated_at: string;
summary: {
total_with_remediation: number;
by_status: Record<string, number>;
};
tests: Array<{
id: string;
name: string;
technique_id: string;
state: string;
remediation_status: string | null;
remediation_steps: string | null;
remediation_assignee: string | null;
}>;
}
export interface ReportFilters {
tactic?: string;
platform?: string;
state?: string;
date_from?: string;
date_to?: string;
status?: string;
}
// ── API ─────────────────────────────────────────────────────────────
export async function getCoverageSummary(
filters?: ReportFilters,
): Promise<CoverageReport> {
const params = new URLSearchParams();
if (filters?.tactic) params.set("tactic", filters.tactic);
if (filters?.platform) params.set("platform", filters.platform);
const { data } = await client.get<CoverageReport>(
`/reports/coverage-summary?${params.toString()}`,
);
return data;
}
export function getCoverageCsvUrl(filters?: ReportFilters): string {
const params = new URLSearchParams();
if (filters?.tactic) params.set("tactic", filters.tactic);
if (filters?.platform) params.set("platform", filters.platform);
return `/api/v1/reports/coverage-csv?${params.toString()}`;
}
export async function getTestResults(
filters?: ReportFilters,
): Promise<TestResultsReport> {
const params = new URLSearchParams();
if (filters?.state) params.set("state", filters.state);
if (filters?.date_from) params.set("date_from", filters.date_from);
if (filters?.date_to) params.set("date_to", filters.date_to);
const { data } = await client.get<TestResultsReport>(
`/reports/test-results?${params.toString()}`,
);
return data;
}
export async function getRemediationStatus(
filters?: ReportFilters,
): Promise<RemediationReport> {
const params = new URLSearchParams();
if (filters?.status) params.set("status", filters.status);
const { data } = await client.get<RemediationReport>(
`/reports/remediation-status?${params.toString()}`,
);
return data;
}
+78
View File
@@ -0,0 +1,78 @@
import { AlertTriangle, Loader2 } from "lucide-react";
interface ConfirmDialogProps {
open: boolean;
title: string;
message: string;
confirmLabel?: string;
cancelLabel?: string;
variant?: "danger" | "warning" | "default";
isLoading?: boolean;
onConfirm: () => void;
onCancel: () => void;
}
const variantStyles = {
danger: {
icon: "text-red-400 bg-red-500/10",
button: "bg-red-600 hover:bg-red-500",
},
warning: {
icon: "text-yellow-400 bg-yellow-500/10",
button: "bg-yellow-600 hover:bg-yellow-500",
},
default: {
icon: "text-cyan-400 bg-cyan-500/10",
button: "bg-cyan-600 hover:bg-cyan-500",
},
};
export default function ConfirmDialog({
open,
title,
message,
confirmLabel = "Confirm",
cancelLabel = "Cancel",
variant = "default",
isLoading = false,
onConfirm,
onCancel,
}: ConfirmDialogProps) {
if (!open) return null;
const styles = variantStyles[variant];
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
<div className="w-full max-w-md rounded-xl border border-gray-800 bg-gray-900 p-6 shadow-2xl">
<div className="flex items-start gap-4">
<div className={`rounded-lg p-2 ${styles.icon}`}>
<AlertTriangle className="h-6 w-6" />
</div>
<div>
<h3 className="text-lg font-semibold text-white">{title}</h3>
<p className="mt-1 text-sm text-gray-400">{message}</p>
</div>
</div>
<div className="mt-6 flex justify-end gap-3">
<button
onClick={onCancel}
disabled={isLoading}
className="rounded-lg border border-gray-700 px-4 py-2 text-sm font-medium text-gray-300 transition-colors hover:bg-gray-800 disabled:opacity-50"
>
{cancelLabel}
</button>
<button
onClick={onConfirm}
disabled={isLoading}
className={`flex items-center gap-1.5 rounded-lg px-4 py-2 text-sm font-medium text-white transition-colors disabled:opacity-50 ${styles.button}`}
>
{isLoading && <Loader2 className="h-4 w-4 animate-spin" />}
{confirmLabel}
</button>
</div>
</div>
</div>
);
}
+2
View File
@@ -2,6 +2,7 @@ import { Outlet } from "react-router-dom";
import { LogOut } from "lucide-react";
import { useAuth } from "../context/AuthContext";
import Sidebar from "./Sidebar";
import NotificationBell from "./NotificationBell";
export default function Layout() {
const { user, logout } = useAuth();
@@ -13,6 +14,7 @@ export default function Layout() {
<div className="flex flex-1 flex-col overflow-hidden">
{/* Header */}
<header className="flex h-16 items-center justify-end gap-4 border-b border-gray-800 bg-gray-900 px-6">
<NotificationBell />
<span className="text-sm text-gray-300">{user?.username}</span>
<button
onClick={logout}
@@ -0,0 +1,53 @@
import { useState, useRef, useEffect } from "react";
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { Bell } from "lucide-react";
import { getUnreadCount } from "../api/notifications";
import NotificationDropdown from "./NotificationDropdown";
export default function NotificationBell() {
const [open, setOpen] = useState(false);
const ref = useRef<HTMLDivElement>(null);
const queryClient = useQueryClient();
const { data } = useQuery({
queryKey: ["notifications", "unread-count"],
queryFn: getUnreadCount,
refetchInterval: 30000, // Poll every 30 seconds
});
const count = data?.unread_count ?? 0;
// Close dropdown on outside click
useEffect(() => {
function handleClick(e: MouseEvent) {
if (ref.current && !ref.current.contains(e.target as Node)) {
setOpen(false);
}
}
document.addEventListener("mousedown", handleClick);
return () => document.removeEventListener("mousedown", handleClick);
}, []);
return (
<div ref={ref} className="relative">
<button
onClick={() => {
setOpen(!open);
if (!open) {
queryClient.invalidateQueries({ queryKey: ["notifications"] });
}
}}
className="relative rounded-lg p-2 text-gray-400 transition-colors hover:bg-gray-800 hover:text-white"
>
<Bell className="h-5 w-5" />
{count > 0 && (
<span className="absolute -right-0.5 -top-0.5 flex h-4 min-w-[16px] items-center justify-center rounded-full bg-red-500 px-1 text-[10px] font-bold text-white">
{count > 99 ? "99+" : count}
</span>
)}
</button>
{open && <NotificationDropdown onClose={() => setOpen(false)} />}
</div>
);
}
@@ -0,0 +1,139 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { useNavigate } from "react-router-dom";
import {
Loader2,
CheckCheck,
FlaskConical,
AlertTriangle,
CheckCircle,
XCircle,
Bell,
} from "lucide-react";
import {
getNotifications,
markAsRead,
markAllAsRead,
type NotificationItem,
} from "../api/notifications";
const typeIcons: Record<string, React.ReactNode> = {
test_assigned: <FlaskConical className="h-4 w-4 text-indigo-400" />,
validation_needed: <AlertTriangle className="h-4 w-4 text-yellow-400" />,
test_rejected: <XCircle className="h-4 w-4 text-red-400" />,
test_validated: <CheckCircle className="h-4 w-4 text-green-400" />,
test_state_changed: <Bell className="h-4 w-4 text-cyan-400" />,
};
export default function NotificationDropdown({ onClose }: { onClose: () => void }) {
const navigate = useNavigate();
const queryClient = useQueryClient();
const { data: notifications, isLoading } = useQuery({
queryKey: ["notifications", "list"],
queryFn: () => getNotifications(0, 20),
});
const markReadMutation = useMutation({
mutationFn: markAsRead,
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ["notifications"] });
},
});
const markAllMutation = useMutation({
mutationFn: markAllAsRead,
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ["notifications"] });
},
});
const handleClick = (notif: NotificationItem) => {
if (!notif.read) {
markReadMutation.mutate(notif.id);
}
if (notif.entity_type === "test" && notif.entity_id) {
navigate(`/tests/${notif.entity_id}`);
} else if (notif.entity_type === "technique" && notif.entity_id) {
navigate(`/techniques/${notif.entity_id}`);
}
onClose();
};
const formatTime = (dateStr: string | null) => {
if (!dateStr) return "";
const d = new Date(dateStr);
const now = new Date();
const diffMs = now.getTime() - d.getTime();
const diffMin = Math.floor(diffMs / 60000);
if (diffMin < 1) return "just now";
if (diffMin < 60) return `${diffMin}m ago`;
const diffH = Math.floor(diffMin / 60);
if (diffH < 24) return `${diffH}h ago`;
const diffD = Math.floor(diffH / 24);
return `${diffD}d ago`;
};
return (
<div className="absolute right-0 top-full z-50 mt-2 w-80 rounded-xl border border-gray-800 bg-gray-900 shadow-2xl">
{/* Header */}
<div className="flex items-center justify-between border-b border-gray-800 px-4 py-3">
<h3 className="text-sm font-semibold text-white">Notifications</h3>
<button
onClick={() => markAllMutation.mutate()}
disabled={markAllMutation.isPending}
className="flex items-center gap-1 text-xs text-cyan-400 hover:text-cyan-300 transition-colors"
>
<CheckCheck className="h-3.5 w-3.5" />
Mark all read
</button>
</div>
{/* List */}
<div className="max-h-80 overflow-y-auto">
{isLoading ? (
<div className="flex items-center justify-center py-8">
<Loader2 className="h-5 w-5 animate-spin text-cyan-400" />
</div>
) : notifications && notifications.length > 0 ? (
notifications.map((notif) => (
<button
key={notif.id}
onClick={() => handleClick(notif)}
className={`flex w-full items-start gap-3 px-4 py-3 text-left transition-colors hover:bg-gray-800/50 ${
!notif.read ? "bg-cyan-500/5" : ""
}`}
>
<div className="mt-0.5 flex-shrink-0">
{typeIcons[notif.type] || <Bell className="h-4 w-4 text-gray-400" />}
</div>
<div className="min-w-0 flex-1">
<p
className={`text-sm ${
notif.read ? "text-gray-400" : "font-medium text-white"
}`}
>
{notif.title}
</p>
{notif.message && (
<p className="mt-0.5 text-xs text-gray-500 truncate">
{notif.message}
</p>
)}
<p className="mt-1 text-[10px] text-gray-600">
{formatTime(notif.created_at)}
</p>
</div>
{!notif.read && (
<div className="mt-1.5 h-2 w-2 flex-shrink-0 rounded-full bg-cyan-400" />
)}
</button>
))
) : (
<div className="py-8 text-center text-sm text-gray-500">
No notifications yet
</div>
)}
</div>
</div>
);
}
+99 -23
View File
@@ -1,34 +1,110 @@
import { NavLink } from "react-router-dom";
import { useState } from "react";
import {
LayoutDashboard,
Shield,
FlaskConical,
BookOpen,
BarChart3,
Settings,
Users,
FileText,
ChevronDown,
ListChecks,
ClipboardList,
} from "lucide-react";
import { useAuth } from "../context/AuthContext";
const baseLinks = [
interface NavItem {
to: string;
label: string;
icon: React.FC<{ className?: string }>;
children?: NavItem[];
}
const mainLinks: NavItem[] = [
{ to: "/dashboard", label: "Dashboard", icon: LayoutDashboard },
{ to: "/techniques", label: "Techniques", icon: Shield },
{ to: "/tests", label: "Tests", icon: FlaskConical },
{ to: "/test-catalog", label: "Test Catalog", icon: BookOpen },
{ to: "/techniques", label: "ATT&CK Matrix", icon: Shield },
{
to: "/tests",
label: "Tests",
icon: FlaskConical,
children: [
{ to: "/tests", label: "All Tests", icon: ListChecks },
{ to: "/tests?view=pending", label: "My Pending Tasks", icon: ClipboardList },
{ to: "/test-catalog", label: "Test Catalog", icon: BookOpen },
],
},
{ to: "/reports", label: "Reports", icon: BarChart3 },
];
const adminLinks = [
const adminLinks: NavItem[] = [
{ to: "/users", label: "Users", icon: Users },
{ to: "/audit", label: "Audit Log", icon: FileText },
{ to: "/system", label: "System", icon: Settings },
];
function SidebarLink({ item }: { item: NavItem }) {
const [expanded, setExpanded] = useState(false);
if (item.children) {
return (
<div>
<button
onClick={() => setExpanded(!expanded)}
className="flex w-full items-center justify-between rounded-lg px-3 py-2.5 text-sm font-medium text-gray-400 transition-colors hover:bg-gray-800 hover:text-gray-200"
>
<span className="flex items-center gap-3">
<item.icon className="h-5 w-5" />
{item.label}
</span>
<ChevronDown className={`h-4 w-4 transition-transform ${expanded ? "rotate-180" : ""}`} />
</button>
{expanded && (
<div className="ml-4 mt-1 space-y-0.5 border-l border-gray-800 pl-3">
{item.children.map((child) => (
<NavLink
key={child.to + child.label}
to={child.to}
className={({ isActive }) =>
`flex items-center gap-3 rounded-lg px-3 py-2 text-sm transition-colors ${
isActive
? "bg-cyan-500/10 text-cyan-400"
: "text-gray-500 hover:bg-gray-800 hover:text-gray-200"
}`
}
>
<child.icon className="h-4 w-4" />
{child.label}
</NavLink>
))}
</div>
)}
</div>
);
}
return (
<NavLink
to={item.to}
className={({ isActive }) =>
`flex items-center gap-3 rounded-lg px-3 py-2.5 text-sm font-medium transition-colors ${
isActive
? "bg-cyan-500/10 text-cyan-400"
: "text-gray-400 hover:bg-gray-800 hover:text-gray-200"
}`
}
>
<item.icon className="h-5 w-5" />
{item.label}
</NavLink>
);
}
export default function Sidebar() {
const { user } = useAuth();
const isAdmin = user?.role === "admin";
const links = isAdmin ? [...baseLinks, ...adminLinks] : baseLinks;
return (
<aside className="flex h-screen w-60 flex-col border-r border-gray-800 bg-gray-900">
{/* Logo */}
@@ -39,24 +115,24 @@ export default function Sidebar() {
</span>
</div>
{/* Nav links */}
{/* Main nav */}
<nav className="flex-1 space-y-1 px-3 py-4">
{links.map(({ to, label, icon: Icon }) => (
<NavLink
key={to}
to={to}
className={({ isActive }) =>
`flex items-center gap-3 rounded-lg px-3 py-2.5 text-sm font-medium transition-colors ${
isActive
? "bg-cyan-500/10 text-cyan-400"
: "text-gray-400 hover:bg-gray-800 hover:text-gray-200"
}`
}
>
<Icon className="h-5 w-5" />
{label}
</NavLink>
{mainLinks.map((item) => (
<SidebarLink key={item.to + item.label} item={item} />
))}
{/* Admin section */}
{isAdmin && (
<>
<div className="my-3 border-t border-gray-800" />
<p className="mb-2 px-3 text-[10px] font-semibold uppercase tracking-widest text-gray-600">
Administration
</p>
{adminLinks.map((item) => (
<SidebarLink key={item.to} item={item} />
))}
</>
)}
</nav>
{/* Footer */}
+491
View File
@@ -0,0 +1,491 @@
import { useState } from "react";
import { useQuery } from "@tanstack/react-query";
import {
FileText,
Download,
BarChart3,
Shield,
Wrench,
Loader2,
Filter,
ChevronDown,
} from "lucide-react";
import {
getCoverageSummary,
getTestResults,
getRemediationStatus,
type CoverageReport,
type TestResultsReport,
type RemediationReport,
type ReportFilters,
} from "../api/reports";
type ReportType = "coverage" | "test-results" | "remediation";
const reportTypes: { id: ReportType; label: string; icon: React.ReactNode; desc: string }[] = [
{
id: "coverage",
label: "Coverage Summary",
icon: <Shield className="h-5 w-5" />,
desc: "Technique coverage status across the MITRE ATT&CK framework",
},
{
id: "test-results",
label: "Test Results",
icon: <BarChart3 className="h-5 w-5" />,
desc: "Detailed test execution results with state and detection breakdowns",
},
{
id: "remediation",
label: "Remediation Status",
icon: <Wrench className="h-5 w-5" />,
desc: "Remediation progress across all tests with assigned steps",
},
];
export default function ReportsPage() {
const [selectedType, setSelectedType] = useState<ReportType>("coverage");
const [filters, setFilters] = useState<ReportFilters>({});
const [showFilters, setShowFilters] = useState(false);
const coverageQuery = useQuery({
queryKey: ["reports", "coverage", filters],
queryFn: () => getCoverageSummary(filters),
enabled: selectedType === "coverage",
});
const testResultsQuery = useQuery({
queryKey: ["reports", "test-results", filters],
queryFn: () => getTestResults(filters),
enabled: selectedType === "test-results",
});
const remediationQuery = useQuery({
queryKey: ["reports", "remediation", filters],
queryFn: () => getRemediationStatus(filters),
enabled: selectedType === "remediation",
});
const isLoading =
(selectedType === "coverage" && coverageQuery.isLoading) ||
(selectedType === "test-results" && testResultsQuery.isLoading) ||
(selectedType === "remediation" && remediationQuery.isLoading);
const handleDownloadJson = () => {
let data: CoverageReport | TestResultsReport | RemediationReport | undefined;
if (selectedType === "coverage") data = coverageQuery.data;
if (selectedType === "test-results") data = testResultsQuery.data;
if (selectedType === "remediation") data = remediationQuery.data;
if (!data) return;
const blob = new Blob([JSON.stringify(data, null, 2)], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `aegis_${selectedType}_${new Date().toISOString().slice(0, 10)}.json`;
a.click();
URL.revokeObjectURL(url);
};
const handleDownloadCsv = () => {
const token = localStorage.getItem("token");
const params = new URLSearchParams();
if (filters.tactic) params.set("tactic", filters.tactic);
if (filters.platform) params.set("platform", filters.platform);
window.open(
`/api/v1/reports/coverage-csv?${params.toString()}${token ? `&token=${token}` : ""}`,
"_blank",
);
};
return (
<div className="space-y-6">
{/* Header */}
<div className="flex items-center justify-between">
<div>
<h1 className="text-2xl font-bold text-white">Reports</h1>
<p className="mt-1 text-sm text-gray-400">
Generate and download coverage, test results, and remediation reports
</p>
</div>
<div className="flex gap-2">
<button
onClick={handleDownloadJson}
disabled={isLoading}
className="flex items-center gap-2 rounded-lg bg-cyan-600 px-4 py-2 text-sm font-medium text-white hover:bg-cyan-500 transition-colors disabled:opacity-50"
>
<Download className="h-4 w-4" />
Download JSON
</button>
{selectedType === "coverage" && (
<button
onClick={handleDownloadCsv}
disabled={isLoading}
className="flex items-center gap-2 rounded-lg bg-gray-700 px-4 py-2 text-sm font-medium text-white hover:bg-gray-600 transition-colors disabled:opacity-50"
>
<FileText className="h-4 w-4" />
Download CSV
</button>
)}
</div>
</div>
{/* Report type selector */}
<div className="grid grid-cols-1 gap-4 sm:grid-cols-3">
{reportTypes.map((rt) => (
<button
key={rt.id}
onClick={() => {
setSelectedType(rt.id);
setFilters({});
}}
className={`flex items-start gap-3 rounded-xl border p-4 text-left transition-all ${
selectedType === rt.id
? "border-cyan-500/50 bg-cyan-500/10"
: "border-gray-800 bg-gray-900 hover:border-gray-700"
}`}
>
<div
className={`rounded-lg p-2 ${
selectedType === rt.id ? "bg-cyan-500/20 text-cyan-400" : "bg-gray-800 text-gray-400"
}`}
>
{rt.icon}
</div>
<div>
<p className={`text-sm font-medium ${selectedType === rt.id ? "text-cyan-400" : "text-white"}`}>
{rt.label}
</p>
<p className="mt-0.5 text-xs text-gray-500">{rt.desc}</p>
</div>
</button>
))}
</div>
{/* Filters */}
<div className="rounded-xl border border-gray-800 bg-gray-900">
<button
onClick={() => setShowFilters(!showFilters)}
className="flex w-full items-center justify-between px-4 py-3 text-sm font-medium text-gray-300 hover:text-white transition-colors"
>
<span className="flex items-center gap-2">
<Filter className="h-4 w-4" />
Filters
</span>
<ChevronDown className={`h-4 w-4 transition-transform ${showFilters ? "rotate-180" : ""}`} />
</button>
{showFilters && (
<div className="border-t border-gray-800 px-4 py-4">
<div className="grid grid-cols-1 gap-4 sm:grid-cols-3">
{(selectedType === "coverage" || selectedType === "test-results") && (
<>
<div>
<label className="block text-xs text-gray-400 mb-1">Tactic</label>
<input
type="text"
placeholder="e.g. execution"
value={filters.tactic || ""}
onChange={(e) => setFilters({ ...filters, tactic: e.target.value || undefined })}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-3 py-2 text-sm text-white"
/>
</div>
<div>
<label className="block text-xs text-gray-400 mb-1">Platform</label>
<input
type="text"
placeholder="e.g. windows"
value={filters.platform || ""}
onChange={(e) => setFilters({ ...filters, platform: e.target.value || undefined })}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-3 py-2 text-sm text-white"
/>
</div>
</>
)}
{selectedType === "test-results" && (
<>
<div>
<label className="block text-xs text-gray-400 mb-1">State</label>
<select
value={filters.state || ""}
onChange={(e) => setFilters({ ...filters, state: e.target.value || undefined })}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-3 py-2 text-sm text-white"
>
<option value="">All states</option>
<option value="draft">Draft</option>
<option value="red_executing">Red Executing</option>
<option value="blue_evaluating">Blue Evaluating</option>
<option value="in_review">In Review</option>
<option value="validated">Validated</option>
<option value="rejected">Rejected</option>
</select>
</div>
<div>
<label className="block text-xs text-gray-400 mb-1">Date From</label>
<input
type="date"
value={filters.date_from || ""}
onChange={(e) => setFilters({ ...filters, date_from: e.target.value || undefined })}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-3 py-2 text-sm text-white"
/>
</div>
<div>
<label className="block text-xs text-gray-400 mb-1">Date To</label>
<input
type="date"
value={filters.date_to || ""}
onChange={(e) => setFilters({ ...filters, date_to: e.target.value || undefined })}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-3 py-2 text-sm text-white"
/>
</div>
</>
)}
{selectedType === "remediation" && (
<div>
<label className="block text-xs text-gray-400 mb-1">Remediation Status</label>
<select
value={filters.status || ""}
onChange={(e) => setFilters({ ...filters, status: e.target.value || undefined })}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-3 py-2 text-sm text-white"
>
<option value="">All</option>
<option value="pending">Pending</option>
<option value="in_progress">In Progress</option>
<option value="completed">Completed</option>
<option value="not_applicable">Not Applicable</option>
</select>
</div>
)}
</div>
</div>
)}
</div>
{/* Report content */}
{isLoading ? (
<div className="flex items-center justify-center py-16">
<Loader2 className="h-8 w-8 animate-spin text-cyan-400" />
</div>
) : (
<>
{selectedType === "coverage" && coverageQuery.data && (
<CoverageReportView report={coverageQuery.data} />
)}
{selectedType === "test-results" && testResultsQuery.data && (
<TestResultsView report={testResultsQuery.data} />
)}
{selectedType === "remediation" && remediationQuery.data && (
<RemediationView report={remediationQuery.data} />
)}
</>
)}
</div>
);
}
// ── Sub-views ──────────────────────────────────────────────────────
function CoverageReportView({ report }: { report: CoverageReport }) {
const s = report.summary;
return (
<div className="space-y-4">
{/* Summary cards */}
<div className="grid grid-cols-2 gap-4 sm:grid-cols-6">
<StatCard label="Total" value={s.total_techniques} />
<StatCard label="Validated" value={s.validated} color="text-green-400" />
<StatCard label="Partial" value={s.partial} color="text-yellow-400" />
<StatCard label="In Progress" value={s.in_progress} color="text-blue-400" />
<StatCard label="Not Covered" value={s.not_covered} color="text-red-400" />
<StatCard label="Coverage" value={`${s.coverage_percentage}%`} color="text-cyan-400" />
</div>
{/* Table */}
<div className="overflow-hidden rounded-xl border border-gray-800">
<table className="w-full text-sm">
<thead className="bg-gray-900/50">
<tr>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">MITRE ID</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Name</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Tactic</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Status</th>
<th className="px-4 py-3 text-right text-xs font-medium text-gray-400">Tests</th>
</tr>
</thead>
<tbody className="divide-y divide-gray-800">
{report.techniques.map((t) => (
<tr key={t.mitre_id} className="hover:bg-gray-900/30">
<td className="px-4 py-2.5 font-mono text-cyan-400">{t.mitre_id}</td>
<td className="px-4 py-2.5 text-white">{t.name}</td>
<td className="px-4 py-2.5 text-gray-400">{t.tactic}</td>
<td className="px-4 py-2.5">
<StatusBadge status={t.status_global} />
</td>
<td className="px-4 py-2.5 text-right text-gray-300">{t.total_tests}</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
);
}
function TestResultsView({ report }: { report: TestResultsReport }) {
const s = report.summary;
return (
<div className="space-y-4">
<div className="grid grid-cols-2 gap-4 sm:grid-cols-4">
<StatCard label="Total Tests" value={s.total_tests} />
<StatCard label="Validated" value={s.by_state.validated ?? 0} color="text-green-400" />
<StatCard label="In Review" value={s.by_state.in_review ?? 0} color="text-yellow-400" />
<StatCard label="Rejected" value={s.by_state.rejected ?? 0} color="text-red-400" />
</div>
{/* Detection results breakdown */}
{Object.keys(s.by_detection_result).length > 0 && (
<div className="rounded-xl border border-gray-800 bg-gray-900 p-4">
<h3 className="text-sm font-medium text-gray-300 mb-3">Detection Results</h3>
<div className="flex gap-4">
{Object.entries(s.by_detection_result).map(([key, val]) => (
<div key={key} className="text-center">
<p className="text-xl font-bold text-white">{val}</p>
<p className="text-xs text-gray-400">{key.replace(/_/g, " ")}</p>
</div>
))}
</div>
</div>
)}
{/* Table */}
<div className="overflow-hidden rounded-xl border border-gray-800">
<table className="w-full text-sm">
<thead className="bg-gray-900/50">
<tr>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Name</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">State</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Platform</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Detection</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Red Val.</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Blue Val.</th>
</tr>
</thead>
<tbody className="divide-y divide-gray-800">
{report.tests.map((t) => (
<tr key={t.id} className="hover:bg-gray-900/30">
<td className="px-4 py-2.5 text-white">{t.name}</td>
<td className="px-4 py-2.5"><StatusBadge status={t.state} /></td>
<td className="px-4 py-2.5 text-gray-400">{t.platform || "—"}</td>
<td className="px-4 py-2.5 text-gray-300">{t.detection_result?.replace(/_/g, " ") || "—"}</td>
<td className="px-4 py-2.5"><ValidationBadge status={t.red_validation_status} /></td>
<td className="px-4 py-2.5"><ValidationBadge status={t.blue_validation_status} /></td>
</tr>
))}
</tbody>
</table>
</div>
</div>
);
}
function RemediationView({ report }: { report: RemediationReport }) {
const s = report.summary;
return (
<div className="space-y-4">
<div className="grid grid-cols-2 gap-4 sm:grid-cols-4">
<StatCard label="Total w/ Remediation" value={s.total_with_remediation} />
<StatCard label="Pending" value={s.by_status.pending ?? 0} color="text-yellow-400" />
<StatCard label="In Progress" value={s.by_status.in_progress ?? 0} color="text-blue-400" />
<StatCard label="Completed" value={s.by_status.completed ?? 0} color="text-green-400" />
</div>
<div className="overflow-hidden rounded-xl border border-gray-800">
<table className="w-full text-sm">
<thead className="bg-gray-900/50">
<tr>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Name</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Test State</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Remediation Status</th>
<th className="px-4 py-3 text-left text-xs font-medium text-gray-400">Steps</th>
</tr>
</thead>
<tbody className="divide-y divide-gray-800">
{report.tests.map((t) => (
<tr key={t.id} className="hover:bg-gray-900/30">
<td className="px-4 py-2.5 text-white">{t.name}</td>
<td className="px-4 py-2.5"><StatusBadge status={t.state} /></td>
<td className="px-4 py-2.5"><RemediationBadge status={t.remediation_status} /></td>
<td className="px-4 py-2.5 max-w-xs truncate text-gray-400">
{t.remediation_steps || "—"}
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
);
}
// ── Shared components ──────────────────────────────────────────────
function StatCard({
label,
value,
color = "text-white",
}: {
label: string;
value: number | string;
color?: string;
}) {
return (
<div className="rounded-xl border border-gray-800 bg-gray-900 p-4">
<p className="text-xs text-gray-500">{label}</p>
<p className={`text-2xl font-bold ${color}`}>{value}</p>
</div>
);
}
const statusColors: Record<string, string> = {
validated: "bg-green-500/10 text-green-400 border-green-500/30",
partial: "bg-yellow-500/10 text-yellow-400 border-yellow-500/30",
in_progress: "bg-blue-500/10 text-blue-400 border-blue-500/30",
not_covered: "bg-red-500/10 text-red-400 border-red-500/30",
not_evaluated: "bg-gray-500/10 text-gray-400 border-gray-500/30",
draft: "bg-gray-500/10 text-gray-400 border-gray-500/30",
red_executing: "bg-orange-500/10 text-orange-400 border-orange-500/30",
blue_evaluating: "bg-indigo-500/10 text-indigo-400 border-indigo-500/30",
in_review: "bg-yellow-500/10 text-yellow-400 border-yellow-500/30",
rejected: "bg-red-500/10 text-red-400 border-red-500/30",
};
function StatusBadge({ status }: { status: string }) {
return (
<span className={`inline-flex rounded-full border px-2 py-0.5 text-xs font-medium ${statusColors[status] || statusColors.not_evaluated}`}>
{status.replace(/_/g, " ")}
</span>
);
}
function ValidationBadge({ status }: { status: string | null }) {
if (!status) return <span className="text-gray-600 text-xs"></span>;
const colors: Record<string, string> = {
approved: "text-green-400",
rejected: "text-red-400",
pending: "text-yellow-400",
};
return <span className={`text-xs font-medium ${colors[status] || "text-gray-400"}`}>{status}</span>;
}
function RemediationBadge({ status }: { status: string | null }) {
if (!status) return <span className="text-gray-600 text-xs"></span>;
const colors: Record<string, string> = {
pending: "bg-yellow-500/10 text-yellow-400 border-yellow-500/30",
in_progress: "bg-blue-500/10 text-blue-400 border-blue-500/30",
completed: "bg-green-500/10 text-green-400 border-green-500/30",
not_applicable: "bg-gray-500/10 text-gray-400 border-gray-500/30",
};
return (
<span className={`inline-flex rounded-full border px-2 py-0.5 text-xs font-medium ${colors[status] || colors.pending}`}>
{status.replace(/_/g, " ")}
</span>
);
}
+42 -11
View File
@@ -22,6 +22,7 @@ import type { TestResult, TeamSide, TestTimelineEntry } from "../types/models";
import TestDetailHeader from "../components/test-detail/TestDetailHeader";
import TeamTabs from "../components/test-detail/TeamTabs";
import ValidationModal from "../components/test-detail/ValidationModal";
import ConfirmDialog from "../components/ConfirmDialog";
// ── Page Component ─────────────────────────────────────────────────
@@ -38,6 +39,8 @@ export default function TestDetailPage() {
side: "red" | "blue";
}>({ open: false, side: "red" });
const [confirmReopen, setConfirmReopen] = useState(false);
const [redDraft, setRedDraft] = useState({
procedure_text: "",
tool_used: "",
@@ -96,7 +99,19 @@ export default function TestDetailPage() {
const showToast = useCallback((message: string, type: "success" | "error") => {
setToast({ message, type });
setTimeout(() => setToast(null), 3500);
setTimeout(() => setToast(null), 5000);
}, []);
/** Extract a user-friendly error message from Axios or generic errors. */
const extractError = useCallback((err: unknown): string => {
if (err && typeof err === "object" && "response" in err) {
const resp = (err as { response?: { data?: { detail?: string | { message?: string } } } }).response;
const detail = resp?.data?.detail;
if (typeof detail === "string") return detail;
if (detail && typeof detail === "object" && "message" in detail) return (detail as { message: string }).message;
}
if (err instanceof Error) return err.message;
return "An unexpected error occurred";
}, []);
const invalidateAll = useCallback(() => {
@@ -120,7 +135,7 @@ export default function TestDetailPage() {
invalidateAll();
showToast("Red Team fields saved", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
const saveBlueMutation = useMutation({
@@ -133,7 +148,7 @@ export default function TestDetailPage() {
invalidateAll();
showToast("Blue Team fields saved", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
// State transitions
@@ -143,7 +158,7 @@ export default function TestDetailPage() {
invalidateAll();
showToast("Test execution started", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
const submitRedMutation = useMutation({
@@ -152,7 +167,7 @@ export default function TestDetailPage() {
invalidateAll();
showToast("Submitted to Blue Team", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
const submitBlueMutation = useMutation({
@@ -161,7 +176,7 @@ export default function TestDetailPage() {
invalidateAll();
showToast("Submitted for review", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
const validateRedLeadMutation = useMutation({
@@ -172,7 +187,7 @@ export default function TestDetailPage() {
setValidationModal({ open: false, side: "red" });
showToast("Red Lead validation submitted", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
const validateBlueLeadMutation = useMutation({
@@ -183,16 +198,20 @@ export default function TestDetailPage() {
setValidationModal({ open: false, side: "blue" });
showToast("Blue Lead validation submitted", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
const reopenMutation = useMutation({
mutationFn: () => reopenTest(testId!),
onSuccess: () => {
invalidateAll();
setConfirmReopen(false);
showToast("Test reopened", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => {
setConfirmReopen(false);
showToast(extractError(err), "error");
},
});
// Evidence upload
@@ -203,7 +222,7 @@ export default function TestDetailPage() {
invalidateAll();
showToast("Evidence uploaded", "success");
},
onError: (err: Error) => showToast(err.message, "error"),
onError: (err: unknown) => showToast(extractError(err), "error"),
});
// ── Handlers ───────────────────────────────────────────────────
@@ -322,7 +341,7 @@ export default function TestDetailPage() {
onSubmitRed={() => submitRedMutation.mutate()}
onSubmitBlue={() => submitBlueMutation.mutate()}
onOpenValidateModal={(side) => setValidationModal({ open: true, side })}
onReopen={() => reopenMutation.mutate()}
onReopen={() => setConfirmReopen(true)}
/>
{/* Content: Tabs + Sidebar */}
@@ -426,6 +445,18 @@ export default function TestDetailPage() {
</div>
</div>
{/* Confirm Reopen Dialog */}
<ConfirmDialog
open={confirmReopen}
title="Reopen Test"
message="This will move the test back to Draft state and clear all validation decisions. The Red/Blue workflow will need to be restarted. Are you sure?"
confirmLabel="Reopen"
variant="warning"
isLoading={reopenMutation.isPending}
onConfirm={() => reopenMutation.mutate()}
onCancel={() => setConfirmReopen(false)}
/>
{/* Validation Modal */}
{validationModal.open && (
<ValidationModal
+6
View File
@@ -86,6 +86,11 @@ export interface Test {
blue_validation_status: ValidationStatus | null;
blue_validation_notes: string | null;
// Remediation fields
remediation_steps: string | null;
remediation_status: string | null;
remediation_assignee: string | null;
// Technique info (populated in list endpoints)
technique_mitre_id: string | null;
technique_name: string | null;
@@ -125,6 +130,7 @@ export interface TestTemplate {
tool_suggested: string | null;
severity: string | null;
atomic_test_id: string | null;
suggested_remediation: string | null;
is_active: boolean;
created_at: string;
}