feat(refactor): PEP8, type annotations, docstrings and PyJWT security fix
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""FastAPI router modules — one router per feature domain."""
|
||||
|
||||
@@ -1,50 +1,81 @@
|
||||
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
|
||||
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import advanced_metrics_service from app.services
|
||||
from app.services import advanced_metrics_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage-by-tactic")
|
||||
# Define function coverage_by_tactic
|
||||
def coverage_by_tactic(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||
# Return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||
return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/never-tested")
|
||||
# Define function never_tested_techniques
|
||||
def never_tested_techniques(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Techniques that have never had a test created."""
|
||||
# Return advanced_metrics_service.get_never_tested_techniques(db)
|
||||
return advanced_metrics_service.get_never_tested_techniques(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/avg-validation-time")
|
||||
# Define function avg_validation_time
|
||||
def avg_validation_time(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Average time from test creation to validation, computed from audit logs.
|
||||
|
||||
Returns overall average and per-phase averages where data is available.
|
||||
"""
|
||||
# Return advanced_metrics_service.get_avg_validation_time(db)
|
||||
return advanced_metrics_service.get_avg_validation_time(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/detection-rate-trend")
|
||||
# Define function detection_rate_trend
|
||||
def detection_rate_trend(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Monthly detection rate trend for the last 12 months."""
|
||||
# Return advanced_metrics_service.get_detection_rate_trend(db)
|
||||
return advanced_metrics_service.get_detection_rate_trend(db)
|
||||
|
||||
@@ -4,52 +4,85 @@ Returns complete datasets without pagination so BI tools can ingest
|
||||
directly from URL. All endpoints require authentication.
|
||||
"""
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import analytics_service from app.services
|
||||
from app.services import analytics_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage")
|
||||
# Define function analytics_coverage
|
||||
def analytics_coverage(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Coverage per technique — flat format for BI dashboards."""
|
||||
# Return analytics_service.get_coverage_analytics(db)
|
||||
return analytics_service.get_coverage_analytics(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/tests")
|
||||
# Define function analytics_tests
|
||||
def analytics_tests(
|
||||
# Entry: date_from
|
||||
date_from: str = Query(None, description="ISO date filter (>=)"),
|
||||
# Entry: date_to
|
||||
date_to: str = Query(None, description="ISO date filter (<=)"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""All tests with timestamps — flat format for BI dashboards."""
|
||||
# Return analytics_service.get_tests_analytics(
|
||||
return analytics_service.get_tests_analytics(
|
||||
db, date_from=date_from, date_to=date_to
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/trends")
|
||||
# Define function analytics_trends
|
||||
def analytics_trends(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Historical coverage snapshots for trend visualization."""
|
||||
# Return analytics_service.get_trends_analytics(db)
|
||||
return analytics_service.get_trends_analytics(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/operators")
|
||||
# Define function analytics_operators
|
||||
def analytics_operators(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> list:
|
||||
"""Per-operator metrics — for workload management dashboards."""
|
||||
# Return analytics_service.get_operators_analytics(db)
|
||||
return analytics_service.get_operators_analytics(db)
|
||||
|
||||
@@ -1,77 +1,127 @@
|
||||
"""Audit log viewer router (admin only)."""
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import AuditLogOut, AuditLogPage from app.schemas.audit
|
||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||
|
||||
# Import from app.services.audit_query_service
|
||||
from app.services.audit_query_service import (
|
||||
list_distinct_actions,
|
||||
list_distinct_entity_types,
|
||||
list_logs,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("", response_model=AuditLogPage)
|
||||
# Define function list_audit_logs
|
||||
def list_audit_logs(
|
||||
# Entry: user_id
|
||||
user_id: Optional[str] = Query(None, description="Filter by user ID"),
|
||||
# Entry: action
|
||||
action: Optional[str] = Query(None, description="Filter by action type"),
|
||||
# Entry: entity_type
|
||||
entity_type: Optional[str] = Query(None, description="Filter by entity type"),
|
||||
# Entry: start_date
|
||||
start_date: Optional[datetime] = Query(None, description="Filter by start date"),
|
||||
# Entry: end_date
|
||||
end_date: Optional[datetime] = Query(None, description="Filter by end date"),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=100, description="Max records to return"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> AuditLogPage:
|
||||
"""Return paginated audit logs with optional filters.
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
# Assign result = list_logs(
|
||||
result = list_logs(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user_id,
|
||||
# Keyword argument: action
|
||||
action=action,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: start_date
|
||||
start_date=start_date,
|
||||
# Keyword argument: end_date
|
||||
end_date=end_date,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
# Return AuditLogPage(
|
||||
return AuditLogPage(
|
||||
# Keyword argument: items
|
||||
items=[AuditLogOut(**item) for item in result["items"]],
|
||||
# Keyword argument: total
|
||||
total=result["total"],
|
||||
# Keyword argument: offset
|
||||
offset=result["offset"],
|
||||
# Keyword argument: limit
|
||||
limit=result["limit"],
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/actions", response_model=list[str])
|
||||
# Define function list_actions
|
||||
def list_actions(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> list[str]:
|
||||
"""Return a list of distinct action types in the audit log.
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
# Return list_distinct_actions(db)
|
||||
return list_distinct_actions(db)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/entity-types", response_model=list[str])
|
||||
# Define function list_entity_types
|
||||
def list_entity_types(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> list[str]:
|
||||
"""Return a list of distinct entity types in the audit log.
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
# Return list_distinct_entity_types(db)
|
||||
return list_distinct_entity_types(db)
|
||||
|
||||
+133
-12
@@ -7,31 +7,68 @@ the token in the body for backwards compatibility and for clients that
|
||||
cannot use cookies (e.g. Swagger UI).
|
||||
"""
|
||||
|
||||
# Import os
|
||||
import os
|
||||
|
||||
# Import APIRouter, Cookie, Depends, Request, Response from fastapi
|
||||
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||
|
||||
# Import OAuth2PasswordRequestForm from fastapi.security
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
# Import jwt (PyJWT)
|
||||
import jwt
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from jose import jwt, JWTError
|
||||
# Import blacklist_token, create_access_token, verify_pa... from app.auth
|
||||
from app.auth import blacklist_token, create_access_token, verify_password
|
||||
|
||||
from app.auth import create_access_token, blacklist_token, verify_password
|
||||
# Import settings from app.config
|
||||
from app.config import settings
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import BusinessRuleViolation, PermissionViolation from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import resolve_client_ip from app.middleware.request_context
|
||||
from app.middleware.request_context import resolve_client_ip
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import (
|
||||
_DUMMY_HASH,
|
||||
change_password as auth_change_password,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import TokenResponse, UserOut from app.schemas.auth
|
||||
from app.schemas.auth import TokenResponse, UserOut
|
||||
|
||||
# Import PasswordChange from app.schemas.user
|
||||
from app.schemas.user import PasswordChange
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.auth_service
|
||||
from app.services.auth_service import (
|
||||
_DUMMY_HASH,
|
||||
)
|
||||
|
||||
# Import from app.services.auth_service
|
||||
from app.services.auth_service import (
|
||||
change_password as auth_change_password,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# SECURE_COOKIES desacopla la seguridad de la cookie del entorno de ejecucion.
|
||||
@@ -47,111 +84,182 @@ else: # "auto" — activo solo si AEGIS_ENV=production
|
||||
_COOKIE_NAME = "aegis_token"
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("5/minute")
|
||||
# Define function login
|
||||
def login(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: response
|
||||
response: Response,
|
||||
# Entry: form_data
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> TokenResponse:
|
||||
"""Authenticate a user and return a JWT access token.
|
||||
|
||||
Rate-limited to **5 attempts per minute per IP**. Failed and successful
|
||||
logins are recorded in the audit log (SEC-009).
|
||||
"""
|
||||
# Assign user = db.query(User).filter(User.username == form_data.username).first()
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
# Assign target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||
target_hash = user.hashed_password if user else _DUMMY_HASH
|
||||
# Assign password_valid = verify_password(form_data.password, target_hash)
|
||||
password_valid = verify_password(form_data.password, target_hash)
|
||||
# Assign ip = resolve_client_ip(request)
|
||||
ip = resolve_client_ip(request)
|
||||
|
||||
# Check: user is None or not password_valid
|
||||
if user is None or not password_valid:
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
user.id if user else None,
|
||||
# Literal argument value
|
||||
"LOGIN_FAILED",
|
||||
# Literal argument value
|
||||
"auth",
|
||||
# Literal argument value
|
||||
None,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"username": form_data.username,
|
||||
# Literal argument value
|
||||
"ip": ip,
|
||||
# Literal argument value
|
||||
"reason": "invalid_credentials",
|
||||
},
|
||||
# Keyword argument: ip_address
|
||||
ip_address=ip,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Incorrect username or password")
|
||||
|
||||
# Check: not user.is_active
|
||||
if not user.is_active:
|
||||
# Raise PermissionViolation
|
||||
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||
|
||||
# Assign access_token = create_access_token(data={"sub": user.username})
|
||||
access_token = create_access_token(data={"sub": user.username})
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
user.id,
|
||||
# Literal argument value
|
||||
"LOGIN_SUCCESS",
|
||||
# Literal argument value
|
||||
"auth",
|
||||
str(user.id),
|
||||
# Keyword argument: details
|
||||
details={"username": user.username, "ip": ip},
|
||||
# Keyword argument: ip_address
|
||||
ip_address=ip,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Call response.set_cookie()
|
||||
response.set_cookie(
|
||||
# Keyword argument: key
|
||||
key=_COOKIE_NAME,
|
||||
# Keyword argument: value
|
||||
value=access_token,
|
||||
# Keyword argument: httponly
|
||||
httponly=True,
|
||||
# Keyword argument: secure
|
||||
secure=_IS_HTTPS,
|
||||
# Keyword argument: samesite
|
||||
samesite="strict",
|
||||
# Keyword argument: max_age
|
||||
max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
# Keyword argument: path
|
||||
path="/",
|
||||
)
|
||||
|
||||
# Return TokenResponse(access_token=access_token)
|
||||
return TokenResponse(access_token=access_token)
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/logout")
|
||||
# Define function logout
|
||||
def logout(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: response
|
||||
response: Response,
|
||||
# Entry: aegis_token
|
||||
aegis_token: str | None = Cookie(None),
|
||||
):
|
||||
) -> dict:
|
||||
"""Clear the authentication cookie and revoke the current token."""
|
||||
# Assign bearer = (
|
||||
bearer = (
|
||||
request.headers.get("Authorization")
|
||||
or request.headers.get("authorization")
|
||||
or ""
|
||||
)
|
||||
# Assign bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
|
||||
bearer = bearer.removeprefix("Bearer ").removeprefix("bearer ").strip()
|
||||
|
||||
# Assign seen = set()
|
||||
seen: set[str] = set()
|
||||
# Iterate over (aegis_token, bearer)
|
||||
for raw in (aegis_token, bearer):
|
||||
# Check: not raw or raw in seen
|
||||
if not raw or raw in seen:
|
||||
# Skip to the next loop iteration
|
||||
continue
|
||||
# Call seen.add()
|
||||
seen.add(raw)
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign payload = jwt.decode(
|
||||
payload = jwt.decode(
|
||||
raw,
|
||||
settings.SECRET_KEY,
|
||||
# Keyword argument: algorithms
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
# Assign jti = payload.get("jti")
|
||||
jti = payload.get("jti")
|
||||
# Assign exp = payload.get("exp", 0)
|
||||
exp = payload.get("exp", 0)
|
||||
# Check: jti
|
||||
if jti:
|
||||
# Call blacklist_token()
|
||||
blacklist_token(jti, float(exp))
|
||||
except JWTError:
|
||||
# Handle any JWT validation error during logout (token may be expired or malformed)
|
||||
except jwt.exceptions.InvalidTokenError:
|
||||
# Intentional no-op placeholder
|
||||
pass
|
||||
|
||||
# Call response.delete_cookie()
|
||||
response.delete_cookie(
|
||||
# Keyword argument: key
|
||||
key=_COOKIE_NAME,
|
||||
# Keyword argument: httponly
|
||||
httponly=True,
|
||||
# Keyword argument: secure
|
||||
secure=_IS_HTTPS,
|
||||
# Keyword argument: samesite
|
||||
samesite="strict",
|
||||
# Keyword argument: path
|
||||
path="/",
|
||||
)
|
||||
# Return {"detail": "Logged out"}
|
||||
return {"detail": "Logged out"}
|
||||
|
||||
|
||||
@@ -207,25 +315,38 @@ def refresh_token(
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)):
|
||||
# Define function read_current_user
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
|
||||
"""Return the profile of the currently authenticated user."""
|
||||
# Return current_user
|
||||
return current_user
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/change-password")
|
||||
# Define function change_password
|
||||
def change_password(
|
||||
# Entry: body
|
||||
body: PasswordChange,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Change the current user's password."""
|
||||
# Call auth_change_password()
|
||||
auth_change_password(
|
||||
db,
|
||||
current_user,
|
||||
# Keyword argument: current_password
|
||||
current_password=body.current_password,
|
||||
# Keyword argument: new_password
|
||||
new_password=body.new_password,
|
||||
)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"detail": "Password changed successfully"}
|
||||
return {"detail": "Password changed successfully"}
|
||||
|
||||
@@ -1,80 +1,169 @@
|
||||
"""Campaign endpoints — CRUD, test management, activation, and auto-generation.
|
||||
|
||||
Provides comprehensive campaign lifecycle management including
|
||||
test ordering, progress tracking, and threat actor integration.
|
||||
Provides comprehensive campaign lifecycle management including test ordering,
|
||||
progress tracking, and threat actor integration.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import BaseModel, Field from pydantic
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.test import Test
|
||||
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||
from app.services.campaign_crud_service import (
|
||||
add_test_to_campaign as crud_add_test,
|
||||
activate_campaign as crud_activate,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
complete_campaign as crud_complete,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
create_campaign as crud_create,
|
||||
delete_campaign as crud_delete,
|
||||
get_campaign_detail as crud_get_detail,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
get_campaign_history as crud_get_history,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
get_campaign_progress_data as crud_get_progress,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
list_campaigns as crud_list,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
remove_test_from_campaign as crud_remove_test,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
schedule_campaign as crud_schedule,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
serialize_campaign,
|
||||
)
|
||||
|
||||
# Import from app.services.campaign_crud_service
|
||||
from app.services.campaign_crud_service import (
|
||||
update_campaign as crud_update,
|
||||
)
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import generate_campaign_from_threat_actor from app.services.campaign_service
|
||||
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||
|
||||
# Import notify_role from app.services.notification_service
|
||||
from app.services.notification_service import notify_role
|
||||
from app.services.webhook_service import dispatch_webhook
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(prefix="/campaigns", tags=["campaigns"])
|
||||
router = APIRouter(prefix="/campaigns", tags=["campaigns"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ─────────────────────────────────────────────────
|
||||
|
||||
class CampaignCreate(BaseModel):
|
||||
"""Payload for creating a new campaign."""
|
||||
|
||||
# name: str
|
||||
name: str
|
||||
# Assign description = None
|
||||
description: Optional[str] = None
|
||||
# Assign type = "custom"
|
||||
type: str = "custom"
|
||||
# Assign threat_actor_id = None
|
||||
threat_actor_id: Optional[str] = None
|
||||
# Assign target_platform = None
|
||||
target_platform: Optional[str] = None
|
||||
# Assign tags = Field(default_factory=list)
|
||||
tags: Optional[list[str]] = Field(default_factory=list)
|
||||
# Assign scheduled_at = None
|
||||
scheduled_at: Optional[str] = None
|
||||
start_date: Optional[str] = None # ISO date — campaign won't activate before this
|
||||
|
||||
|
||||
# Define class CampaignUpdate
|
||||
class CampaignUpdate(BaseModel):
|
||||
"""Payload for updating an existing campaign's metadata."""
|
||||
|
||||
# Assign name = None
|
||||
name: Optional[str] = None
|
||||
# Assign description = None
|
||||
description: Optional[str] = None
|
||||
# Assign type = None
|
||||
type: Optional[str] = None
|
||||
# Assign target_platform = None
|
||||
target_platform: Optional[str] = None
|
||||
# Assign tags = None
|
||||
tags: Optional[list[str]] = None
|
||||
# Assign scheduled_at = None
|
||||
scheduled_at: Optional[str] = None
|
||||
start_date: Optional[str] = None # ISO date — can be updated while still in draft
|
||||
|
||||
|
||||
# Define class AddTestPayload
|
||||
class AddTestPayload(BaseModel):
|
||||
"""Payload for adding a test to a campaign."""
|
||||
|
||||
# test_id: str
|
||||
test_id: str
|
||||
# Assign order_index = None
|
||||
order_index: Optional[int] = None
|
||||
# Assign depends_on = None
|
||||
depends_on: Optional[str] = None
|
||||
# Assign phase = None
|
||||
phase: Optional[str] = None
|
||||
|
||||
|
||||
# Define class SchedulePayload
|
||||
class SchedulePayload(BaseModel):
|
||||
"""Payload for scheduling or rescheduling a campaign run."""
|
||||
|
||||
# is_recurring: bool
|
||||
is_recurring: bool
|
||||
# Assign recurrence_pattern = None # weekly, monthly, quarterly
|
||||
recurrence_pattern: Optional[str] = None # weekly, monthly, quarterly
|
||||
# Assign next_run_at = None
|
||||
next_run_at: Optional[str] = None
|
||||
|
||||
|
||||
@@ -83,24 +172,54 @@ class SchedulePayload(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
# Define function list_campaigns
|
||||
def list_campaigns(
|
||||
# Entry: type
|
||||
type: Optional[str] = Query(None),
|
||||
# Entry: status
|
||||
status: Optional[str] = Query(None),
|
||||
# Entry: threat_actor_id
|
||||
threat_actor_id: Optional[str] = Query(None),
|
||||
# Entry: search
|
||||
search: Optional[str] = Query(None),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List campaigns with optional filters and pagination."""
|
||||
) -> list:
|
||||
"""List campaigns with optional filters and pagination.
|
||||
|
||||
Args:
|
||||
type (Optional[str]): Filter by campaign type (e.g. ``custom``, ``threat_actor``).
|
||||
status (Optional[str]): Filter by campaign status (e.g. ``draft``, ``active``).
|
||||
threat_actor_id (Optional[str]): Filter campaigns linked to a specific threat actor.
|
||||
search (Optional[str]): Free-text search against campaign name.
|
||||
offset (int): Number of records to skip for pagination.
|
||||
limit (int): Maximum number of records to return.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
list: Serialised list of campaign summary dicts.
|
||||
"""
|
||||
# Return crud_list(
|
||||
return crud_list(
|
||||
db,
|
||||
# Keyword argument: type
|
||||
type=type,
|
||||
# Keyword argument: status
|
||||
status=status,
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=threat_actor_id,
|
||||
# Keyword argument: search
|
||||
search=search,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -110,36 +229,64 @@ def list_campaigns(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("", status_code=201)
|
||||
# Define function create_campaign
|
||||
def create_campaign(
|
||||
# Entry: payload
|
||||
payload: CampaignCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a new campaign."""
|
||||
) -> dict:
|
||||
"""Create a new campaign.
|
||||
|
||||
Args:
|
||||
payload (CampaignCreate): Fields for the new campaign (name, type, threat actor, etc.).
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead creating the campaign.
|
||||
|
||||
Returns:
|
||||
dict: Serialised representation of the newly created campaign.
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign result = crud_create(
|
||||
result = crud_create(
|
||||
db,
|
||||
# Keyword argument: creator_id
|
||||
creator_id=current_user.id,
|
||||
# Keyword argument: name
|
||||
name=payload.name,
|
||||
# Keyword argument: description
|
||||
description=payload.description,
|
||||
# Keyword argument: type
|
||||
type=payload.type,
|
||||
# Keyword argument: threat_actor_id
|
||||
threat_actor_id=payload.threat_actor_id,
|
||||
# Keyword argument: target_platform
|
||||
target_platform=payload.target_platform,
|
||||
# Keyword argument: tags
|
||||
tags=payload.tags,
|
||||
# Keyword argument: scheduled_at
|
||||
scheduled_at=payload.scheduled_at,
|
||||
start_date=payload.start_date,
|
||||
)
|
||||
campaign_id = result["id"]
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="create_campaign",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
entity_id=campaign_id,
|
||||
details={"name": payload.name, "type": payload.type},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
@@ -148,12 +295,26 @@ def create_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}")
|
||||
# Define function get_campaign
|
||||
def get_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed campaign info including tests and progress."""
|
||||
) -> dict:
|
||||
"""Get detailed campaign info including tests and progress.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the campaign to retrieve.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Campaign detail including associated tests and progress metrics.
|
||||
"""
|
||||
# Return crud_get_detail(db, campaign_id)
|
||||
return crud_get_detail(db, campaign_id)
|
||||
|
||||
|
||||
@@ -162,32 +323,60 @@ def get_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.patch("/{campaign_id}")
|
||||
# Define function update_campaign
|
||||
def update_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: payload
|
||||
payload: CampaignUpdate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Update a campaign. Only allowed in draft or active state."""
|
||||
) -> dict:
|
||||
"""Update a campaign. Only allowed in draft or active state.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the campaign to update.
|
||||
payload (CampaignUpdate): Partial update payload; only set fields are applied.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead performing the update.
|
||||
|
||||
Returns:
|
||||
dict: Serialised representation of the updated campaign.
|
||||
"""
|
||||
# Assign update_data = payload.model_dump(exclude_unset=True)
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign result = crud_update(
|
||||
result = crud_update(
|
||||
db,
|
||||
campaign_id,
|
||||
# Keyword argument: updater_id
|
||||
updater_id=current_user.id,
|
||||
# Keyword argument: updater_role
|
||||
updater_role=current_user.role,
|
||||
**update_data,
|
||||
)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="update_campaign",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=campaign_id,
|
||||
# Keyword argument: details
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
@@ -227,22 +416,44 @@ def delete_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/tests")
|
||||
# Define function add_test_to_campaign
|
||||
def add_test_to_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: payload
|
||||
payload: AddTestPayload,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Add a test to a campaign with optional ordering and dependency."""
|
||||
) -> dict:
|
||||
"""Add a test to a campaign with optional ordering and dependency.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the target campaign.
|
||||
payload (AddTestPayload): Test ID plus optional order index, dependency, and phase.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead adding the test.
|
||||
|
||||
Returns:
|
||||
dict: The created campaign-test association record.
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign result = crud_add_test(
|
||||
result = crud_add_test(
|
||||
db,
|
||||
campaign_id,
|
||||
# Keyword argument: test_id
|
||||
test_id=payload.test_id,
|
||||
# Keyword argument: order_index
|
||||
order_index=payload.order_index,
|
||||
# Keyword argument: depends_on
|
||||
depends_on=payload.depends_on,
|
||||
# Keyword argument: phase
|
||||
phase=payload.phase,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
return result
|
||||
@@ -253,16 +464,35 @@ def add_test_to_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{campaign_id}/tests/{campaign_test_id}")
|
||||
# Define function remove_test_from_campaign
|
||||
def remove_test_from_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: campaign_test_id
|
||||
campaign_test_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Remove a test from a campaign."""
|
||||
) -> dict:
|
||||
"""Remove a test from a campaign.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the campaign.
|
||||
campaign_test_id (str): UUID string of the campaign-test association to remove.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead removing the test.
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with key ``detail``.
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call crud_remove_test()
|
||||
crud_remove_test(db, campaign_id, campaign_test_id)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Return {"detail": "Test removed from campaign"}
|
||||
return {"detail": "Test removed from campaign"}
|
||||
|
||||
|
||||
@@ -271,10 +501,13 @@ def remove_test_from_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/activate")
|
||||
# Define function activate_campaign
|
||||
def activate_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
force: bool = Query(False, description="Activate even if start_date is in the future"),
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Activate a campaign, moving it from draft to active.
|
||||
@@ -303,25 +536,41 @@ def activate_campaign(
|
||||
)
|
||||
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign campaign = crud_activate(db, campaign_id)
|
||||
campaign = crud_activate(db, campaign_id)
|
||||
# Call notify_role()
|
||||
notify_role(
|
||||
db,
|
||||
# Keyword argument: role
|
||||
role="red_tech",
|
||||
# Keyword argument: type
|
||||
type="campaign_activated",
|
||||
# Keyword argument: title
|
||||
title="Campaign activated",
|
||||
# Keyword argument: message
|
||||
message=f'Campaign "{campaign.name}" has been activated.',
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=campaign.id,
|
||||
)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="activate_campaign",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=campaign.id,
|
||||
# Keyword argument: details
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(campaign)
|
||||
|
||||
# Create Jira tickets for campaign and tests at activation time (non-fatal).
|
||||
@@ -359,26 +608,50 @@ def activate_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/{campaign_id}/complete")
|
||||
# Define function complete_campaign
|
||||
def complete_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
||||
):
|
||||
"""Mark a campaign as completed."""
|
||||
) -> dict:
|
||||
"""Mark a campaign as completed.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the campaign to complete.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or admin completing the campaign.
|
||||
|
||||
Returns:
|
||||
dict: Serialised representation of the completed campaign.
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign campaign = crud_complete(db, campaign_id)
|
||||
campaign = crud_complete(db, campaign_id)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="complete_campaign",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=campaign.id,
|
||||
# Keyword argument: details
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(campaign)
|
||||
dispatch_webhook("campaign.completed", {"campaign_id": str(campaign.id), "name": campaign.name})
|
||||
|
||||
# Return serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
@@ -387,12 +660,26 @@ def complete_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}/progress")
|
||||
# Define function get_campaign_progress_endpoint
|
||||
def get_campaign_progress_endpoint(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get progress statistics for a campaign."""
|
||||
) -> dict:
|
||||
"""Get progress statistics for a campaign.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the campaign.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Progress breakdown including counts by test state and overall percentage.
|
||||
"""
|
||||
# Return crud_get_progress(db, campaign_id)
|
||||
return crud_get_progress(db, campaign_id)
|
||||
|
||||
|
||||
@@ -405,16 +692,27 @@ class GenerateFromActorPayload(BaseModel):
|
||||
|
||||
|
||||
@router.post("/from-threat-actor/{actor_id}", status_code=201)
|
||||
# Define function generate_campaign_from_actor
|
||||
def generate_campaign_from_actor(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
payload: GenerateFromActorPayload = GenerateFromActorPayload(),
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Auto-generate a campaign from a threat actor's uncovered techniques.
|
||||
|
||||
Creates tests from the best available templates and orders them
|
||||
by kill chain phase.
|
||||
|
||||
Args:
|
||||
actor_id (str): UUID string of the threat actor to generate a campaign for.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead requesting the generation.
|
||||
|
||||
Returns:
|
||||
dict: Serialised representation of the newly generated campaign.
|
||||
"""
|
||||
start_date_parsed = (
|
||||
datetime.fromisoformat(payload.start_date) if payload.start_date else None
|
||||
@@ -426,17 +724,26 @@ def generate_campaign_from_actor(
|
||||
start_date=start_date_parsed,
|
||||
)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="generate_campaign",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=campaign.id,
|
||||
# Keyword argument: details
|
||||
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
@@ -445,41 +752,74 @@ def generate_campaign_from_actor(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.patch("/{campaign_id}/schedule")
|
||||
# Define function schedule_campaign
|
||||
def schedule_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: payload
|
||||
payload: SchedulePayload,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Configure or update the recurrence schedule for a campaign.
|
||||
|
||||
Only the campaign creator or admin can change scheduling.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the campaign to schedule.
|
||||
payload (SchedulePayload): Recurrence flag, pattern, and next run timestamp.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead (must be owner or admin).
|
||||
|
||||
Returns:
|
||||
dict: Serialised representation of the campaign with updated schedule fields.
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign campaign = crud_schedule(
|
||||
campaign = crud_schedule(
|
||||
db,
|
||||
campaign_id,
|
||||
# Keyword argument: owner_id
|
||||
owner_id=current_user.id,
|
||||
# Keyword argument: owner_role
|
||||
owner_role=current_user.role,
|
||||
# Keyword argument: is_recurring
|
||||
is_recurring=payload.is_recurring,
|
||||
# Keyword argument: recurrence_pattern
|
||||
recurrence_pattern=payload.recurrence_pattern,
|
||||
# Keyword argument: next_run_at
|
||||
next_run_at=payload.next_run_at,
|
||||
)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="schedule_campaign",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="campaign",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=campaign.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"is_recurring": campaign.is_recurring,
|
||||
# Literal argument value
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
# Literal argument value
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(campaign)
|
||||
|
||||
# Return serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
@@ -488,12 +828,26 @@ def schedule_campaign(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{campaign_id}/history")
|
||||
# Define function get_campaign_history
|
||||
def get_campaign_history(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all child campaigns (execution history) of a recurring campaign."""
|
||||
) -> list:
|
||||
"""List all child campaigns (execution history) of a recurring campaign.
|
||||
|
||||
Args:
|
||||
campaign_id (str): UUID string of the parent recurring campaign.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
list: Serialised list of child campaign dicts ordered by creation date.
|
||||
"""
|
||||
# Return crud_get_history(db, campaign_id)
|
||||
return crud_get_history(db, campaign_id)
|
||||
|
||||
|
||||
|
||||
@@ -1,32 +1,45 @@
|
||||
"""Compliance endpoints — framework status, reports, and gap analysis.
|
||||
|
||||
Thin HTTP adapter: delegates all data logic to compliance_service.
|
||||
|
||||
Thin HTTP adapter that delegates all data logic to compliance_service.
|
||||
Provides compliance posture assessment by mapping MITRE ATT&CK technique
|
||||
coverage to compliance framework controls.
|
||||
"""
|
||||
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
# Import StreamingResponse from fastapi.responses
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.services.compliance_service import (
|
||||
list_frameworks,
|
||||
get_framework_status,
|
||||
build_framework_report_csv,
|
||||
get_framework_gaps,
|
||||
)
|
||||
|
||||
# Import from app.services.compliance_import_service
|
||||
from app.services.compliance_import_service import (
|
||||
import_nist_800_53_mappings,
|
||||
import_cis_controls_v8_mappings,
|
||||
import_dora_mappings,
|
||||
import_iso_27001_mappings,
|
||||
import_iso_42001_mappings,
|
||||
)
|
||||
|
||||
# Import from app.services.compliance_service
|
||||
from app.services.compliance_service import (
|
||||
build_framework_report_csv,
|
||||
get_framework_gaps,
|
||||
get_framework_status,
|
||||
list_frameworks,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
|
||||
|
||||
@@ -34,11 +47,23 @@ router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
|
||||
|
||||
@router.get("/frameworks")
|
||||
# Define function list_frameworks_endpoint
|
||||
def list_frameworks_endpoint(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all available compliance frameworks."""
|
||||
) -> list:
|
||||
"""List all available compliance frameworks.
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
list: List of framework summary dicts containing id, name, and control counts.
|
||||
"""
|
||||
# Return list_frameworks(db)
|
||||
return list_frameworks(db)
|
||||
|
||||
|
||||
@@ -46,12 +71,26 @@ def list_frameworks_endpoint(
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/status")
|
||||
# Define function framework_status
|
||||
def framework_status(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get compliance status for each control in a framework."""
|
||||
) -> dict:
|
||||
"""Get compliance status for each control in a framework.
|
||||
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework (e.g. ``nist-800-53``).
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of control IDs to their coverage status and linked techniques.
|
||||
"""
|
||||
# Return get_framework_status(db, framework_id)
|
||||
return get_framework_status(db, framework_id)
|
||||
|
||||
|
||||
@@ -59,12 +98,26 @@ def framework_status(
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/report")
|
||||
# Define function framework_report
|
||||
def framework_report(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get the full compliance report (same as status but marked as report)."""
|
||||
) -> dict:
|
||||
"""Get the full compliance report (same as status but marked as report).
|
||||
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Full compliance report with per-control coverage details.
|
||||
"""
|
||||
# Return get_framework_status(db, framework_id)
|
||||
return get_framework_status(db, framework_id)
|
||||
|
||||
|
||||
@@ -72,17 +125,35 @@ def framework_report(
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/report/csv")
|
||||
# Define function framework_report_csv
|
||||
def framework_report_csv(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Export compliance report as CSV."""
|
||||
) -> StreamingResponse:
|
||||
"""Export compliance report as CSV.
|
||||
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework to export.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: CSV file attachment with compliance coverage data.
|
||||
"""
|
||||
# csv_bytes, filename = build_framework_report_csv(db, framework_id)
|
||||
csv_bytes, filename = build_framework_report_csv(db, framework_id)
|
||||
# Return StreamingResponse(
|
||||
return StreamingResponse(
|
||||
iter([csv_bytes]),
|
||||
# Keyword argument: media_type
|
||||
media_type="text/csv",
|
||||
# Keyword argument: headers
|
||||
headers={
|
||||
# Literal argument value
|
||||
"Content-Disposition": f"attachment; filename={filename}",
|
||||
},
|
||||
)
|
||||
@@ -92,12 +163,26 @@ def framework_report_csv(
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/gaps")
|
||||
# Define function framework_gaps
|
||||
def framework_gaps(
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get controls with techniques that are not adequately covered."""
|
||||
) -> dict:
|
||||
"""Get controls with techniques that are not adequately covered.
|
||||
|
||||
Args:
|
||||
framework_id (str): Identifier of the compliance framework to analyse.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Controls flagged as gaps, with linked technique IDs and coverage ratios.
|
||||
"""
|
||||
# Return get_framework_gaps(db, framework_id)
|
||||
return get_framework_gaps(db, framework_id)
|
||||
|
||||
|
||||
@@ -105,22 +190,49 @@ def framework_gaps(
|
||||
|
||||
|
||||
@router.post("/import/nist-800-53")
|
||||
# Define function import_nist
|
||||
def import_nist(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import NIST 800-53 Rev 5 mappings (admin only)."""
|
||||
) -> dict:
|
||||
"""Import NIST 800-53 Rev 5 mappings (admin only).
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated admin user.
|
||||
|
||||
Returns:
|
||||
dict: Import result with counts of created and updated control mappings.
|
||||
"""
|
||||
# Assign result = import_nist_800_53_mappings(db)
|
||||
result = import_nist_800_53_mappings(db)
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/import/cis-controls-v8")
|
||||
# Define function import_cis
|
||||
def import_cis(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Import CIS Controls v8 mappings (admin only)."""
|
||||
) -> dict:
|
||||
"""Import CIS Controls v8 mappings (admin only).
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated admin user.
|
||||
|
||||
Returns:
|
||||
dict: Import result with counts of created and updated control mappings.
|
||||
"""
|
||||
# Assign result = import_cis_controls_v8_mappings(db)
|
||||
result = import_cis_controls_v8_mappings(db)
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -1,26 +1,47 @@
|
||||
"""D3FEND endpoints — defensive technique listings, mappings, and import trigger."""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.services.d3fend_import_service
|
||||
from app.services.d3fend_import_service import (
|
||||
import_d3fend_techniques,
|
||||
import_d3fend_mappings,
|
||||
import_d3fend_techniques,
|
||||
)
|
||||
|
||||
# Import from app.services.d3fend_query_service
|
||||
from app.services.d3fend_query_service import (
|
||||
get_defenses_for_attack_technique,
|
||||
list_d3fend_tactics,
|
||||
)
|
||||
|
||||
# Import from app.services.d3fend_query_service
|
||||
from app.services.d3fend_query_service import (
|
||||
list_defensive_techniques as list_defensive_techniques_svc,
|
||||
list_d3fend_tactics,
|
||||
get_defenses_for_attack_technique,
|
||||
)
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||
router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||
|
||||
|
||||
@@ -29,15 +50,23 @@ router = APIRouter(prefix="/d3fend", tags=["d3fend"])
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
# Define function list_defensive_techniques
|
||||
def list_defensive_techniques(
|
||||
# Entry: tactic
|
||||
tactic: Optional[str] = Query(None),
|
||||
# Entry: search
|
||||
search: Optional[str] = Query(None),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""List all D3FEND defensive techniques with optional filters."""
|
||||
# Return list_defensive_techniques_svc(
|
||||
return list_defensive_techniques_svc(
|
||||
db, tactic=tactic, search=search, offset=offset, limit=limit
|
||||
)
|
||||
@@ -48,11 +77,15 @@ def list_defensive_techniques(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/tactics")
|
||||
# Define function list_d3fend_tactics_endpoint
|
||||
def list_d3fend_tactics_endpoint(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Return a list of all D3FEND tactics with counts."""
|
||||
# Return list_d3fend_tactics(db)
|
||||
return list_d3fend_tactics(db)
|
||||
|
||||
|
||||
@@ -61,12 +94,17 @@ def list_d3fend_tactics_endpoint(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/for-technique/{mitre_id}")
|
||||
# Define function get_defenses_for_attack_technique_endpoint
|
||||
def get_defenses_for_attack_technique_endpoint(
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
# Return get_defenses_for_attack_technique(db, mitre_id)
|
||||
return get_defenses_for_attack_technique(db, mitre_id)
|
||||
|
||||
|
||||
@@ -75,15 +113,23 @@ def get_defenses_for_attack_technique_endpoint(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/import")
|
||||
# Define function trigger_d3fend_import
|
||||
def trigger_d3fend_import(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Import D3FEND techniques and ATT&CK mappings. Admin only."""
|
||||
# Assign tech_result = import_d3fend_techniques(db)
|
||||
tech_result = import_d3fend_techniques(db)
|
||||
# Assign mapping_result = import_d3fend_mappings(db)
|
||||
mapping_result = import_d3fend_mappings(db)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"techniques": tech_result,
|
||||
# Literal argument value
|
||||
"mappings": mapping_result,
|
||||
}
|
||||
|
||||
@@ -5,16 +5,34 @@ Provides a centralized panel for managing all external data sources
|
||||
including sync triggers, enable/disable toggles, and statistics.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
# Import BaseModel from pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import require_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.data_source_service
|
||||
from app.services.data_source_service import (
|
||||
get_source_stats,
|
||||
list_sources,
|
||||
@@ -23,18 +41,21 @@ from app.services.data_source_service import (
|
||||
update_source,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic schemas for request validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DataSourceUpdate(BaseModel):
|
||||
"""Payload for updating a data source — only allowed fields."""
|
||||
# Assign is_enabled = None
|
||||
is_enabled: Optional[bool] = None
|
||||
# Assign sync_frequency = None
|
||||
sync_frequency: Optional[str] = None
|
||||
# Assign config = None
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
# Assign router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
|
||||
|
||||
@@ -44,90 +65,137 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
# Define function list_data_sources
|
||||
def list_data_sources(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> list:
|
||||
"""List all registered data sources.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
# Return list_sources(db)
|
||||
return list_sources(db)
|
||||
|
||||
|
||||
# Apply the @router.patch decorator
|
||||
@router.patch("/{source_id}")
|
||||
# Define function update_data_source
|
||||
def update_data_source(
|
||||
# Entry: source_id
|
||||
source_id: str,
|
||||
# Entry: body
|
||||
body: DataSourceUpdate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Update a data source (enable/disable, change config).
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
# Assign update_data = body.model_dump(exclude_unset=True)
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call update_source()
|
||||
update_source(db, source_id, **update_data)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="update_data_source",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="data_source",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=source_id,
|
||||
# Keyword argument: details
|
||||
details={"updates": update_data},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"message": "Data source updated", "id": source_id}
|
||||
return {"message": "Data source updated", "id": source_id}
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/{source_id}/sync")
|
||||
# Define function sync_data_source
|
||||
def sync_data_source(
|
||||
# Entry: source_id
|
||||
source_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Trigger sync/import for a specific data source.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
# Return sync_source(db, source_id)
|
||||
return sync_source(db, source_id)
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/sync-all")
|
||||
# Define function sync_all_data_sources
|
||||
def sync_all_data_sources(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Trigger sync for all enabled data sources (sequentially).
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
# Assign results = sync_all_sources(db)
|
||||
results = sync_all_sources(db)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="sync_all_data_sources",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="data_source",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details={"results": results},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"message": "Sync all complete", "results": results}
|
||||
return {"message": "Sync all complete", "results": results}
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/{source_id}/stats")
|
||||
# Define function get_data_source_stats
|
||||
def get_data_source_stats(
|
||||
# Entry: source_id
|
||||
source_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Get detailed statistics for a specific data source.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
# Return get_source_stats(db, source_id)
|
||||
return get_source_stats(db, source_id)
|
||||
|
||||
@@ -6,36 +6,55 @@ Provides endpoints for browsing detection rules, querying rules by technique,
|
||||
and managing the template ↔ detection rule associations.
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import BaseModel from pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||
from app.models.user import User
|
||||
from app.services.detection_rule_service import (
|
||||
list_rules,
|
||||
get_rules_for_template,
|
||||
auto_associate_rules,
|
||||
get_rules_for_test,
|
||||
evaluate_rule,
|
||||
)
|
||||
|
||||
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.services.detection_rule_service
|
||||
from app.services.detection_rule_service import (
|
||||
auto_associate_rules,
|
||||
evaluate_rule,
|
||||
get_rules_for_template,
|
||||
get_rules_for_test,
|
||||
list_rules,
|
||||
)
|
||||
|
||||
# ── Pydantic schemas for request validation ────────────────────────────
|
||||
|
||||
|
||||
class DetectionRuleEvaluate(BaseModel):
|
||||
"""Payload for evaluating a detection rule against a test."""
|
||||
# test_id: uuid.UUID
|
||||
test_id: uuid.UUID
|
||||
# detection_rule_id: uuid.UUID
|
||||
detection_rule_id: uuid.UUID
|
||||
# Assign triggered = None
|
||||
triggered: Optional[bool] = None
|
||||
# Assign notes = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
# Assign router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||
router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||
|
||||
|
||||
@@ -43,24 +62,40 @@ router = APIRouter(prefix="/detection-rules", tags=["detection-rules"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
# Define function list_detection_rules
|
||||
def list_detection_rules(
|
||||
# Entry: technique
|
||||
technique: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
|
||||
# Entry: source
|
||||
source: Optional[str] = Query(None, description="Filter by source (sigma, elastic, splunk, custom)"),
|
||||
# Entry: severity
|
||||
severity: Optional[str] = Query(None),
|
||||
# Entry: search
|
||||
search: Optional[str] = Query(None),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""List detection rules with optional filters and pagination."""
|
||||
# Return list_rules(
|
||||
return list_rules(
|
||||
db,
|
||||
# Keyword argument: technique
|
||||
technique=technique,
|
||||
# Keyword argument: source
|
||||
source=source,
|
||||
# Keyword argument: severity
|
||||
severity=severity,
|
||||
# Keyword argument: search
|
||||
search=search,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -69,12 +104,17 @@ def list_detection_rules(
|
||||
|
||||
|
||||
@router.get("/for-template/{template_id}")
|
||||
# Define function get_detection_rules_for_template
|
||||
def get_detection_rules_for_template(
|
||||
# Entry: template_id
|
||||
template_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Get detection rules associated with a test template."""
|
||||
# Return get_rules_for_template(db, template_id)
|
||||
return get_rules_for_template(db, template_id)
|
||||
|
||||
|
||||
@@ -82,16 +122,20 @@ def get_detection_rules_for_template(
|
||||
|
||||
|
||||
@router.post("/auto-associate")
|
||||
# Define function auto_associate_detection_rules
|
||||
def auto_associate_detection_rules(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Auto-associate test templates with detection rules by MITRE technique ID.
|
||||
|
||||
For each active template, find all active detection rules for the same
|
||||
technique and create associations. Rules with severity >= high are marked
|
||||
as primary.
|
||||
"""
|
||||
# Return auto_associate_rules(db)
|
||||
return auto_associate_rules(db)
|
||||
|
||||
|
||||
@@ -99,16 +143,21 @@ def auto_associate_detection_rules(
|
||||
|
||||
|
||||
@router.get("/for-test/{test_id}")
|
||||
# Define function get_detection_rules_for_test
|
||||
def get_detection_rules_for_test(
|
||||
# Entry: test_id
|
||||
test_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Get detection rules relevant to a test, along with their evaluation results.
|
||||
|
||||
Finds rules by matching the test's technique_id to detection rules,
|
||||
and returns any existing evaluation results.
|
||||
"""
|
||||
# Return get_rules_for_test(db, test_id)
|
||||
return get_rules_for_test(db, test_id)
|
||||
|
||||
|
||||
@@ -116,17 +165,27 @@ def get_detection_rules_for_test(
|
||||
|
||||
|
||||
@router.post("/evaluate")
|
||||
# Define function evaluate_detection_rule
|
||||
def evaluate_detection_rule(
|
||||
# Entry: payload
|
||||
payload: DetectionRuleEvaluate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Save or update the evaluation result for a detection rule on a test."""
|
||||
# Return evaluate_rule(
|
||||
return evaluate_rule(
|
||||
db,
|
||||
# Keyword argument: test_id
|
||||
test_id=payload.test_id,
|
||||
# Keyword argument: detection_rule_id
|
||||
detection_rule_id=payload.detection_rule_id,
|
||||
# Keyword argument: triggered
|
||||
triggered=payload.triggered,
|
||||
# Keyword argument: notes
|
||||
notes=payload.notes,
|
||||
# Keyword argument: evaluator_id
|
||||
evaluator_id=current_user.id,
|
||||
)
|
||||
|
||||
@@ -20,30 +20,54 @@ Access Control
|
||||
``validated``, or ``rejected``.
|
||||
"""
|
||||
|
||||
# Import hashlib
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Import uuid
|
||||
import uuid as _uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, File, Form, Query, Request,... from fastapi
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import TeamSide from app.models.enums
|
||||
from app.models.enums import TeamSide
|
||||
|
||||
# Import Evidence from app.models.evidence
|
||||
from app.models.evidence import Evidence
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import EvidenceOut from app.schemas.evidence
|
||||
from app.schemas.evidence import EvidenceOut
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.evidence_service
|
||||
from app.services.evidence_service import (
|
||||
MAX_UPLOAD_SIZE,
|
||||
get_evidence_or_raise,
|
||||
get_test_or_raise,
|
||||
list_evidence_for_test,
|
||||
MAX_UPLOAD_SIZE,
|
||||
validate_delete_permission,
|
||||
validate_file,
|
||||
validate_upload_permission,
|
||||
@@ -53,6 +77,7 @@ from app.storage import download_file, upload_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(tags=["evidence"])
|
||||
router = APIRouter(tags=["evidence"])
|
||||
|
||||
|
||||
@@ -67,13 +92,21 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
never needs direct access to MinIO.
|
||||
"""
|
||||
return EvidenceOut(
|
||||
# Keyword argument: id
|
||||
id=evidence.id,
|
||||
# Keyword argument: test_id
|
||||
test_id=evidence.test_id,
|
||||
# Keyword argument: file_name
|
||||
file_name=evidence.file_name,
|
||||
# Keyword argument: sha256_hash
|
||||
sha256_hash=evidence.sha256_hash,
|
||||
# Keyword argument: uploaded_by
|
||||
uploaded_by=evidence.uploaded_by,
|
||||
# Keyword argument: uploaded_at
|
||||
uploaded_at=evidence.uploaded_at,
|
||||
# Keyword argument: team
|
||||
team=evidence.team,
|
||||
# Keyword argument: notes
|
||||
notes=evidence.notes,
|
||||
download_url=f"/api/v1/evidence/{evidence.id}/file",
|
||||
)
|
||||
@@ -85,30 +118,47 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
|
||||
|
||||
@router.post(
|
||||
# Literal argument value
|
||||
"/tests/{test_id}/evidence",
|
||||
# Keyword argument: response_model
|
||||
response_model=EvidenceOut,
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("10/minute")
|
||||
# Define async function upload_evidence
|
||||
async def upload_evidence(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: test_id
|
||||
test_id: _uuid.UUID,
|
||||
# Entry: file
|
||||
file: UploadFile = File(...),
|
||||
# Entry: team
|
||||
team: TeamSide = Form(TeamSide.red),
|
||||
# Entry: notes
|
||||
notes: Optional[str] = Form(None),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> EvidenceOut:
|
||||
"""Upload a file as evidence for the given test.
|
||||
|
||||
The ``team`` field (sent as form data) determines whether this is
|
||||
Red Team (attack) or Blue Team (detection) evidence.
|
||||
"""
|
||||
# Assign test = get_test_or_raise(db, test_id)
|
||||
test = get_test_or_raise(db, test_id)
|
||||
# Call validate_upload_permission()
|
||||
validate_upload_permission(test, team, current_user.role)
|
||||
|
||||
# Assign file_name = file.filename or "unnamed"
|
||||
file_name = file.filename or "unnamed"
|
||||
# Assign content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
# Call validate_file()
|
||||
validate_file(file_name, len(content))
|
||||
|
||||
# Hash
|
||||
@@ -116,6 +166,7 @@ async def upload_evidence(
|
||||
|
||||
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
||||
safe_name = os.path.basename(file_name)
|
||||
# Assign key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
||||
key = f"{test_id}/{_uuid.uuid4()}_{safe_name}"
|
||||
|
||||
# 5. Upload to MinIO
|
||||
@@ -123,32 +174,53 @@ async def upload_evidence(
|
||||
|
||||
# 6. Persist metadata and audit
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign evidence = Evidence(
|
||||
evidence = Evidence(
|
||||
# Keyword argument: test_id
|
||||
test_id=test_id,
|
||||
# Keyword argument: file_name
|
||||
file_name=safe_name,
|
||||
# Keyword argument: file_path
|
||||
file_path=key,
|
||||
# Keyword argument: sha256_hash
|
||||
sha256_hash=sha256,
|
||||
# Keyword argument: uploaded_by
|
||||
uploaded_by=current_user.id,
|
||||
uploaded_at=datetime.utcnow(), # set explicitly — DB column has no server default
|
||||
team=team,
|
||||
# Keyword argument: notes
|
||||
notes=notes,
|
||||
)
|
||||
# Stage new record(s) for database insertion
|
||||
db.add(evidence)
|
||||
# Flush changes to DB without committing the transaction
|
||||
db.flush() # Get evidence.id for audit
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="upload_evidence",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="evidence",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=evidence.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"file_name": safe_name,
|
||||
# Literal argument value
|
||||
"sha256": sha256,
|
||||
# Literal argument value
|
||||
"test_id": str(test_id),
|
||||
# Literal argument value
|
||||
"team": team.value,
|
||||
},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(evidence)
|
||||
|
||||
# 7. Attach to Jira ticket if one exists (non-fatal)
|
||||
@@ -194,15 +266,23 @@ def _attach_evidence_to_jira(
|
||||
|
||||
|
||||
@router.get("/tests/{test_id}/evidence", response_model=list[EvidenceOut])
|
||||
# Define function list_evidence
|
||||
def list_evidence(
|
||||
# Entry: test_id
|
||||
test_id: _uuid.UUID,
|
||||
# Entry: team
|
||||
team: Optional[str] = Query(None, description="Filter by team: red or blue"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[EvidenceOut]:
|
||||
"""List all evidences for a test, optionally filtered by team."""
|
||||
# Call get_test_or_raise()
|
||||
get_test_or_raise(db, test_id)
|
||||
# Assign evidences = list_evidence_for_test(db, test_id, team=team)
|
||||
evidences = list_evidence_for_test(db, test_id, team=team)
|
||||
# Return [_evidence_to_out(e) for e in evidences]
|
||||
return [_evidence_to_out(e) for e in evidences]
|
||||
|
||||
|
||||
@@ -212,13 +292,18 @@ def list_evidence(
|
||||
|
||||
|
||||
@router.get("/evidence/{evidence_id}", response_model=EvidenceOut)
|
||||
# Define function get_evidence
|
||||
def get_evidence(
|
||||
# Entry: evidence_id
|
||||
evidence_id: _uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return evidence metadata. ``download_url`` is a backend proxy URL."""
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
# Return _evidence_to_out(evidence)
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
@@ -265,11 +350,15 @@ def download_evidence_file(
|
||||
|
||||
|
||||
@router.delete("/evidence/{evidence_id}", status_code=status.HTTP_200_OK)
|
||||
# Define function delete_evidence
|
||||
def delete_evidence(
|
||||
# Entry: evidence_id
|
||||
evidence_id: _uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Delete an evidence record.
|
||||
|
||||
Only allowed in editable states:
|
||||
@@ -277,24 +366,40 @@ def delete_evidence(
|
||||
- Blue evidence: ``blue_evaluating``
|
||||
- No deletions in ``in_review``, ``validated``, ``rejected``
|
||||
"""
|
||||
# Assign evidence = get_evidence_or_raise(db, evidence_id)
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
# Assign test = get_test_or_raise(db, evidence.test_id)
|
||||
test = get_test_or_raise(db, evidence.test_id)
|
||||
# Call validate_delete_permission()
|
||||
validate_delete_permission(test, evidence, current_user.role, current_user.id)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="delete_evidence",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="evidence",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=evidence.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"file_name": evidence.file_name,
|
||||
# Literal argument value
|
||||
"test_id": str(evidence.test_id),
|
||||
# Literal argument value
|
||||
"team": evidence.team.value if evidence.team else None,
|
||||
},
|
||||
)
|
||||
# Mark record for deletion on next commit
|
||||
db.delete(evidence)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"detail": "Evidence deleted"}
|
||||
return {"detail": "Evidence deleted"}
|
||||
|
||||
@@ -5,101 +5,169 @@ No business logic lives here — only request validation and response
|
||||
formatting.
|
||||
"""
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import json
|
||||
import json
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import StreamingResponse from fastapi.responses
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import heatmap_service from app.services
|
||||
from app.services import heatmap_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
||||
router = APIRouter(prefix="/heatmap", tags=["heatmap"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage")
|
||||
# Define function heatmap_coverage
|
||||
def heatmap_coverage(
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None, description="Comma-separated platforms"),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None, description="Comma-separated tactics"),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Coverage layer — score based on status_global of each technique."""
|
||||
# Return heatmap_service.build_coverage_layer(
|
||||
return heatmap_service.build_coverage_layer(
|
||||
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/threat-actor/{actor_id}")
|
||||
# Define function heatmap_threat_actor
|
||||
def heatmap_threat_actor(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Threat actor layer — techniques used by an actor with coverage color."""
|
||||
# Return heatmap_service.build_threat_actor_layer(
|
||||
return heatmap_service.build_threat_actor_layer(
|
||||
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/detection-rules")
|
||||
# Define function heatmap_detection_rules
|
||||
def heatmap_detection_rules(
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Detection rules layer — score based on ratio of rules available vs total."""
|
||||
# Return heatmap_service.build_detection_rules_layer(
|
||||
return heatmap_service.build_detection_rules_layer(
|
||||
db, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/campaign/{campaign_id}")
|
||||
# Define function heatmap_campaign
|
||||
def heatmap_campaign(
|
||||
# Entry: campaign_id
|
||||
campaign_id: str,
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Campaign layer — only techniques in the campaign, colored by test state."""
|
||||
# Return heatmap_service.build_campaign_layer(
|
||||
return heatmap_service.build_campaign_layer(
|
||||
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/export-navigator")
|
||||
# Define function export_navigator
|
||||
def export_navigator(
|
||||
# Entry: layer
|
||||
layer: str = Query(..., description="Layer type: coverage, threat-actor, detection-rules, campaign"),
|
||||
# Entry: layer_id
|
||||
layer_id: Optional[str] = Query(None, description="Actor ID or Campaign ID (if applicable)"),
|
||||
# Entry: platforms
|
||||
platforms: Optional[str] = Query(None),
|
||||
# Entry: tactics
|
||||
tactics: Optional[str] = Query(None),
|
||||
# Entry: min_score
|
||||
min_score: int = Query(0, ge=0, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> StreamingResponse:
|
||||
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
|
||||
# Assign data = heatmap_service.build_navigator_export(
|
||||
data = heatmap_service.build_navigator_export(
|
||||
db, layer, layer_id=layer_id,
|
||||
# Keyword argument: platforms
|
||||
platforms=platforms, tactics=tactics, min_score=min_score,
|
||||
)
|
||||
|
||||
# Assign json_content = json.dumps(data, indent=2, default=str)
|
||||
json_content = json.dumps(data, indent=2, default=str)
|
||||
# Assign buffer = io.BytesIO(json_content.encode("utf-8"))
|
||||
buffer = io.BytesIO(json_content.encode("utf-8"))
|
||||
|
||||
# Return StreamingResponse(
|
||||
return StreamingResponse(
|
||||
buffer,
|
||||
# Keyword argument: media_type
|
||||
media_type="application/json",
|
||||
# Keyword argument: headers
|
||||
headers={"Content-Disposition": f"attachment; filename=aegis_{layer}_layer.json"},
|
||||
)
|
||||
|
||||
+103
-6
@@ -1,138 +1,235 @@
|
||||
"""Jira integration router — link, search, sync, create issues."""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import JiraLinkEntityType from app.models.jira_link
|
||||
from app.models.jira_link import JiraLinkEntityType
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.schemas.jira_schema
|
||||
from app.schemas.jira_schema import (
|
||||
JiraIssueResult,
|
||||
JiraLinkCreate,
|
||||
JiraLinkOut,
|
||||
)
|
||||
from app.services import jira_service, audit_service
|
||||
|
||||
# Import audit_service, jira_service from app.services
|
||||
from app.services import audit_service, jira_service
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(prefix="/jira", tags=["jira"])
|
||||
router = APIRouter(prefix="/jira", tags=["jira"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/search", response_model=list[JiraIssueResult])
|
||||
# Define function search_issues
|
||||
def search_issues(
|
||||
# Entry: q
|
||||
q: str = Query(..., min_length=2),
|
||||
# Entry: max_results
|
||||
max_results: int = Query(10, le=50),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[JiraIssueResult]:
|
||||
"""Search Jira issues by JQL or free text."""
|
||||
# Return jira_service.search_jira_issues(q, max_results)
|
||||
return jira_service.search_jira_issues(q, max_results)
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/links", response_model=JiraLinkOut, status_code=201)
|
||||
# Define function create_link
|
||||
def create_link(
|
||||
# Entry: body
|
||||
body: JiraLinkCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> JiraLinkOut:
|
||||
"""Associate an Aegis entity with a Jira issue."""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign link = jira_service.create_link(
|
||||
link = jira_service.create_link(
|
||||
db,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=body.entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=body.entity_id,
|
||||
# Keyword argument: jira_issue_key
|
||||
jira_issue_key=body.jira_issue_key,
|
||||
# Keyword argument: sync_direction
|
||||
sync_direction=body.sync_direction,
|
||||
# Keyword argument: created_by
|
||||
created_by=user.id,
|
||||
)
|
||||
# Call audit_service.log_action()
|
||||
audit_service.log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="JIRA_LINK_CREATED",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="jira_link",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=str(link.id),
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"linked_entity_type": body.entity_type.value,
|
||||
# Literal argument value
|
||||
"linked_entity_id": str(body.entity_id),
|
||||
# Literal argument value
|
||||
"jira_issue_key": body.jira_issue_key,
|
||||
},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(link)
|
||||
|
||||
# Return link
|
||||
return link
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/links", response_model=list[JiraLinkOut])
|
||||
# Define function list_links
|
||||
def list_links(
|
||||
# Entry: entity_type
|
||||
entity_type: Optional[JiraLinkEntityType] = None,
|
||||
# Entry: entity_id
|
||||
entity_id: Optional[UUID] = None,
|
||||
entity_ids: Optional[list[UUID]] = Query(default=None, description="Filter by multiple entity IDs"),
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List Jira links, optionally filtered by entity or a list of entity IDs."""
|
||||
return jira_service.list_links(
|
||||
db,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
entity_ids=entity_ids,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/links/{link_id}/sync")
|
||||
# Define function sync_link
|
||||
def sync_link(
|
||||
# Entry: link_id
|
||||
link_id: UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Force bidirectional sync for a specific Jira link."""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign link = jira_service.get_link_or_raise(db, link_id)
|
||||
link = jira_service.get_link_or_raise(db, link_id)
|
||||
# Call jira_service.sync_jira_to_aegis()
|
||||
jira_service.sync_jira_to_aegis(db, link)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||
return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||
|
||||
|
||||
# Apply the @router.delete decorator
|
||||
@router.delete("/links/{link_id}", status_code=204)
|
||||
# Define function delete_link
|
||||
def delete_link(
|
||||
# Entry: link_id
|
||||
link_id: UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> None:
|
||||
"""Remove a Jira link."""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign link = jira_service.delete_link(db, link_id)
|
||||
link = jira_service.delete_link(db, link_id)
|
||||
# Call audit_service.log_action()
|
||||
audit_service.log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: action
|
||||
action="jira_link_deleted",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="jira_link",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=str(link_id),
|
||||
# Keyword argument: details
|
||||
details={"jira_issue_key": link.jira_issue_key},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/create-issue")
|
||||
# Define function create_issue_from_entity
|
||||
def create_issue_from_entity(
|
||||
# Entry: entity_type
|
||||
entity_type: JiraLinkEntityType,
|
||||
# Entry: entity_id
|
||||
entity_id: UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign result = jira_service.create_issue_and_link(
|
||||
result = jira_service.create_issue_and_link(
|
||||
db,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
# Keyword argument: created_by
|
||||
created_by=user.id,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Return result
|
||||
return result
|
||||
|
||||
@@ -7,12 +7,22 @@ validation-rate endpoints for the Red/Blue workflow.
|
||||
Thin HTTP adapter: delegates all data logic to metrics_query_service.
|
||||
"""
|
||||
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.schemas.metrics
|
||||
from app.schemas.metrics import (
|
||||
CoverageSummary,
|
||||
RecentTestItem,
|
||||
@@ -21,6 +31,8 @@ from app.schemas.metrics import (
|
||||
TestPipelineCounts,
|
||||
ValidationRate,
|
||||
)
|
||||
|
||||
# Import from app.services.metrics_query_service
|
||||
from app.services.metrics_query_service import (
|
||||
get_coverage_by_tactic,
|
||||
get_coverage_summary,
|
||||
@@ -30,6 +42,7 @@ from app.services.metrics_query_service import (
|
||||
get_validation_rate,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/metrics", tags=["metrics"])
|
||||
router = APIRouter(prefix="/metrics", tags=["metrics"])
|
||||
|
||||
|
||||
@@ -39,11 +52,15 @@ router = APIRouter(prefix="/metrics", tags=["metrics"])
|
||||
|
||||
|
||||
@router.get("/summary", response_model=CoverageSummary)
|
||||
# Define function coverage_summary
|
||||
def coverage_summary(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> CoverageSummary:
|
||||
"""Return a global coverage summary across all techniques."""
|
||||
# Return get_coverage_summary(db)
|
||||
return get_coverage_summary(db)
|
||||
|
||||
|
||||
@@ -53,11 +70,15 @@ def coverage_summary(
|
||||
|
||||
|
||||
@router.get("/by-tactic", response_model=list[TacticCoverage])
|
||||
# Define function coverage_by_tactic
|
||||
def coverage_by_tactic(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[TacticCoverage]:
|
||||
"""Return coverage breakdown grouped by tactic."""
|
||||
# Return get_coverage_by_tactic(db)
|
||||
return get_coverage_by_tactic(db)
|
||||
|
||||
|
||||
@@ -67,11 +88,15 @@ def coverage_by_tactic(
|
||||
|
||||
|
||||
@router.get("/test-pipeline", response_model=TestPipelineCounts)
|
||||
# Define function test_pipeline
|
||||
def test_pipeline(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> TestPipelineCounts:
|
||||
"""Return how many tests are in each pipeline state."""
|
||||
# Return get_test_pipeline_counts(db)
|
||||
return get_test_pipeline_counts(db)
|
||||
|
||||
|
||||
@@ -81,11 +106,15 @@ def test_pipeline(
|
||||
|
||||
|
||||
@router.get("/team-activity", response_model=list[TeamActivity])
|
||||
# Define function team_activity
|
||||
def team_activity(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[TeamActivity]:
|
||||
"""Return activity summary for Red and Blue teams."""
|
||||
# Return get_team_activity(db)
|
||||
return get_team_activity(db)
|
||||
|
||||
|
||||
@@ -95,11 +124,15 @@ def team_activity(
|
||||
|
||||
|
||||
@router.get("/validation-rate", response_model=list[ValidationRate])
|
||||
# Define function validation_rate
|
||||
def validation_rate(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[ValidationRate]:
|
||||
"""Return approval and rejection rates for Red Lead and Blue Lead."""
|
||||
# Return get_validation_rate(db)
|
||||
return get_validation_rate(db)
|
||||
|
||||
|
||||
@@ -109,9 +142,13 @@ def validation_rate(
|
||||
|
||||
|
||||
@router.get("/recent-tests", response_model=list[RecentTestItem])
|
||||
# Define function recent_tests
|
||||
def recent_tests(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[RecentTestItem]:
|
||||
"""Return the 10 most recently created tests."""
|
||||
# Return get_recent_tests(db, limit=10)
|
||||
return get_recent_tests(db, limit=10)
|
||||
|
||||
@@ -8,23 +8,39 @@ PATCH /notifications/{id}/read — mark one notification as read
|
||||
POST /notifications/read-all — mark all as read
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import NotificationOut, UnreadCountOut from app.schemas.notification
|
||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||
|
||||
# Import from app.services.notification_service
|
||||
from app.services.notification_service import (
|
||||
list_notifications,
|
||||
mark_as_read,
|
||||
mark_all_as_read,
|
||||
get_unread_count,
|
||||
list_notifications,
|
||||
mark_all_as_read,
|
||||
mark_as_read,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
@@ -34,13 +50,19 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[NotificationOut])
|
||||
# Define function list_notifications_endpoint
|
||||
def list_notifications_endpoint(
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list[NotificationOut]:
|
||||
"""Return paginated notifications for the current user, newest first."""
|
||||
# Return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||
return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||
|
||||
|
||||
@@ -50,12 +72,17 @@ def list_notifications_endpoint(
|
||||
|
||||
|
||||
@router.get("/unread-count", response_model=UnreadCountOut)
|
||||
# Define function unread_count
|
||||
def unread_count(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> UnreadCountOut:
|
||||
"""Return the number of unread notifications for the current user."""
|
||||
# Assign count = get_unread_count(db, current_user.id)
|
||||
count = get_unread_count(db, current_user.id)
|
||||
# Return UnreadCountOut(unread_count=count)
|
||||
return UnreadCountOut(unread_count=count)
|
||||
|
||||
|
||||
@@ -65,15 +92,23 @@ def unread_count(
|
||||
|
||||
|
||||
@router.patch("/{notification_id}/read", response_model=NotificationOut)
|
||||
# Define function read_notification
|
||||
def read_notification(
|
||||
# Entry: notification_id
|
||||
notification_id: uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> NotificationOut:
|
||||
"""Mark a single notification as read."""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign notif = mark_as_read(db, notification_id, current_user.id)
|
||||
notif = mark_as_read(db, notification_id, current_user.id)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Return notif
|
||||
return notif
|
||||
|
||||
|
||||
@@ -83,12 +118,19 @@ def read_notification(
|
||||
|
||||
|
||||
@router.post("/read-all")
|
||||
# Define function read_all_notifications
|
||||
def read_all_notifications(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Mark all notifications for the current user as read."""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign count = mark_all_as_read(db, current_user.id)
|
||||
count = mark_all_as_read(db, current_user.id)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Return {"detail": f"Marked {count} notifications as read"}
|
||||
return {"detail": f"Marked {count} notifications as read"}
|
||||
|
||||
@@ -4,18 +4,28 @@ Provides operational KPIs for security teams with trend analysis
|
||||
and team-level breakdowns.
|
||||
"""
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.services.operational_metrics_service
|
||||
from app.services.operational_metrics_service import (
|
||||
get_all_operational_metrics,
|
||||
get_operational_trend,
|
||||
get_metrics_by_team,
|
||||
get_operational_trend,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
|
||||
router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
|
||||
|
||||
|
||||
@@ -23,13 +33,18 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
# Define function operational_metrics
|
||||
def operational_metrics(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min."""
|
||||
# Import get_operational_metrics_cached from app.services.score_cache
|
||||
from app.services.score_cache import get_operational_metrics_cached
|
||||
|
||||
# Return get_operational_metrics_cached(db)
|
||||
return get_operational_metrics_cached(db)
|
||||
|
||||
|
||||
@@ -37,12 +52,17 @@ def operational_metrics(
|
||||
|
||||
|
||||
@router.get("/trend")
|
||||
# Define function operational_trend
|
||||
def operational_trend(
|
||||
# Entry: period
|
||||
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Get weekly trend data for operational metrics."""
|
||||
# Return get_operational_trend(db, period)
|
||||
return get_operational_trend(db, period)
|
||||
|
||||
|
||||
@@ -50,9 +70,13 @@ def operational_trend(
|
||||
|
||||
|
||||
@router.get("/by-team")
|
||||
# Define function metrics_by_team
|
||||
def metrics_by_team(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Get metrics broken down by Red Team vs Blue Team."""
|
||||
# Return get_metrics_by_team(db)
|
||||
return get_metrics_by_team(db)
|
||||
|
||||
+162
-15
@@ -1,26 +1,44 @@
|
||||
"""OSINT enrichment endpoints — view, review, and trigger enrichment of
|
||||
OSINT items (CVEs, advisories, etc.) linked to techniques.
|
||||
"""
|
||||
"""OSINT enrichment endpoints — view, review, and trigger enrichment of OSINT items linked to techniques."""
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
# Import APIRouter, Depends, HTTPException, Query, status from fastapi
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
# Import BaseModel from pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.services.osint_enrichment_service
|
||||
from app.services.osint_enrichment_service import (
|
||||
enrich_technique_with_cves,
|
||||
get_osint_items_for_technique,
|
||||
get_osint_summary,
|
||||
get_technique_or_raise,
|
||||
list_osint_items as service_list_osint_items,
|
||||
mark_osint_reviewed,
|
||||
)
|
||||
|
||||
# Import from app.services.osint_enrichment_service
|
||||
from app.services.osint_enrichment_service import (
|
||||
list_osint_items as service_list_osint_items,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/osint", tags=["osint"])
|
||||
router = APIRouter(prefix="/osint", tags=["osint"])
|
||||
|
||||
|
||||
@@ -28,18 +46,34 @@ router = APIRouter(prefix="/osint", tags=["osint"])
|
||||
|
||||
|
||||
class OsintItemOut(BaseModel):
|
||||
"""Serialized OSINT item returned by the API."""
|
||||
|
||||
# id: str
|
||||
id: str
|
||||
# technique_id: str
|
||||
technique_id: str
|
||||
# source_type: str
|
||||
source_type: str
|
||||
# source_url: str
|
||||
source_url: str
|
||||
# title: str
|
||||
title: str
|
||||
# description: str | None
|
||||
description: str | None
|
||||
# severity: str | None
|
||||
severity: str | None
|
||||
# discovered_at: str | None
|
||||
discovered_at: str | None
|
||||
# reviewed: bool
|
||||
reviewed: bool
|
||||
# Assign metadata_ = None
|
||||
metadata_: dict | None = None
|
||||
|
||||
# Define class Config
|
||||
class Config:
|
||||
"""ORM mode configuration for SQLAlchemy model mapping."""
|
||||
|
||||
# Assign from_attributes = True
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@@ -47,94 +81,207 @@ class OsintItemOut(BaseModel):
|
||||
|
||||
|
||||
@router.get("/items")
|
||||
# Define function list_osint_items
|
||||
def list_osint_items(
|
||||
# Entry: technique_id
|
||||
technique_id: UUID | None = Query(None),
|
||||
# Entry: source_type
|
||||
source_type: str | None = Query(None),
|
||||
# Entry: reviewed
|
||||
reviewed: bool | None = Query(None),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List OSINT items with optional filters."""
|
||||
) -> list:
|
||||
"""List OSINT items with optional filters.
|
||||
|
||||
Args:
|
||||
technique_id (UUID | None): Filter by the technique's UUID.
|
||||
source_type (str | None): Filter by source type (e.g. ``nvd_cve``, ``advisory``).
|
||||
reviewed (bool | None): Filter by review status; ``None`` returns all.
|
||||
offset (int): Number of records to skip for pagination.
|
||||
limit (int): Maximum number of records to return.
|
||||
db (Session): SQLAlchemy database session.
|
||||
user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
list: Serialised list of OSINT item dicts matching the filters.
|
||||
"""
|
||||
# Return service_list_osint_items(
|
||||
return service_list_osint_items(
|
||||
db,
|
||||
# Keyword argument: technique_id
|
||||
technique_id=technique_id,
|
||||
# Keyword argument: source_type
|
||||
source_type=source_type,
|
||||
# Keyword argument: reviewed
|
||||
reviewed=reviewed,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/summary")
|
||||
# Define function osint_summary
|
||||
def osint_summary(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Summary statistics for OSINT items."""
|
||||
) -> dict:
|
||||
"""Return summary statistics for OSINT items.
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Counts of total, reviewed, and unreviewed items broken down by source type.
|
||||
"""
|
||||
# Return get_osint_summary(db)
|
||||
return get_osint_summary(db)
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/items/{item_id}/review")
|
||||
# Define function review_osint_item
|
||||
def review_osint_item(
|
||||
# Entry: item_id
|
||||
item_id: UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Mark an OSINT item as reviewed."""
|
||||
) -> dict:
|
||||
"""Mark an OSINT item as reviewed.
|
||||
|
||||
Args:
|
||||
item_id (UUID): Primary key of the OSINT item to mark reviewed.
|
||||
db (Session): SQLAlchemy database session.
|
||||
user (User): Authenticated user performing the review.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``id`` (str) and ``reviewed`` (bool ``True``).
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign item = mark_osint_reviewed(db, str(item_id))
|
||||
item = mark_osint_reviewed(db, str(item_id))
|
||||
# Check: not item
|
||||
if not item:
|
||||
# Raise HTTPException
|
||||
raise HTTPException(
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
# Keyword argument: detail
|
||||
detail="OSINT item not found",
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Return {"id": str(item.id), "reviewed": True}
|
||||
return {"id": str(item.id), "reviewed": True}
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/enrich/{technique_id}")
|
||||
# Define function trigger_technique_enrichment
|
||||
def trigger_technique_enrichment(
|
||||
# Entry: technique_id
|
||||
technique_id: UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Manually trigger OSINT enrichment for a single technique."""
|
||||
) -> dict:
|
||||
"""Manually trigger OSINT enrichment for a single technique.
|
||||
|
||||
Args:
|
||||
technique_id (UUID): Primary key of the technique to enrich.
|
||||
db (Session): SQLAlchemy database session.
|
||||
user (User): Authenticated red_lead or blue_lead requesting enrichment.
|
||||
|
||||
Returns:
|
||||
dict: Contains ``technique_id`` (str), ``mitre_id`` (str), and ``new_items`` (int).
|
||||
"""
|
||||
# Assign technique = get_technique_or_raise(db, technique_id)
|
||||
technique = get_technique_or_raise(db, technique_id)
|
||||
# Assign count = enrich_technique_with_cves(db, technique)
|
||||
count = enrich_technique_with_cves(db, technique)
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"technique_id": str(technique.id),
|
||||
# Literal argument value
|
||||
"mitre_id": technique.mitre_id,
|
||||
# Literal argument value
|
||||
"new_items": count,
|
||||
}
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/technique/{technique_id}")
|
||||
# Define function get_technique_osint
|
||||
def get_technique_osint(
|
||||
# Entry: technique_id
|
||||
technique_id: UUID,
|
||||
# Entry: source_type
|
||||
source_type: str | None = Query(None),
|
||||
# Entry: reviewed
|
||||
reviewed: bool | None = Query(None),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get all OSINT items for a specific technique."""
|
||||
) -> list:
|
||||
"""Get all OSINT items for a specific technique.
|
||||
|
||||
Args:
|
||||
technique_id (UUID): Primary key of the technique.
|
||||
source_type (str | None): Filter by source type (e.g. ``nvd_cve``).
|
||||
reviewed (bool | None): Filter by review status; ``None`` returns all.
|
||||
db (Session): SQLAlchemy database session.
|
||||
user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
list: Dicts with OSINT item fields including source URL, severity, and review status.
|
||||
"""
|
||||
# Assign items = get_osint_items_for_technique(
|
||||
items = get_osint_items_for_technique(
|
||||
db,
|
||||
str(technique_id),
|
||||
# Keyword argument: source_type
|
||||
source_type=source_type,
|
||||
# Keyword argument: reviewed
|
||||
reviewed=reviewed,
|
||||
)
|
||||
# Return [
|
||||
return [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": str(item.id),
|
||||
# Literal argument value
|
||||
"source_type": item.source_type,
|
||||
# Literal argument value
|
||||
"source_url": item.source_url,
|
||||
# Literal argument value
|
||||
"title": item.title,
|
||||
# Literal argument value
|
||||
"description": item.description,
|
||||
# Literal argument value
|
||||
"severity": item.severity,
|
||||
# Literal argument value
|
||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||
# Literal argument value
|
||||
"reviewed": item.reviewed,
|
||||
# Literal argument value
|
||||
"metadata": item.metadata_,
|
||||
}
|
||||
for item in items
|
||||
|
||||
@@ -1,118 +1,195 @@
|
||||
"""Professional report generation endpoints — PDF, DOCX, HTML output."""
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
# Import APIRouter, Depends, Query, Request from fastapi
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
# Import FileResponse from fastapi.responses
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import report_generation_service from app.services
|
||||
from app.services import report_generation_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||
router = APIRouter(prefix="/reports/generate", tags=["professional-reports"])
|
||||
|
||||
# Assign _MEDIA_TYPES = {
|
||||
_MEDIA_TYPES = {
|
||||
# Literal argument value
|
||||
"pdf": "application/pdf",
|
||||
# Literal argument value
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
# Literal argument value
|
||||
"html": "text/html",
|
||||
}
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/purple-campaign/{campaign_id}")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("5/minute")
|
||||
# Define function generate_purple_report
|
||||
def generate_purple_report(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: campaign_id
|
||||
campaign_id: UUID,
|
||||
# Entry: format
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||
):
|
||||
) -> FileResponse:
|
||||
"""Generate a Purple Team campaign assessment report."""
|
||||
# Assign filepath = report_generation_service.generate_purple_campaign_report(
|
||||
filepath = report_generation_service.generate_purple_campaign_report(
|
||||
db, str(campaign_id), output_format=format,
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
filename=f"purple_report.{format}",
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage-summary")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("5/minute")
|
||||
# Define function generate_coverage_report
|
||||
def generate_coverage_report(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: format
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||
):
|
||||
) -> FileResponse:
|
||||
"""Generate an organization-wide MITRE ATT&CK coverage report."""
|
||||
# Assign filepath = report_generation_service.generate_coverage_report(
|
||||
filepath = report_generation_service.generate_coverage_report(
|
||||
db, output_format=format,
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
filename=f"coverage_report.{format}",
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/executive-summary")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("5/minute")
|
||||
# Define function generate_executive_report
|
||||
def generate_executive_report(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: format
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||
):
|
||||
) -> FileResponse:
|
||||
"""Generate an executive security summary report."""
|
||||
# Assign filepath = report_generation_service.generate_executive_summary(
|
||||
filepath = report_generation_service.generate_executive_summary(
|
||||
db, output_format=format,
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
filename=f"executive_summary.{format}",
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/quarterly-summary")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("5/minute")
|
||||
# Define function generate_quarterly_report
|
||||
def generate_quarterly_report(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: format
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
|
||||
):
|
||||
) -> FileResponse:
|
||||
"""Generate a quarterly security summary report."""
|
||||
# Assign filepath = report_generation_service.generate_quarterly_summary(
|
||||
filepath = report_generation_service.generate_quarterly_summary(
|
||||
db, output_format=format,
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
filename=f"quarterly_summary.{format}",
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/technique/{technique_id}")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("5/minute")
|
||||
# Define function generate_technique_report
|
||||
def generate_technique_report(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: technique_id
|
||||
technique_id: UUID,
|
||||
# Entry: format
|
||||
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> FileResponse:
|
||||
"""Generate a detailed report for one MITRE technique."""
|
||||
# Assign filepath = report_generation_service.generate_technique_detail_report(
|
||||
filepath = report_generation_service.generate_technique_detail_report(
|
||||
db, str(technique_id), output_format=format,
|
||||
)
|
||||
# Return FileResponse(
|
||||
return FileResponse(
|
||||
filepath,
|
||||
# Keyword argument: media_type
|
||||
media_type=_MEDIA_TYPES[format],
|
||||
# Keyword argument: filename
|
||||
filename=f"technique_{technique_id}.{format}",
|
||||
)
|
||||
|
||||
@@ -10,18 +10,37 @@ GET /reports/test-results — test results report (JSON)
|
||||
GET /reports/remediation-status — remediation status report (JSON)
|
||||
"""
|
||||
|
||||
# Import csv
|
||||
import csv
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import StreamingResponse from fastapi.responses
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.services.coverage_report_service
|
||||
from app.services.coverage_report_service import (
|
||||
build_coverage_csv_rows,
|
||||
build_coverage_summary,
|
||||
@@ -29,61 +48,99 @@ from app.services.coverage_report_service import (
|
||||
build_test_results_report,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/reports", tags=["reports"])
|
||||
router = APIRouter(prefix="/reports", tags=["reports"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage-summary")
|
||||
# Define function coverage_summary
|
||||
def coverage_summary(
|
||||
# Entry: tactic
|
||||
tactic: Optional[str] = Query(None, description="Filter by tactic"),
|
||||
# Entry: platform
|
||||
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Full coverage report as JSON — technique-by-technique with test counts."""
|
||||
# Return build_coverage_summary(db, tactic=tactic, platform=platform)
|
||||
return build_coverage_summary(db, tactic=tactic, platform=platform)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/coverage-csv")
|
||||
# Define function coverage_csv
|
||||
def coverage_csv(
|
||||
# Entry: tactic
|
||||
tactic: Optional[str] = Query(None),
|
||||
# Entry: platform
|
||||
platform: Optional[str] = Query(None),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> StreamingResponse:
|
||||
"""Export coverage as a downloadable CSV."""
|
||||
# Assign rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
|
||||
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
|
||||
|
||||
# Assign output = io.StringIO()
|
||||
output = io.StringIO()
|
||||
# Assign writer = csv.writer(output)
|
||||
writer = csv.writer(output)
|
||||
# Iterate over rows
|
||||
for row in rows:
|
||||
# Call writer.writerow()
|
||||
writer.writerow(row)
|
||||
|
||||
# Call output.seek()
|
||||
output.seek(0)
|
||||
# Assign filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
|
||||
filename = f"aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"
|
||||
# Return StreamingResponse(
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
# Keyword argument: media_type
|
||||
media_type="text/csv",
|
||||
# Keyword argument: headers
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/test-results")
|
||||
# Define function test_results
|
||||
def test_results(
|
||||
# Entry: state
|
||||
state: Optional[str] = Query(None),
|
||||
# Entry: date_from
|
||||
date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
||||
# Entry: date_to
|
||||
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Report of test results with optional filters."""
|
||||
# Return build_test_results_report(db, state=state, date_from=date_from, dat...
|
||||
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/remediation-status")
|
||||
# Define function remediation_status
|
||||
def remediation_status(
|
||||
# Entry: status
|
||||
status: Optional[str] = Query(None, description="Filter by remediation status"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Report of remediation status across all tests."""
|
||||
# Return build_remediation_status_report(db, status=status)
|
||||
return build_remediation_status_report(db, status=status)
|
||||
|
||||
+154
-20
@@ -3,28 +3,45 @@
|
||||
Provides granular scoring with breakdowns and configurable weights.
|
||||
"""
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import BaseModel from pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.services.scoring_service import (
|
||||
score_technique_by_mitre_id,
|
||||
score_actor_by_id,
|
||||
calculate_tactic_score,
|
||||
calculate_organization_score,
|
||||
get_score_history,
|
||||
)
|
||||
|
||||
# Import from app.services.scoring_config_service
|
||||
from app.services.scoring_config_service import (
|
||||
get_weights_dict,
|
||||
update_scoring_weights,
|
||||
)
|
||||
|
||||
# Import from app.services.scoring_service
|
||||
from app.services.scoring_service import (
|
||||
calculate_tactic_score,
|
||||
get_score_history,
|
||||
score_actor_by_id,
|
||||
score_technique_by_mitre_id,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/scores", tags=["scores"])
|
||||
router = APIRouter(prefix="/scores", tags=["scores"])
|
||||
|
||||
|
||||
@@ -32,12 +49,26 @@ router = APIRouter(prefix="/scores", tags=["scores"])
|
||||
|
||||
|
||||
@router.get("/technique/{mitre_id}")
|
||||
# Define function score_technique
|
||||
def score_technique(
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed score with breakdown for a specific technique."""
|
||||
) -> dict:
|
||||
"""Get detailed score with breakdown for a specific technique.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059``).
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Score value and component breakdown (tests, detection rules, recency, etc.).
|
||||
"""
|
||||
# Return score_technique_by_mitre_id(db, mitre_id)
|
||||
return score_technique_by_mitre_id(db, mitre_id)
|
||||
|
||||
|
||||
@@ -45,12 +76,26 @@ def score_technique(
|
||||
|
||||
|
||||
@router.get("/tactic/{tactic}")
|
||||
# Define function score_tactic
|
||||
def score_tactic(
|
||||
# Entry: tactic
|
||||
tactic: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get average score for a tactic."""
|
||||
) -> dict:
|
||||
"""Get average score for a tactic.
|
||||
|
||||
Args:
|
||||
tactic (str): MITRE ATT&CK tactic slug (e.g. ``initial-access``).
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Average score and per-technique breakdown for the tactic.
|
||||
"""
|
||||
# Return calculate_tactic_score(tactic, db)
|
||||
return calculate_tactic_score(tactic, db)
|
||||
|
||||
|
||||
@@ -58,12 +103,26 @@ def score_tactic(
|
||||
|
||||
|
||||
@router.get("/threat-actor/{actor_id}")
|
||||
# Define function score_threat_actor
|
||||
def score_threat_actor(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get coverage score against a specific threat actor."""
|
||||
) -> dict:
|
||||
"""Get coverage score against a specific threat actor.
|
||||
|
||||
Args:
|
||||
actor_id (str): UUID string of the threat actor to score against.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Coverage score and per-technique breakdown for the threat actor.
|
||||
"""
|
||||
# Return score_actor_by_id(db, actor_id)
|
||||
return score_actor_by_id(db, actor_id)
|
||||
|
||||
|
||||
@@ -71,13 +130,26 @@ def score_threat_actor(
|
||||
|
||||
|
||||
@router.get("/organization")
|
||||
# Define function score_organization
|
||||
def score_organization(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get the overall organization security score (cached for 5 min)."""
|
||||
) -> dict:
|
||||
"""Get the overall organization security score (cached for 5 min).
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Aggregate organization score with tactic-level breakdowns.
|
||||
"""
|
||||
# Import get_organization_score_cached from app.services.score_cache
|
||||
from app.services.score_cache import get_organization_score_cached
|
||||
|
||||
# Return get_organization_score_cached(db)
|
||||
return get_organization_score_cached(db)
|
||||
|
||||
|
||||
@@ -85,12 +157,26 @@ def score_organization(
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
# Define function score_history
|
||||
def score_history(
|
||||
# Entry: period
|
||||
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get historical score data points (weekly)."""
|
||||
) -> dict:
|
||||
"""Get historical score data points (weekly).
|
||||
|
||||
Args:
|
||||
period (str): Time window for history — one of ``30d``, ``90d``, or ``1y``.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
dict: Weekly score data points for the requested period.
|
||||
"""
|
||||
# Return get_score_history(db, period)
|
||||
return get_score_history(db, period)
|
||||
|
||||
|
||||
@@ -98,11 +184,23 @@ def score_history(
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
# Define function get_scoring_config
|
||||
def get_scoring_config(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Get current scoring weights (admin only)."""
|
||||
) -> dict:
|
||||
"""Get current scoring weights (admin only).
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated admin user.
|
||||
|
||||
Returns:
|
||||
dict: Current weight values for each scoring component.
|
||||
"""
|
||||
# Return get_weights_dict(db)
|
||||
return get_weights_dict(db)
|
||||
|
||||
|
||||
@@ -110,41 +208,77 @@ def get_scoring_config(
|
||||
|
||||
|
||||
class ScoringConfigUpdate(BaseModel):
|
||||
"""Partial update payload for the scoring weight configuration."""
|
||||
|
||||
# Assign tests = None
|
||||
tests: Optional[float] = None
|
||||
# Assign detection_rules = None
|
||||
detection_rules: Optional[float] = None
|
||||
# Assign d3fend = None
|
||||
d3fend: Optional[float] = None
|
||||
# Assign recency = None
|
||||
recency: Optional[float] = None
|
||||
# Assign severity = None
|
||||
severity: Optional[float] = None
|
||||
# Assign freshness = None
|
||||
freshness: Optional[float] = None
|
||||
# Assign platform_diversity = None
|
||||
platform_diversity: Optional[float] = None
|
||||
|
||||
|
||||
# Apply the @router.patch decorator
|
||||
@router.patch("/config")
|
||||
# Define function update_scoring_config
|
||||
def update_scoring_config(
|
||||
# Entry: payload
|
||||
payload: ScoringConfigUpdate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Update scoring weights (admin only).
|
||||
|
||||
Weights are persisted in the database and survive restarts.
|
||||
Validation enforces that all weights are non-negative and sum to 100.
|
||||
|
||||
Args:
|
||||
payload (ScoringConfigUpdate): Partial weight update; only set fields are changed.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated admin user.
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message plus the full updated weight configuration.
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign result = update_scoring_weights(
|
||||
result = update_scoring_weights(
|
||||
db,
|
||||
# Keyword argument: tests
|
||||
tests=payload.tests,
|
||||
# Keyword argument: detection_rules
|
||||
detection_rules=payload.detection_rules,
|
||||
# Keyword argument: d3fend
|
||||
d3fend=payload.d3fend,
|
||||
# Keyword argument: recency
|
||||
recency=payload.recency,
|
||||
# Keyword argument: severity
|
||||
severity=payload.severity,
|
||||
# Keyword argument: freshness
|
||||
freshness=payload.freshness,
|
||||
# Keyword argument: platform_diversity
|
||||
platform_diversity=payload.platform_diversity,
|
||||
# Keyword argument: updated_by
|
||||
updated_by=current_user.id,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Import invalidate from app.services.score_cache
|
||||
from app.services.score_cache import invalidate
|
||||
# Call invalidate()
|
||||
invalidate()
|
||||
|
||||
# Return {"message": "Scoring config updated", **result}
|
||||
return {"message": "Scoring config updated", **result}
|
||||
|
||||
@@ -4,40 +4,71 @@ Provides periodic and manual snapshots of the organisation's coverage
|
||||
state, plus temporal comparison between any two snapshots.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import BaseModel from pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
|
||||
# Import BusinessRuleViolation from app.domain.errors
|
||||
from app.domain.errors import BusinessRuleViolation
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.services.snapshot_service import (
|
||||
create_snapshot,
|
||||
compare_snapshots,
|
||||
cleanup_old_snapshots,
|
||||
get_coverage_evolution,
|
||||
serialize_snapshot_summary,
|
||||
list_snapshots as list_snapshots_svc,
|
||||
get_snapshot_or_raise,
|
||||
get_snapshot_detail,
|
||||
delete_snapshot,
|
||||
)
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.snapshot_service
|
||||
from app.services.snapshot_service import (
|
||||
compare_snapshots,
|
||||
create_snapshot,
|
||||
delete_snapshot,
|
||||
get_coverage_evolution,
|
||||
get_snapshot_detail,
|
||||
get_snapshot_or_raise,
|
||||
serialize_snapshot_summary,
|
||||
)
|
||||
|
||||
# Import from app.services.snapshot_service
|
||||
from app.services.snapshot_service import (
|
||||
list_snapshots as list_snapshots_svc,
|
||||
)
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(prefix="/snapshots", tags=["snapshots"])
|
||||
router = APIRouter(prefix="/snapshots", tags=["snapshots"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ─────────────────────────────────────────────────
|
||||
|
||||
class SnapshotCreate(BaseModel):
|
||||
"""Payload for creating a new coverage snapshot."""
|
||||
|
||||
# Assign name = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@@ -46,13 +77,19 @@ class SnapshotCreate(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("")
|
||||
# Define function list_snapshots
|
||||
def list_snapshots(
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||
# Return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||
return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||
|
||||
|
||||
@@ -61,25 +98,39 @@ def list_snapshots(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("", status_code=201)
|
||||
# Define function create_snapshot_endpoint
|
||||
def create_snapshot_endpoint(
|
||||
# Entry: payload
|
||||
payload: SnapshotCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Create a manual coverage snapshot with an optional name."""
|
||||
# Assign snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
|
||||
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="create_snapshot",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="snapshot",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=snapshot.id,
|
||||
# Keyword argument: details
|
||||
details={"name": snapshot.name, "score": snapshot.organization_score},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return serialize_snapshot_summary(snapshot)
|
||||
return serialize_snapshot_summary(snapshot)
|
||||
|
||||
|
||||
@@ -89,12 +140,17 @@ def create_snapshot_endpoint(
|
||||
|
||||
|
||||
@router.get("/evolution")
|
||||
# Define function coverage_evolution
|
||||
def coverage_evolution(
|
||||
# Entry: months
|
||||
months: int = Query(12, ge=1, le=36),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Return coverage snapshots for trend charts (last *months* months)."""
|
||||
# Return get_coverage_evolution(db, months=months)
|
||||
return get_coverage_evolution(db, months=months)
|
||||
|
||||
|
||||
@@ -103,19 +159,30 @@ def coverage_evolution(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/compare")
|
||||
# Define function compare_snapshots_endpoint
|
||||
def compare_snapshots_endpoint(
|
||||
# Entry: a
|
||||
a: str = Query(..., description="Snapshot A ID"),
|
||||
# Entry: b
|
||||
b: str = Query(..., description="Snapshot B ID"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign a_id = uuid.UUID(a)
|
||||
a_id = uuid.UUID(a)
|
||||
# Assign b_id = uuid.UUID(b)
|
||||
b_id = uuid.UUID(b)
|
||||
# Handle ValueError
|
||||
except ValueError:
|
||||
# Raise BusinessRuleViolation
|
||||
raise BusinessRuleViolation("Invalid snapshot ID format")
|
||||
|
||||
# Return compare_snapshots(db, a_id, b_id)
|
||||
return compare_snapshots(db, a_id, b_id)
|
||||
|
||||
|
||||
@@ -124,12 +191,17 @@ def compare_snapshots_endpoint(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/{snapshot_id}")
|
||||
# Define function get_snapshot
|
||||
def get_snapshot(
|
||||
# Entry: snapshot_id
|
||||
snapshot_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Get detailed snapshot information including per-technique states."""
|
||||
# Return get_snapshot_detail(db, snapshot_id)
|
||||
return get_snapshot_detail(db, snapshot_id)
|
||||
|
||||
|
||||
@@ -138,24 +210,39 @@ def get_snapshot(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{snapshot_id}")
|
||||
# Define function delete_snapshot_endpoint
|
||||
def delete_snapshot_endpoint(
|
||||
# Entry: snapshot_id
|
||||
snapshot_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Delete a snapshot (admin only)."""
|
||||
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="delete_snapshot",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="snapshot",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=snapshot.id,
|
||||
# Keyword argument: details
|
||||
details={"name": snapshot.name},
|
||||
)
|
||||
# Call delete_snapshot()
|
||||
delete_snapshot(db, snapshot_id)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"detail": "Snapshot deleted"}
|
||||
return {"detail": "Snapshot deleted"}
|
||||
|
||||
@@ -8,6 +8,7 @@ Also exposes email configuration CRUD (admin only) that writes to the
|
||||
system_configs table so settings survive container restarts.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -22,10 +23,26 @@ from app.services.mitre_sync_service import sync_mitre
|
||||
from app.services.intel_service import scan_intel
|
||||
from app.services.atomic_import_service import import_atomic_red_team
|
||||
from app.jobs.mitre_sync_job import scheduler
|
||||
|
||||
# Import limiter from app.limiter
|
||||
from app.limiter import limiter
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import import_atomic_red_team from app.services.atomic_import_service
|
||||
from app.services.atomic_import_service import import_atomic_red_team
|
||||
|
||||
# Import scan_intel from app.services.intel_service
|
||||
from app.services.intel_service import scan_intel
|
||||
|
||||
# Import sync_mitre from app.services.mitre_sync_service
|
||||
from app.services.mitre_sync_service import sync_mitre
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(prefix="/system", tags=["system"])
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
@@ -105,8 +122,11 @@ def _bg_mitre_sync() -> None:
|
||||
|
||||
|
||||
@router.post("/sync-mitre")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("2/hour")
|
||||
# Define function trigger_mitre_sync
|
||||
def trigger_mitre_sync(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
@@ -127,11 +147,15 @@ def trigger_mitre_sync(
|
||||
}
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/run-intel-scan")
|
||||
# Define function trigger_intel_scan
|
||||
def trigger_intel_scan(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Manually trigger a threat-intelligence scan.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
@@ -139,20 +163,30 @@ def trigger_intel_scan(
|
||||
Returns a JSON object with the scan summary including the count of
|
||||
new intel items found.
|
||||
"""
|
||||
# Assign summary = scan_intel(db)
|
||||
summary = scan_intel(db)
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"message": "Intel scan completed",
|
||||
# Literal argument value
|
||||
"new_items": summary["new_items"],
|
||||
}
|
||||
|
||||
|
||||
# Apply the @router.post decorator
|
||||
@router.post("/import-atomic-tests")
|
||||
# Apply the @limiter.limit decorator
|
||||
@limiter.limit("2/hour")
|
||||
# Define function trigger_atomic_import
|
||||
def trigger_atomic_import(
|
||||
# Entry: request
|
||||
request: Request,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Trigger an import of Atomic Red Team tests as TestTemplates.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
@@ -163,37 +197,58 @@ def trigger_atomic_import(
|
||||
|
||||
Returns a JSON object with import statistics.
|
||||
"""
|
||||
# Attempt the following; catch errors below
|
||||
try:
|
||||
# Assign summary = import_atomic_red_team(db)
|
||||
summary = import_atomic_red_team(db)
|
||||
# Handle Exception
|
||||
except Exception as exc:
|
||||
# Log error: "Atomic Red Team import failed: %s", exc, exc_info
|
||||
logger.error("Atomic Red Team import failed: %s", exc, exc_info=True)
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"message": "Import failed. Check server logs for details.",
|
||||
}
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"message": "Import completed",
|
||||
# Literal argument value
|
||||
"imported": summary["created"],
|
||||
# Literal argument value
|
||||
"skipped": summary["skipped_existing"],
|
||||
# Literal argument value
|
||||
"total_parsed": summary["total_tests_parsed"],
|
||||
}
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/scheduler-status")
|
||||
# Define function scheduler_status
|
||||
def scheduler_status(
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> dict:
|
||||
"""Return the current state of the background scheduler.
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
# Assign jobs = scheduler.get_jobs()
|
||||
jobs = scheduler.get_jobs()
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"running": scheduler.running,
|
||||
# Literal argument value
|
||||
"jobs": [
|
||||
{
|
||||
# Literal argument value
|
||||
"id": job.id,
|
||||
# Literal argument value
|
||||
"name": job.name,
|
||||
# Literal argument value
|
||||
"next_run_time": str(job.next_run_time) if job.next_run_time else None,
|
||||
}
|
||||
for job in jobs
|
||||
|
||||
@@ -5,29 +5,56 @@ for error signaling. The error_handler middleware maps domain
|
||||
exceptions to HTTP responses automatically.
|
||||
"""
|
||||
|
||||
# Import APIRouter, Depends, Query, status from fastapi
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||
|
||||
# Import get_current_user, require_any_role, require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
|
||||
# Import get_technique_repository from app.dependencies.repositories
|
||||
from app.dependencies.repositories import get_technique_repository
|
||||
|
||||
# Import TechniqueEntity from app.domain.entities.technique
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
from app.domain.errors import DuplicateEntityError, EntityNotFoundError
|
||||
|
||||
# Import TechniqueStatus from app.domain.enums
|
||||
from app.domain.enums import TechniqueStatus
|
||||
|
||||
# Import DuplicateEntityError, EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import DuplicateEntityError, EntityNotFoundError
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import from app.infrastructure.persistence.repositories.sa_technique_repository
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.schemas.technique
|
||||
from app.schemas.technique import (
|
||||
TechniqueCreate,
|
||||
TechniqueOut,
|
||||
TechniqueSummary,
|
||||
TechniqueUpdate,
|
||||
)
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import get_technique_detail from app.services.technique_query_service
|
||||
from app.services.technique_query_service import get_technique_detail
|
||||
|
||||
# Assign router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||
|
||||
|
||||
@@ -37,19 +64,29 @@ router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[TechniqueSummary])
|
||||
# Define function list_techniques
|
||||
def list_techniques(
|
||||
# Entry: tactic
|
||||
tactic: str | None = Query(None, description="Filter by tactic name"),
|
||||
# Entry: status_global
|
||||
status_global: TechniqueStatus | None = Query(
|
||||
None, alias="status", description="Filter by global status"
|
||||
),
|
||||
# Entry: review_required
|
||||
review_required: bool | None = Query(None, description="Filter by review flag"),
|
||||
# Entry: repo
|
||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Return a lightweight list of techniques, optionally filtered."""
|
||||
# Return repo.list_all(
|
||||
return repo.list_all(
|
||||
# Keyword argument: tactic
|
||||
tactic=tactic,
|
||||
# Keyword argument: status
|
||||
status=status_global,
|
||||
# Keyword argument: review_required
|
||||
review_required=review_required,
|
||||
)
|
||||
|
||||
@@ -60,12 +97,17 @@ def list_techniques(
|
||||
|
||||
|
||||
@router.get("/{mitre_id}")
|
||||
# Define function get_technique
|
||||
def get_technique(
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
||||
# Return get_technique_detail(db, mitre_id)
|
||||
return get_technique_detail(db, mitre_id)
|
||||
|
||||
|
||||
@@ -75,40 +117,66 @@ def get_technique(
|
||||
|
||||
|
||||
@router.post(
|
||||
# Literal argument value
|
||||
"",
|
||||
# Keyword argument: response_model
|
||||
response_model=TechniqueOut,
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
# Define function create_technique
|
||||
def create_technique(
|
||||
# Entry: payload
|
||||
payload: TechniqueCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: repo
|
||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> TechniqueOut:
|
||||
"""Create a new technique manually."""
|
||||
# Check: repo.exists_by_mitre_id(payload.mitre_id)
|
||||
if repo.exists_by_mitre_id(payload.mitre_id):
|
||||
# Raise DuplicateEntityError
|
||||
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
|
||||
|
||||
# Assign entity = TechniqueEntity.create(
|
||||
entity = TechniqueEntity.create(
|
||||
# Keyword argument: mitre_id
|
||||
mitre_id=payload.mitre_id,
|
||||
# Keyword argument: name
|
||||
name=payload.name,
|
||||
# Keyword argument: description
|
||||
description=payload.description,
|
||||
# Keyword argument: tactic
|
||||
tactic=payload.tactic,
|
||||
# Keyword argument: platforms
|
||||
platforms=payload.platforms,
|
||||
)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign saved = repo.save(entity)
|
||||
saved = repo.save(entity)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="create_technique",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="technique",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=saved.id,
|
||||
# Keyword argument: details
|
||||
details={"mitre_id": saved.mitre_id, "name": saved.name},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return saved
|
||||
return saved
|
||||
|
||||
|
||||
@@ -118,34 +186,56 @@ def create_technique(
|
||||
|
||||
|
||||
@router.patch("/{mitre_id}", response_model=TechniqueOut)
|
||||
# Define function update_technique
|
||||
def update_technique(
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: payload
|
||||
payload: TechniqueUpdate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: repo
|
||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
) -> TechniqueOut:
|
||||
"""Update one or more fields of an existing technique."""
|
||||
# Assign entity = repo.find_by_mitre_id(mitre_id)
|
||||
entity = repo.find_by_mitre_id(mitre_id)
|
||||
# Check: entity is None
|
||||
if entity is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
|
||||
# Assign update_data = payload.model_dump(exclude_unset=True)
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
# Iterate over update_data.items()
|
||||
for field, value in update_data.items():
|
||||
# Call setattr()
|
||||
setattr(entity, field, value)
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign saved = repo.save(entity)
|
||||
saved = repo.save(entity)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="update_technique",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="technique",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=saved.id,
|
||||
# Keyword argument: details
|
||||
details={"mitre_id": mitre_id, "updated_fields": list(update_data.keys())},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return saved
|
||||
return saved
|
||||
|
||||
|
||||
@@ -155,33 +245,52 @@ def update_technique(
|
||||
|
||||
|
||||
@router.patch("/{mitre_id}/review", response_model=TechniqueOut)
|
||||
# Define function review_technique
|
||||
def review_technique(
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: repo
|
||||
repo: SATechniqueRepository = Depends(get_technique_repository),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
) -> TechniqueOut:
|
||||
"""Mark a technique as reviewed.
|
||||
|
||||
Sets ``review_required`` to *False* and records the current timestamp
|
||||
in ``last_review_date``.
|
||||
"""
|
||||
# Assign entity = repo.find_by_mitre_id(mitre_id)
|
||||
entity = repo.find_by_mitre_id(mitre_id)
|
||||
# Check: entity is None
|
||||
if entity is None:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
|
||||
# Call entity.mark_reviewed()
|
||||
entity.mark_reviewed()
|
||||
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign saved = repo.save(entity)
|
||||
saved = repo.save(entity)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="review_technique",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="technique",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=saved.id,
|
||||
# Keyword argument: details
|
||||
details={"mitre_id": mitre_id},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return saved
|
||||
return saved
|
||||
|
||||
@@ -22,35 +22,69 @@ Filters (GET /test-templates)
|
||||
- offset / limit: pagination (default limit=50)
|
||||
"""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query, status from fastapi
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.technique import Technique
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.schemas.test_template
|
||||
from app.schemas.test_template import (
|
||||
TestTemplateCreate,
|
||||
TestTemplateOut,
|
||||
TestTemplateSummary,
|
||||
)
|
||||
|
||||
# Import log_action from app.services.audit_service
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.test_template_service
|
||||
from app.services.test_template_service import (
|
||||
bulk_activate,
|
||||
create_template as create_template_svc,
|
||||
get_template_or_raise,
|
||||
get_template_stats,
|
||||
get_templates_by_technique as templates_by_technique,
|
||||
list_templates,
|
||||
soft_delete_template,
|
||||
)
|
||||
|
||||
# Import from app.services.test_template_service
|
||||
from app.services.test_template_service import (
|
||||
create_template as create_template_svc,
|
||||
)
|
||||
|
||||
# Import from app.services.test_template_service
|
||||
from app.services.test_template_service import (
|
||||
get_templates_by_technique as templates_by_technique,
|
||||
)
|
||||
|
||||
# Import from app.services.test_template_service
|
||||
from app.services.test_template_service import (
|
||||
toggle_template_active as toggle_template_active_svc,
|
||||
)
|
||||
|
||||
# Import from app.services.test_template_service
|
||||
from app.services.test_template_service import (
|
||||
update_template as update_template_svc,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||
|
||||
|
||||
@@ -60,28 +94,64 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[TestTemplateSummary])
|
||||
# Define function _list_templates_handler
|
||||
def _list_templates_handler(
|
||||
# Entry: source
|
||||
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
|
||||
# Entry: platform
|
||||
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
|
||||
# Entry: severity
|
||||
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
|
||||
# Entry: mitre_technique_id
|
||||
mitre_technique_id: Optional[str] = Query(None, description="Filter by MITRE technique ID"),
|
||||
# Entry: search
|
||||
search: Optional[str] = Query(None, description="Search in name and description"),
|
||||
# Entry: is_active
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status (true/false). Omit to return all."),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return a paginated, filterable list of test templates."""
|
||||
) -> list:
|
||||
"""Return a paginated, filterable list of test templates.
|
||||
|
||||
Args:
|
||||
source (Optional[str]): Filter by source (``atomic_red_team``, ``mitre``, ``custom``).
|
||||
platform (Optional[str]): Filter by platform (``windows``, ``linux``, ``macos``).
|
||||
severity (Optional[str]): Filter by severity (``low``, ``medium``, ``high``, ``critical``).
|
||||
mitre_technique_id (Optional[str]): Filter by MITRE technique ID string.
|
||||
search (Optional[str]): Full-text search across name and description.
|
||||
is_active (Optional[bool]): Filter by active status; omit to return all.
|
||||
offset (int): Number of records to skip for pagination.
|
||||
limit (int): Maximum number of records to return.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
list: Serialised list of :class:`TestTemplateSummary` objects.
|
||||
"""
|
||||
# Return list_templates(
|
||||
return list_templates(
|
||||
db,
|
||||
# Keyword argument: source
|
||||
source=source,
|
||||
# Keyword argument: platform
|
||||
platform=platform,
|
||||
# Keyword argument: severity
|
||||
severity=severity,
|
||||
# Keyword argument: mitre_technique_id
|
||||
mitre_technique_id=mitre_technique_id,
|
||||
# Keyword argument: search
|
||||
search=search,
|
||||
# Keyword argument: is_active
|
||||
is_active=is_active,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -92,11 +162,23 @@ def _list_templates_handler(
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
# Define function template_stats
|
||||
def template_stats(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Return catalog statistics: active, by_source, by_platform."""
|
||||
) -> dict:
|
||||
"""Return catalog statistics: active, by_source, by_platform.
|
||||
|
||||
Args:
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead.
|
||||
|
||||
Returns:
|
||||
dict: Counts of active templates broken down by source and platform.
|
||||
"""
|
||||
# Return get_template_stats(db)
|
||||
return get_template_stats(db)
|
||||
|
||||
|
||||
@@ -106,27 +188,53 @@ def template_stats(
|
||||
|
||||
|
||||
@router.patch("/bulk-activate")
|
||||
# Define function bulk_activate_templates
|
||||
def bulk_activate_templates(
|
||||
# Entry: activate
|
||||
activate: bool = Query(True, description="True to activate all, False to deactivate all"),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Set all templates to active or inactive."""
|
||||
) -> dict:
|
||||
"""Set all templates to active or inactive.
|
||||
|
||||
Args:
|
||||
activate (bool): ``True`` to activate all templates, ``False`` to deactivate all.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead.
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with ``affected`` count and the applied ``is_active`` flag.
|
||||
"""
|
||||
# Assign count = bulk_activate(db, activate=activate)
|
||||
count = bulk_activate(db, activate=activate)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="bulk_activate_templates" if activate else "bulk_deactivate_templates",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=None,
|
||||
# Keyword argument: details
|
||||
details={"affected": count, "is_active": activate},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
|
||||
# Literal argument value
|
||||
"affected": count,
|
||||
# Literal argument value
|
||||
"is_active": activate,
|
||||
}
|
||||
|
||||
@@ -137,12 +245,26 @@ def bulk_activate_templates(
|
||||
|
||||
|
||||
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
|
||||
# Define function _templates_by_technique_handler
|
||||
def _templates_by_technique_handler(
|
||||
# Entry: mitre_id
|
||||
mitre_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return all active templates mapped to a specific MITRE technique."""
|
||||
) -> list:
|
||||
"""Return all active templates mapped to a specific MITRE technique.
|
||||
|
||||
Args:
|
||||
mitre_id (str): MITRE ATT&CK technique ID (e.g. ``T1059.001``).
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
list: Serialised list of :class:`TestTemplateSummary` objects for the technique.
|
||||
"""
|
||||
# Return templates_by_technique(db, mitre_id)
|
||||
return templates_by_technique(db, mitre_id)
|
||||
|
||||
|
||||
@@ -152,12 +274,26 @@ def _templates_by_technique_handler(
|
||||
|
||||
|
||||
@router.get("/{template_id}", response_model=TestTemplateOut)
|
||||
# Define function get_template
|
||||
def get_template(
|
||||
# Entry: template_id
|
||||
template_id: uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return full details for a single test template."""
|
||||
) -> TestTemplateOut:
|
||||
"""Return full details for a single test template.
|
||||
|
||||
Args:
|
||||
template_id (uuid.UUID): Primary key of the template to retrieve.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated user making the request.
|
||||
|
||||
Returns:
|
||||
TestTemplateOut: Full template detail including all fields.
|
||||
"""
|
||||
# Return get_template_or_raise(db, template_id)
|
||||
return get_template_or_raise(db, template_id)
|
||||
|
||||
|
||||
@@ -167,17 +303,35 @@ def get_template(
|
||||
|
||||
|
||||
@router.post(
|
||||
# Literal argument value
|
||||
"",
|
||||
# Keyword argument: response_model
|
||||
response_model=TestTemplateOut,
|
||||
# Keyword argument: status_code
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
# Define function create_template
|
||||
def create_template(
|
||||
# Entry: payload
|
||||
payload: TestTemplateCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a custom test template."""
|
||||
) -> TestTemplateOut:
|
||||
"""Create a custom test template.
|
||||
|
||||
Args:
|
||||
payload (TestTemplateCreate): All fields for the new template.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead creating the template.
|
||||
|
||||
Returns:
|
||||
TestTemplateOut: The newly created template with all fields populated.
|
||||
"""
|
||||
# Assign template = create_template_svc(db, **payload.model_dump())
|
||||
template = create_template_svc(db, **payload.model_dump())
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Flag the associated technique for review — new template available
|
||||
if template.mitre_technique_id:
|
||||
@@ -190,19 +344,30 @@ def create_template(
|
||||
technique.review_required = True
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="create_test_template",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=template.id,
|
||||
# Keyword argument: details
|
||||
details={
|
||||
# Literal argument value
|
||||
"name": template.name,
|
||||
# Literal argument value
|
||||
"source": template.source,
|
||||
# Literal argument value
|
||||
"mitre_technique_id": template.mitre_technique_id,
|
||||
},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(template)
|
||||
|
||||
# Return template
|
||||
return template
|
||||
|
||||
|
||||
@@ -212,26 +377,52 @@ def create_template(
|
||||
|
||||
|
||||
@router.patch("/{template_id}", response_model=TestTemplateOut)
|
||||
# Define function update_template
|
||||
def update_template(
|
||||
# Entry: template_id
|
||||
template_id: uuid.UUID,
|
||||
# Entry: payload
|
||||
payload: TestTemplateCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Update fields of an existing test template."""
|
||||
) -> TestTemplateOut:
|
||||
"""Update fields of an existing test template.
|
||||
|
||||
Args:
|
||||
template_id (uuid.UUID): Primary key of the template to update.
|
||||
payload (TestTemplateCreate): Fields to update; only set fields are applied.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead updating the template.
|
||||
|
||||
Returns:
|
||||
TestTemplateOut: The updated template with refreshed field values.
|
||||
"""
|
||||
# Assign template = update_template_svc(db, template_id, **payload.model_dump(exclude_u...
|
||||
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="update_test_template",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=template.id,
|
||||
# Keyword argument: details
|
||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(template)
|
||||
|
||||
# Return template
|
||||
return template
|
||||
|
||||
|
||||
@@ -241,25 +432,49 @@ def update_template(
|
||||
|
||||
|
||||
@router.patch("/{template_id}/toggle-active", response_model=TestTemplateOut)
|
||||
# Define function toggle_template_active
|
||||
def toggle_template_active(
|
||||
# Entry: template_id
|
||||
template_id: uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Toggle a template between active and inactive (is_active = not is_active)."""
|
||||
) -> TestTemplateOut:
|
||||
"""Toggle a template between active and inactive (is_active = not is_active).
|
||||
|
||||
Args:
|
||||
template_id (uuid.UUID): Primary key of the template to toggle.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead.
|
||||
|
||||
Returns:
|
||||
TestTemplateOut: The template with the updated ``is_active`` flag.
|
||||
"""
|
||||
# Assign template = toggle_template_active_svc(db, template_id)
|
||||
template = toggle_template_active_svc(db, template_id)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="toggle_test_template",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=template.id,
|
||||
# Keyword argument: details
|
||||
details={"name": template.name, "is_active": template.is_active},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(template)
|
||||
|
||||
# Return template
|
||||
return template
|
||||
|
||||
|
||||
@@ -269,23 +484,47 @@ def toggle_template_active(
|
||||
|
||||
|
||||
@router.delete("/{template_id}", status_code=status.HTTP_200_OK)
|
||||
# Define function delete_template
|
||||
def delete_template(
|
||||
# Entry: template_id
|
||||
template_id: uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Soft-delete a test template by setting ``is_active=False``."""
|
||||
) -> dict:
|
||||
"""Soft-delete a test template by setting ``is_active=False``.
|
||||
|
||||
Args:
|
||||
template_id (uuid.UUID): Primary key of the template to delete.
|
||||
db (Session): SQLAlchemy database session.
|
||||
current_user (User): Authenticated red_lead or blue_lead.
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with key ``detail``.
|
||||
"""
|
||||
# Assign template = get_template_or_raise(db, template_id)
|
||||
template = get_template_or_raise(db, template_id)
|
||||
# Call soft_delete_template()
|
||||
soft_delete_template(db, template_id)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="delete_test_template",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="test_template",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=template.id,
|
||||
# Keyword argument: details
|
||||
details={"name": template.name},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
|
||||
# Return {"detail": "Test template deactivated"}
|
||||
return {"detail": "Test template deactivated"}
|
||||
|
||||
+590
-40
File diff suppressed because it is too large
Load Diff
@@ -4,15 +4,28 @@ Provides listing, detail, coverage analysis, and gap analysis for
|
||||
threat actor profiles imported from MITRE CTI.
|
||||
"""
|
||||
|
||||
# Import logging
|
||||
import logging
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import APIRouter, Depends, Query from fastapi
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import from app.services.threat_actor_service
|
||||
from app.services.threat_actor_service import (
|
||||
get_actor_coverage,
|
||||
get_actor_detail,
|
||||
@@ -20,58 +33,90 @@ from app.services.threat_actor_service import (
|
||||
list_actors,
|
||||
)
|
||||
|
||||
# Assign logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Assign router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
|
||||
router = APIRouter(prefix="/threat-actors", tags=["threat-actors"])
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("")
|
||||
# Define function list_threat_actors
|
||||
def list_threat_actors(
|
||||
# Entry: search
|
||||
search: Optional[str] = Query(None),
|
||||
# Entry: country
|
||||
country: Optional[str] = Query(None),
|
||||
# Entry: motivation
|
||||
motivation: Optional[str] = Query(None),
|
||||
# Entry: sophistication
|
||||
sophistication: Optional[str] = Query(None),
|
||||
# Entry: target_sectors
|
||||
target_sectors: Optional[str] = Query(None),
|
||||
# Entry: offset
|
||||
offset: int = Query(0, ge=0),
|
||||
# Entry: limit
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""List threat actors with optional filters and pagination.
|
||||
|
||||
**Requires** authentication (any role).
|
||||
"""
|
||||
# Return list_actors(
|
||||
return list_actors(
|
||||
db,
|
||||
# Keyword argument: search
|
||||
search=search,
|
||||
# Keyword argument: country
|
||||
country=country,
|
||||
# Keyword argument: motivation
|
||||
motivation=motivation,
|
||||
# Keyword argument: sophistication
|
||||
sophistication=sophistication,
|
||||
# Keyword argument: target_sectors
|
||||
target_sectors=target_sectors,
|
||||
# Keyword argument: offset
|
||||
offset=offset,
|
||||
# Keyword argument: limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/{actor_id}")
|
||||
# Define function get_threat_actor
|
||||
def get_threat_actor(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Get detailed info about a threat actor including techniques.
|
||||
|
||||
**Requires** authentication (any role).
|
||||
"""
|
||||
# Return get_actor_detail(db, actor_id)
|
||||
return get_actor_detail(db, actor_id)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/{actor_id}/coverage")
|
||||
# Define function get_threat_actor_coverage
|
||||
def get_threat_actor_coverage(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> dict:
|
||||
"""Calculate coverage percentage against a specific threat actor.
|
||||
|
||||
**Requires** authentication (any role).
|
||||
@@ -79,19 +124,26 @@ def get_threat_actor_coverage(
|
||||
Returns the percentage of the actor's techniques that have been
|
||||
validated or partially validated, along with a breakdown.
|
||||
"""
|
||||
# Return get_actor_coverage(db, actor_id)
|
||||
return get_actor_coverage(db, actor_id)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/{actor_id}/gaps")
|
||||
# Define function get_threat_actor_gaps
|
||||
def get_threat_actor_gaps(
|
||||
# Entry: actor_id
|
||||
actor_id: str,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> list:
|
||||
"""Identify techniques of this actor that are NOT fully validated.
|
||||
|
||||
**Requires** authentication (any role).
|
||||
|
||||
Returns list of gap techniques with available templates.
|
||||
"""
|
||||
# Return get_actor_gaps(db, actor_id)
|
||||
return get_actor_gaps(db, actor_id)
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
"""User management router (admin only)."""
|
||||
|
||||
# Import uuid
|
||||
import uuid
|
||||
|
||||
# Import APIRouter, Depends, status from fastapi
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import require_role from app.dependencies.auth
|
||||
from app.dependencies.auth import require_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserOut, UserPreferencesUpdate
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
# Import from app.services.user_service
|
||||
from app.services.user_service import (
|
||||
create_user,
|
||||
get_user_or_raise,
|
||||
@@ -19,6 +32,7 @@ from app.services.user_service import (
|
||||
update_user,
|
||||
)
|
||||
|
||||
# Assign router = APIRouter(prefix="/users", tags=["users"])
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@@ -69,11 +83,15 @@ def get_me(
|
||||
|
||||
|
||||
@router.get("", response_model=list[UserOut])
|
||||
# Define function list_users_route
|
||||
def list_users_route(
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of all users. **Requires admin role.**"""
|
||||
) -> list[UserOut]:
|
||||
"""Return a list of all users. **Requires admin role.**."""
|
||||
# Return list_users(db)
|
||||
return list_users(db)
|
||||
|
||||
|
||||
@@ -83,31 +101,50 @@ def list_users_route(
|
||||
|
||||
|
||||
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
# Define function create_user_route
|
||||
def create_user_route(
|
||||
# Entry: payload
|
||||
payload: UserCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Create a new user. **Requires admin role.**"""
|
||||
) -> UserOut:
|
||||
"""Create a new user. **Requires admin role.**."""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign user = create_user(
|
||||
user = create_user(
|
||||
db,
|
||||
# Keyword argument: username
|
||||
username=payload.username,
|
||||
# Keyword argument: email
|
||||
email=payload.email,
|
||||
# Keyword argument: password
|
||||
password=payload.password,
|
||||
# Keyword argument: role
|
||||
role=payload.role,
|
||||
)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="create_user",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="user",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=user.id,
|
||||
# Keyword argument: details
|
||||
details={"username": user.username, "role": user.role},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(user)
|
||||
|
||||
# Return user
|
||||
return user
|
||||
|
||||
|
||||
@@ -117,12 +154,17 @@ def create_user_route(
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserOut)
|
||||
# Define function get_user
|
||||
def get_user(
|
||||
# Entry: user_id
|
||||
user_id: uuid.UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a single user by ID. **Requires admin role.**"""
|
||||
) -> UserOut:
|
||||
"""Return a single user by ID. **Requires admin role.**."""
|
||||
# Return get_user_or_raise(db, user_id)
|
||||
return get_user_or_raise(db, user_id)
|
||||
|
||||
|
||||
@@ -132,25 +174,42 @@ def get_user(
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserOut)
|
||||
# Define function update_user_route
|
||||
def update_user_route(
|
||||
# Entry: user_id
|
||||
user_id: uuid.UUID,
|
||||
# Entry: payload
|
||||
payload: UserUpdate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: current_user
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
||||
) -> UserOut:
|
||||
"""Update one or more fields of an existing user. **Requires admin role.**."""
|
||||
# Assign update_data = payload.model_dump(exclude_unset=True)
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign user = update_user(db, user_id, **update_data)
|
||||
user = update_user(db, user_id, **update_data)
|
||||
# Call log_action()
|
||||
log_action(
|
||||
db,
|
||||
# Keyword argument: user_id
|
||||
user_id=current_user.id,
|
||||
# Keyword argument: action
|
||||
action="update_user",
|
||||
# Keyword argument: entity_type
|
||||
entity_type="user",
|
||||
# Keyword argument: entity_id
|
||||
entity_id=user.id,
|
||||
# Keyword argument: details
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(user)
|
||||
|
||||
# Return user
|
||||
return user
|
||||
|
||||
@@ -1,19 +1,39 @@
|
||||
"""Worklog router — internal time-tracking records with integrity verification."""
|
||||
|
||||
# Import datetime from datetime
|
||||
from datetime import datetime
|
||||
|
||||
# Import Optional from typing
|
||||
from typing import Optional
|
||||
|
||||
# Import UUID from uuid
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
# Import APIRouter, Depends from fastapi
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
# Import BaseModel, Field from pydantic
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import get_db from app.database
|
||||
from app.database import get_db
|
||||
|
||||
# Import get_current_user, require_any_role from app.dependencies.auth
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
|
||||
# Import UnitOfWork from app.domain.unit_of_work
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
|
||||
# Import User from app.models.user
|
||||
from app.models.user import User
|
||||
|
||||
# Import worklog_service from app.services
|
||||
from app.services import worklog_service
|
||||
|
||||
# Assign router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||
|
||||
|
||||
@@ -21,30 +41,58 @@ router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||
|
||||
|
||||
class WorklogCreate(BaseModel):
|
||||
"""Payload for logging a work session against an entity."""
|
||||
|
||||
# Assign entity_type = Field(..., max_length=50)
|
||||
entity_type: str = Field(..., max_length=50)
|
||||
# entity_id: UUID
|
||||
entity_id: UUID
|
||||
# Assign activity_type = Field(..., max_length=100)
|
||||
activity_type: str = Field(..., max_length=100)
|
||||
# started_at: datetime
|
||||
started_at: datetime
|
||||
# Assign ended_at = None
|
||||
ended_at: Optional[datetime] = None
|
||||
# Assign duration_seconds = Field(..., gt=0)
|
||||
duration_seconds: int = Field(..., gt=0)
|
||||
# Assign description = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# Define class WorklogOut
|
||||
class WorklogOut(BaseModel):
|
||||
"""Serialized worklog entry returned by the API."""
|
||||
|
||||
# id: UUID
|
||||
id: UUID
|
||||
# entity_type: str
|
||||
entity_type: str
|
||||
# entity_id: UUID
|
||||
entity_id: UUID
|
||||
# user_id: UUID
|
||||
user_id: UUID
|
||||
# activity_type: str
|
||||
activity_type: str
|
||||
# started_at: datetime
|
||||
started_at: datetime
|
||||
# Assign ended_at = None
|
||||
ended_at: Optional[datetime] = None
|
||||
# duration_seconds: int
|
||||
duration_seconds: int
|
||||
# Assign description = None
|
||||
description: Optional[str] = None
|
||||
# Assign tempo_synced = None
|
||||
tempo_synced: Optional[datetime] = None
|
||||
# Assign integrity_hash = None
|
||||
integrity_hash: Optional[str] = None
|
||||
# created_at: datetime
|
||||
created_at: datetime
|
||||
|
||||
# Define class Config
|
||||
class Config:
|
||||
"""ORM mode configuration for SQLAlchemy model mapping."""
|
||||
|
||||
# Assign from_attributes = True
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@@ -52,65 +100,146 @@ class WorklogOut(BaseModel):
|
||||
|
||||
|
||||
@router.post("", response_model=WorklogOut, status_code=201)
|
||||
# Define function create
|
||||
def create(
|
||||
# Entry: body
|
||||
body: WorklogCreate,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: user
|
||||
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a manually-logged worklog entry."""
|
||||
) -> WorklogOut:
|
||||
"""Create a manually-logged worklog entry.
|
||||
|
||||
Args:
|
||||
body (WorklogCreate): Worklog fields including entity, activity type, and duration.
|
||||
db (Session): SQLAlchemy database session.
|
||||
user (User): Authenticated team member creating the worklog.
|
||||
|
||||
Returns:
|
||||
WorklogOut: The newly created worklog with integrity hash and all fields.
|
||||
"""
|
||||
# Open context manager
|
||||
with UnitOfWork(db) as uow:
|
||||
# Assign wl = worklog_service.create_worklog(
|
||||
wl = worklog_service.create_worklog(
|
||||
db,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=body.entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=body.entity_id,
|
||||
# Keyword argument: user_id
|
||||
user_id=user.id,
|
||||
# Keyword argument: activity_type
|
||||
activity_type=body.activity_type,
|
||||
# Keyword argument: started_at
|
||||
started_at=body.started_at,
|
||||
# Keyword argument: ended_at
|
||||
ended_at=body.ended_at,
|
||||
# Keyword argument: duration_seconds
|
||||
duration_seconds=body.duration_seconds,
|
||||
# Keyword argument: description
|
||||
description=body.description,
|
||||
)
|
||||
# Call uow.commit()
|
||||
uow.commit()
|
||||
# Reload ORM object attributes from the database
|
||||
db.refresh(wl)
|
||||
# Return wl
|
||||
return wl
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("", response_model=list[WorklogOut])
|
||||
# Define function list_all
|
||||
def list_all(
|
||||
# Entry: entity_type
|
||||
entity_type: Optional[str] = None,
|
||||
# Entry: entity_id
|
||||
entity_id: Optional[UUID] = None,
|
||||
# Entry: user_id
|
||||
user_id: Optional[UUID] = None,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: _user
|
||||
_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List worklogs with optional filters."""
|
||||
) -> list[WorklogOut]:
|
||||
"""List worklogs with optional filters.
|
||||
|
||||
Args:
|
||||
entity_type (Optional[str]): Filter by entity type (e.g. ``test``, ``campaign``).
|
||||
entity_id (Optional[UUID]): Filter by the UUID of the associated entity.
|
||||
user_id (Optional[UUID]): Filter by the UUID of the worklog author.
|
||||
db (Session): SQLAlchemy database session.
|
||||
_user (User): Authenticated user making the request (unused, enforces auth).
|
||||
|
||||
Returns:
|
||||
list[WorklogOut]: Serialised list of worklog entries matching the filters.
|
||||
"""
|
||||
# Return worklog_service.list_worklogs(
|
||||
return worklog_service.list_worklogs(
|
||||
db,
|
||||
# Keyword argument: entity_type
|
||||
entity_type=entity_type,
|
||||
# Keyword argument: entity_id
|
||||
entity_id=entity_id,
|
||||
# Keyword argument: user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/{worklog_id}", response_model=WorklogOut)
|
||||
# Define function get_one
|
||||
def get_one(
|
||||
# Entry: worklog_id
|
||||
worklog_id: UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: _user
|
||||
_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a single worklog by ID."""
|
||||
) -> WorklogOut:
|
||||
"""Get a single worklog by ID.
|
||||
|
||||
Args:
|
||||
worklog_id (UUID): Primary key of the worklog to retrieve.
|
||||
db (Session): SQLAlchemy database session.
|
||||
_user (User): Authenticated user making the request (unused, enforces auth).
|
||||
|
||||
Returns:
|
||||
WorklogOut: Full worklog detail including integrity hash.
|
||||
"""
|
||||
# Return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
|
||||
|
||||
# Apply the @router.get decorator
|
||||
@router.get("/{worklog_id}/verify")
|
||||
# Define function verify_integrity
|
||||
def verify_integrity(
|
||||
# Entry: worklog_id
|
||||
worklog_id: UUID,
|
||||
# Entry: db
|
||||
db: Session = Depends(get_db),
|
||||
# Entry: _user
|
||||
_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Check whether a worklog's integrity hash is still valid."""
|
||||
) -> dict:
|
||||
"""Check whether a worklog's integrity hash is still valid.
|
||||
|
||||
Args:
|
||||
worklog_id (UUID): Primary key of the worklog to verify.
|
||||
db (Session): SQLAlchemy database session.
|
||||
_user (User): Authenticated user making the request (unused, enforces auth).
|
||||
|
||||
Returns:
|
||||
dict: Contains ``worklog_id`` (str) and ``integrity_valid`` (bool).
|
||||
"""
|
||||
# Assign wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"worklog_id": str(wl.id),
|
||||
# Literal argument value
|
||||
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user