feat(phase-9): implement MVP polishing and closure
T-032: User management admin panel - backend users router with CRUD, frontend UsersPage with modals T-033: Audit log viewer - backend audit router with filters/pagination, frontend AuditLogPage T-034: Global error handling - ErrorBoundary, LoadingSpinner, ErrorMessage, Toast components T-035: Backend tests - pytest setup with SQLite, tests for health/auth/techniques/tests T-036: Documentation - Updated README with testing section, created docs/API.md
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.routers import auth as auth_router
|
||||
from app.routers import techniques as techniques_router
|
||||
@@ -10,6 +13,8 @@ from app.routers import tests as tests_router
|
||||
from app.routers import evidence as evidence_router
|
||||
from app.routers import system as system_router
|
||||
from app.routers import metrics as metrics_router
|
||||
from app.routers import users as users_router
|
||||
from app.routers import audit as audit_router
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
@@ -47,8 +52,52 @@ app.include_router(tests_router.router, prefix="/api/v1")
|
||||
app.include_router(evidence_router.router, prefix="/api/v1")
|
||||
app.include_router(system_router.router, prefix="/api/v1")
|
||||
app.include_router(metrics_router.router, prefix="/api/v1")
|
||||
app.include_router(users_router.router, prefix="/api/v1")
|
||||
app.include_router(audit_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ── Exception Handlers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle validation errors with consistent format."""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"detail": "Validation error",
|
||||
"code": "VALIDATION_ERROR",
|
||||
"errors": exc.errors(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(SQLAlchemyError)
|
||||
async def sqlalchemy_exception_handler(request: Request, exc: SQLAlchemyError):
|
||||
"""Handle database errors."""
|
||||
logging.error(f"Database error: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"detail": "Database error occurred",
|
||||
"code": "DATABASE_ERROR",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle all unhandled exceptions."""
|
||||
logging.error(f"Unhandled exception: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"detail": "An internal server error occurred",
|
||||
"code": "INTERNAL_ERROR",
|
||||
},
|
||||
)
|
||||
|
||||
118
backend/app/routers/audit.py
Normal file
118
backend/app/routers/audit.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Audit log viewer router (admin only)."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.user import User
|
||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||
|
||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||
|
||||
|
||||
@router.get("", response_model=AuditLogPage)
|
||||
def list_audit_logs(
|
||||
user_id: Optional[str] = Query(None, description="Filter by user ID"),
|
||||
action: Optional[str] = Query(None, description="Filter by action type"),
|
||||
entity_type: Optional[str] = Query(None, description="Filter by entity type"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter by start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter by end date"),
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Max records to return"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return paginated audit logs with optional filters.
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
query = query.filter(AuditLog.action == action)
|
||||
if entity_type:
|
||||
query = query.filter(AuditLog.entity_type == entity_type)
|
||||
if start_date:
|
||||
query = query.filter(AuditLog.timestamp >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AuditLog.timestamp <= end_date)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get paginated results
|
||||
logs = (
|
||||
query
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Convert to response format with username
|
||||
items = []
|
||||
for log in logs:
|
||||
item = AuditLogOut(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.user.username if log.user else None,
|
||||
action=log.action,
|
||||
entity_type=log.entity_type,
|
||||
entity_id=log.entity_id,
|
||||
timestamp=log.timestamp,
|
||||
details=log.details,
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return AuditLogPage(
|
||||
items=items,
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/actions", response_model=list[str])
|
||||
def list_actions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of distinct action types in the audit log.
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
actions = (
|
||||
db.query(AuditLog.action)
|
||||
.distinct()
|
||||
.order_by(AuditLog.action)
|
||||
.all()
|
||||
)
|
||||
return [a[0] for a in actions]
|
||||
|
||||
|
||||
@router.get("/entity-types", response_model=list[str])
|
||||
def list_entity_types(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of distinct entity types in the audit log.
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
types = (
|
||||
db.query(AuditLog.entity_type)
|
||||
.filter(AuditLog.entity_type.isnot(None))
|
||||
.distinct()
|
||||
.order_by(AuditLog.entity_type)
|
||||
.all()
|
||||
)
|
||||
return [t[0] for t in types]
|
||||
153
backend/app/routers/users.py
Normal file
153
backend/app/routers/users.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""User management router (admin only)."""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
||||
from app.auth import hash_password
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users — list all users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("", response_model=list[UserOut])
|
||||
def list_users(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of all users. **Requires admin role.**"""
|
||||
return db.query(User).order_by(User.username).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /users — create a new user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
def create_user(
|
||||
payload: UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Create a new user. **Requires admin role.**"""
|
||||
|
||||
# Check if username already exists
|
||||
existing = db.query(User).filter(User.username == payload.username).first()
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Username '{payload.username}' already exists",
|
||||
)
|
||||
|
||||
# Validate role
|
||||
if payload.role not in VALID_ROLES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
||||
)
|
||||
|
||||
user = User(
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hash_password(payload.password),
|
||||
role=payload.role,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="create_user",
|
||||
entity_type="user",
|
||||
entity_id=user.id,
|
||||
details={"username": user.username, "role": user.role},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users/{id} — get a single user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserOut)
|
||||
def get_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a single user by ID. **Requires admin role.**"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /users/{id} — update a user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserOut)
|
||||
def update_user(
|
||||
user_id: uuid.UUID,
|
||||
payload: UserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
# Validate role if being updated
|
||||
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
||||
)
|
||||
|
||||
# Hash password if being updated
|
||||
if "password" in update_data:
|
||||
update_data["hashed_password"] = hash_password(update_data.pop("password"))
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_user",
|
||||
entity_type="user",
|
||||
entity_id=user.id,
|
||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
||||
)
|
||||
|
||||
return user
|
||||
31
backend/app/schemas/audit.py
Normal file
31
backend/app/schemas/audit.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Pydantic schemas for Audit Log endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class AuditLogOut(BaseModel):
|
||||
"""Complete representation of an audit log entry."""
|
||||
|
||||
id: uuid.UUID
|
||||
user_id: uuid.UUID | None = None
|
||||
username: str | None = None # Populated from user relationship
|
||||
action: str
|
||||
entity_type: str | None = None
|
||||
entity_id: str | None = None
|
||||
timestamp: datetime
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AuditLogPage(BaseModel):
|
||||
"""Paginated response for audit logs."""
|
||||
|
||||
items: list[AuditLogOut]
|
||||
total: int
|
||||
offset: int
|
||||
limit: int
|
||||
45
backend/app/schemas/user.py
Normal file
45
backend/app/schemas/user.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Pydantic schemas for User management endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr
|
||||
|
||||
|
||||
# ── Create ──────────────────────────────────────────────────────────
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
"""Payload for creating a new user."""
|
||||
|
||||
username: str
|
||||
email: str | None = None
|
||||
password: str
|
||||
role: str = "viewer"
|
||||
|
||||
|
||||
# ── Update ──────────────────────────────────────────────────────────
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Payload for partially updating an existing user.
|
||||
Every field is optional so callers send only what changed."""
|
||||
|
||||
email: str | None = None
|
||||
role: str | None = None
|
||||
is_active: bool | None = None
|
||||
password: str | None = None
|
||||
|
||||
|
||||
# ── Read (full) ─────────────────────────────────────────────────────
|
||||
|
||||
class UserOut(BaseModel):
|
||||
"""Complete representation returned by the API."""
|
||||
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
email: str | None = None
|
||||
role: str
|
||||
is_active: bool
|
||||
created_at: datetime | None = None
|
||||
last_login: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
5
backend/pytest.ini
Normal file
5
backend/pytest.ini
Normal file
@@ -0,0 +1,5 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short
|
||||
@@ -12,3 +12,8 @@ requests
|
||||
taxii2-client
|
||||
python-multipart
|
||||
pydantic-settings
|
||||
|
||||
# Testing
|
||||
pytest
|
||||
pytest-asyncio
|
||||
httpx
|
||||
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
118
backend/tests/conftest.py
Normal file
118
backend/tests/conftest.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Pytest fixtures and configuration for backend tests."""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.main import app
|
||||
from app.database import Base, get_db
|
||||
from app.auth import hash_password
|
||||
from app.models.user import User
|
||||
|
||||
# Use in-memory SQLite for tests
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def override_get_db():
|
||||
"""Override the database dependency for testing."""
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db():
|
||||
"""Create a fresh database for each test."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = TestingSessionLocal()
|
||||
yield db
|
||||
db.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db):
|
||||
"""Create a test client with database override."""
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def admin_user(db):
|
||||
"""Create an admin user for testing."""
|
||||
user = User(
|
||||
username="admin",
|
||||
email="admin@test.com",
|
||||
hashed_password=hash_password("admin123"),
|
||||
role="admin",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def red_tech_user(db):
|
||||
"""Create a red_tech user for testing."""
|
||||
user = User(
|
||||
username="redtech",
|
||||
email="redtech@test.com",
|
||||
hashed_password=hash_password("redtech123"),
|
||||
role="red_tech",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def admin_token(client, admin_user):
|
||||
"""Get an auth token for the admin user."""
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "admin", "password": "admin123"},
|
||||
)
|
||||
return response.json()["access_token"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def red_tech_token(client, red_tech_user):
|
||||
"""Get an auth token for the red_tech user."""
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "redtech", "password": "redtech123"},
|
||||
)
|
||||
return response.json()["access_token"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def auth_headers(admin_token):
|
||||
"""Return authorization headers for admin user."""
|
||||
return {"Authorization": f"Bearer {admin_token}"}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def red_tech_headers(red_tech_token):
|
||||
"""Return authorization headers for red_tech user."""
|
||||
return {"Authorization": f"Bearer {red_tech_token}"}
|
||||
81
backend/tests/test_auth.py
Normal file
81
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for authentication endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_login_success(client, admin_user):
|
||||
"""Test successful login returns a token."""
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "admin", "password": "admin123"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
|
||||
def test_login_wrong_password(client, admin_user):
|
||||
"""Test login with wrong password returns 400."""
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "admin", "password": "wrongpassword"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_login_nonexistent_user(client):
|
||||
"""Test login with non-existent user returns 400."""
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "nobody", "password": "password"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_login_inactive_user(client, db):
|
||||
"""Test login with inactive user returns 400."""
|
||||
from app.auth import hash_password
|
||||
from app.models.user import User
|
||||
|
||||
user = User(
|
||||
username="inactive",
|
||||
hashed_password=hash_password("password"),
|
||||
role="viewer",
|
||||
is_active=False,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "inactive", "password": "password"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_get_me_with_token(client, admin_user, admin_token):
|
||||
"""Test /auth/me returns current user with valid token."""
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["username"] == "admin"
|
||||
assert data["role"] == "admin"
|
||||
|
||||
|
||||
def test_get_me_without_token(client):
|
||||
"""Test /auth/me returns 401 without token."""
|
||||
response = client.get("/api/v1/auth/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_get_me_invalid_token(client):
|
||||
"""Test /auth/me returns 401 with invalid token."""
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": "Bearer invalidtoken"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
8
backend/tests/test_health.py
Normal file
8
backend/tests/test_health.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Tests for the health endpoint."""
|
||||
|
||||
|
||||
def test_health_endpoint(client):
|
||||
"""Test that the health endpoint returns 200 OK."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
128
backend/tests/test_techniques.py
Normal file
128
backend/tests/test_techniques.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Tests for technique endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_list_techniques_requires_auth(client):
|
||||
"""Test that listing techniques requires authentication."""
|
||||
response = client.get("/api/v1/techniques")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_list_techniques_empty(client, auth_headers):
|
||||
"""Test listing techniques when none exist."""
|
||||
response = client.get("/api/v1/techniques", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def test_create_technique_requires_admin(client, red_tech_headers):
|
||||
"""Test that creating techniques requires admin role."""
|
||||
response = client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1001", "name": "Test Technique"},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_create_technique_success(client, auth_headers):
|
||||
"""Test successful technique creation."""
|
||||
response = client.post(
|
||||
"/api/v1/techniques",
|
||||
json={
|
||||
"mitre_id": "T1059",
|
||||
"name": "Command and Scripting Interpreter",
|
||||
"description": "Test description",
|
||||
"tactic": "execution",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["mitre_id"] == "T1059"
|
||||
assert data["name"] == "Command and Scripting Interpreter"
|
||||
assert data["status_global"] == "not_evaluated"
|
||||
|
||||
|
||||
def test_create_duplicate_technique(client, auth_headers):
|
||||
"""Test creating a technique with duplicate mitre_id fails."""
|
||||
# Create first technique
|
||||
client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1001", "name": "First"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Try to create duplicate
|
||||
response = client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1001", "name": "Second"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_get_technique_by_mitre_id(client, auth_headers):
|
||||
"""Test getting a single technique by mitre_id."""
|
||||
# Create a technique first
|
||||
client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1059", "name": "Test Technique"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Get it by mitre_id
|
||||
response = client.get("/api/v1/techniques/T1059", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["mitre_id"] == "T1059"
|
||||
|
||||
|
||||
def test_get_nonexistent_technique(client, auth_headers):
|
||||
"""Test getting a non-existent technique returns 404."""
|
||||
response = client.get("/api/v1/techniques/T9999", headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_update_technique(client, auth_headers):
|
||||
"""Test updating a technique."""
|
||||
# Create technique
|
||||
client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1059", "name": "Original Name"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Update it
|
||||
response = client.patch(
|
||||
"/api/v1/techniques/T1059",
|
||||
json={"name": "Updated Name", "description": "New description"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Name"
|
||||
|
||||
|
||||
def test_filter_techniques_by_tactic(client, auth_headers):
|
||||
"""Test filtering techniques by tactic."""
|
||||
# Create techniques in different tactics
|
||||
client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1001", "name": "Exec", "tactic": "execution"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1002", "name": "Persist", "tactic": "persistence"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Filter by execution
|
||||
response = client.get(
|
||||
"/api/v1/techniques?tactic=execution",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
techniques = response.json()
|
||||
assert len(techniques) == 1
|
||||
assert techniques[0]["mitre_id"] == "T1001"
|
||||
165
backend/tests/test_tests.py
Normal file
165
backend/tests/test_tests.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Tests for security test endpoints."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def technique(client, auth_headers):
|
||||
"""Create a technique for test association."""
|
||||
response = client.post(
|
||||
"/api/v1/techniques",
|
||||
json={"mitre_id": "T1059", "name": "Test Technique"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_create_test_requires_auth(client, technique):
|
||||
"""Test that creating a test requires authentication."""
|
||||
response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={
|
||||
"technique_id": technique["id"],
|
||||
"name": "Test Name",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_create_test_success(client, red_tech_headers, technique):
|
||||
"""Test successful test creation."""
|
||||
response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={
|
||||
"technique_id": technique["id"],
|
||||
"name": "My Security Test",
|
||||
"description": "Test description",
|
||||
"platform": "windows",
|
||||
},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "My Security Test"
|
||||
assert data["state"] == "draft"
|
||||
assert data["technique_id"] == technique["id"]
|
||||
|
||||
|
||||
def test_create_test_nonexistent_technique(client, red_tech_headers):
|
||||
"""Test creating a test with non-existent technique fails."""
|
||||
response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={
|
||||
"technique_id": "00000000-0000-0000-0000-000000000000",
|
||||
"name": "Test",
|
||||
},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_get_test_by_id(client, red_tech_headers, technique):
|
||||
"""Test getting a test by ID."""
|
||||
# Create a test
|
||||
create_response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={"technique_id": technique["id"], "name": "Test"},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
test_id = create_response.json()["id"]
|
||||
|
||||
# Get it
|
||||
response = client.get(f"/api/v1/tests/{test_id}", headers=red_tech_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == test_id
|
||||
|
||||
|
||||
def test_validate_test(client, auth_headers, red_tech_headers, technique):
|
||||
"""Test validating a test updates status correctly."""
|
||||
# Create a test
|
||||
create_response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={"technique_id": technique["id"], "name": "Test"},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
test_id = create_response.json()["id"]
|
||||
|
||||
# Validate it (requires lead/admin)
|
||||
response = client.post(
|
||||
f"/api/v1/tests/{test_id}/validate",
|
||||
json={"result": "detected"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["state"] == "validated"
|
||||
assert data["result"] == "detected"
|
||||
assert data["validated_by"] is not None
|
||||
|
||||
|
||||
def test_validate_test_updates_technique_status(client, auth_headers, red_tech_headers, technique):
|
||||
"""Test that validating a test recalculates technique status."""
|
||||
# Create and validate a test
|
||||
create_response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={"technique_id": technique["id"], "name": "Test"},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
test_id = create_response.json()["id"]
|
||||
|
||||
client.post(
|
||||
f"/api/v1/tests/{test_id}/validate",
|
||||
json={"result": "detected"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Check technique status was updated
|
||||
response = client.get(
|
||||
f"/api/v1/techniques/{technique['mitre_id']}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.json()["status_global"] == "validated"
|
||||
|
||||
|
||||
def test_reject_test(client, auth_headers, red_tech_headers, technique):
|
||||
"""Test rejecting a test."""
|
||||
# Create a test
|
||||
create_response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={"technique_id": technique["id"], "name": "Test"},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
test_id = create_response.json()["id"]
|
||||
|
||||
# Reject it
|
||||
response = client.post(
|
||||
f"/api/v1/tests/{test_id}/reject",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["state"] == "rejected"
|
||||
|
||||
|
||||
def test_update_test_only_in_draft(client, auth_headers, red_tech_headers, technique):
|
||||
"""Test that tests can only be updated when in draft/rejected state."""
|
||||
# Create and validate a test
|
||||
create_response = client.post(
|
||||
"/api/v1/tests",
|
||||
json={"technique_id": technique["id"], "name": "Test"},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
test_id = create_response.json()["id"]
|
||||
|
||||
client.post(
|
||||
f"/api/v1/tests/{test_id}/validate",
|
||||
json={"result": "detected"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Try to update validated test
|
||||
response = client.patch(
|
||||
f"/api/v1/tests/{test_id}",
|
||||
json={"name": "New Name"},
|
||||
headers=red_tech_headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
Reference in New Issue
Block a user