feat: add Campaign/Compliance domain entities and extract users/audit/data_sources to services (LP-2 through LP-6)
This commit is contained in:
@@ -4,14 +4,17 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.user import User
|
||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||
from app.services.audit_query_service import (
|
||||
list_distinct_actions,
|
||||
list_distinct_entity_types,
|
||||
list_logs,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||
|
||||
@@ -29,56 +32,25 @@ def list_audit_logs(
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return paginated audit logs with optional filters.
|
||||
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
query = query.filter(AuditLog.action == action)
|
||||
if entity_type:
|
||||
query = query.filter(AuditLog.entity_type == entity_type)
|
||||
if start_date:
|
||||
query = query.filter(AuditLog.timestamp >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AuditLog.timestamp <= end_date)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get paginated results
|
||||
logs = (
|
||||
query
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Convert to response format with username
|
||||
items = []
|
||||
for log in logs:
|
||||
item = AuditLogOut(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.user.username if log.user else None,
|
||||
action=log.action,
|
||||
entity_type=log.entity_type,
|
||||
entity_id=log.entity_id,
|
||||
timestamp=log.timestamp,
|
||||
details=log.details,
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return AuditLogPage(
|
||||
items=items,
|
||||
total=total,
|
||||
result = list_logs(
|
||||
db,
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
return AuditLogPage(
|
||||
items=[AuditLogOut(**item) for item in result["items"]],
|
||||
total=result["total"],
|
||||
offset=result["offset"],
|
||||
limit=result["limit"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/actions", response_model=list[str])
|
||||
@@ -87,16 +59,10 @@ def list_actions(
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of distinct action types in the audit log.
|
||||
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
actions = (
|
||||
db.query(AuditLog.action)
|
||||
.distinct()
|
||||
.order_by(AuditLog.action)
|
||||
.all()
|
||||
)
|
||||
return [a[0] for a in actions]
|
||||
return list_distinct_actions(db)
|
||||
|
||||
|
||||
@router.get("/entity-types", response_model=list[str])
|
||||
@@ -105,14 +71,7 @@ def list_entity_types(
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of distinct entity types in the audit log.
|
||||
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
types = (
|
||||
db.query(AuditLog.entity_type)
|
||||
.filter(AuditLog.entity_type.isnot(None))
|
||||
.distinct()
|
||||
.order_by(AuditLog.entity_type)
|
||||
.all()
|
||||
)
|
||||
return [t[0] for t in types]
|
||||
return list_distinct_entity_types(db)
|
||||
|
||||
@@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources
|
||||
including sync triggers, enable/disable toggles, and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.models.data_source import DataSource
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.data_source_service import (
|
||||
get_source_stats,
|
||||
list_sources,
|
||||
sync_all_sources,
|
||||
sync_source,
|
||||
update_source,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel):
|
||||
sync_frequency: Optional[str] = None
|
||||
config: Optional[dict] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync dispatcher — maps source name → import function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_sync_handler(source_name: str):
|
||||
"""Lazily import and return the sync function for *source_name*.
|
||||
|
||||
We import lazily to avoid circular imports and to only load the
|
||||
modules that are actually needed.
|
||||
"""
|
||||
handlers = {
|
||||
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
|
||||
"sigma": ("app.services.sigma_import_service", "sync"),
|
||||
"lolbas": ("app.services.lolbas_import_service", "sync"),
|
||||
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
|
||||
"caldera": ("app.services.caldera_import_service", "sync"),
|
||||
"elastic_rules": ("app.services.elastic_import_service", "sync"),
|
||||
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
|
||||
"d3fend": ("app.services.d3fend_import_service", "sync"),
|
||||
}
|
||||
|
||||
if source_name not in handlers:
|
||||
return None
|
||||
|
||||
module_path, func_name = handlers[source_name]
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
return getattr(mod, func_name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -79,25 +52,7 @@ def list_data_sources(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"name": s.name,
|
||||
"display_name": s.display_name,
|
||||
"type": s.type,
|
||||
"url": s.url,
|
||||
"description": s.description,
|
||||
"is_enabled": s.is_enabled,
|
||||
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
||||
"last_sync_status": s.last_sync_status,
|
||||
"last_sync_stats": s.last_sync_stats,
|
||||
"sync_frequency": s.sync_frequency,
|
||||
"config": s.config,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in sources
|
||||
]
|
||||
return list_sources(db)
|
||||
|
||||
|
||||
@router.patch("/{source_id}")
|
||||
@@ -111,31 +66,21 @@ def update_data_source(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
|
||||
if "is_enabled" in update_data:
|
||||
ds.is_enabled = update_data["is_enabled"]
|
||||
if "sync_frequency" in update_data:
|
||||
ds.sync_frequency = update_data["sync_frequency"]
|
||||
if "config" in update_data:
|
||||
ds.config = update_data["config"]
|
||||
|
||||
db.commit()
|
||||
update_source(db, source_id, **update_data)
|
||||
with UnitOfWork(db) as uow:
|
||||
uow.commit()
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_data_source",
|
||||
entity_type="data_source",
|
||||
entity_id=str(ds.id),
|
||||
entity_id=source_id,
|
||||
details={"updates": update_data},
|
||||
)
|
||||
|
||||
return {"message": "Data source updated", "id": str(ds.id)}
|
||||
return {"message": "Data source updated", "id": source_id}
|
||||
|
||||
|
||||
@router.post("/{source_id}/sync")
|
||||
@@ -148,46 +93,7 @@ def sync_data_source(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
handler = _get_sync_handler(ds.name)
|
||||
if handler is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No sync handler available for '{ds.name}'",
|
||||
)
|
||||
|
||||
# Mark as in_progress
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
|
||||
)
|
||||
|
||||
# Update DS record (the handler may already have done this,
|
||||
# but we ensure it here as well)
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"message": f"Sync complete for {ds.display_name}",
|
||||
"source": ds.name,
|
||||
"stats": summary,
|
||||
}
|
||||
return sync_source(db, source_id)
|
||||
|
||||
|
||||
@router.post("/sync-all")
|
||||
@@ -199,49 +105,7 @@ def sync_all_data_sources(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
enabled_sources = (
|
||||
db.query(DataSource)
|
||||
.filter(DataSource.is_enabled == True)
|
||||
.order_by(DataSource.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
results = []
|
||||
for ds in enabled_sources:
|
||||
handler = _get_sync_handler(ds.name)
|
||||
if handler is None:
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "skipped",
|
||||
"detail": "No sync handler available",
|
||||
})
|
||||
continue
|
||||
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "success",
|
||||
"stats": summary,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "error",
|
||||
"detail": "Sync failed. Check server logs for details.",
|
||||
})
|
||||
results = sync_all_sources(db)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
@@ -265,39 +129,4 @@ def get_data_source_stats(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
# Count items from this source
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
template_count = 0
|
||||
rule_count = 0
|
||||
|
||||
if ds.type == "attack_procedure":
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
elif ds.type == "detection_rule":
|
||||
rule_count = (
|
||||
db.query(DetectionRule)
|
||||
.filter(DetectionRule.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(ds.id),
|
||||
"name": ds.name,
|
||||
"display_name": ds.display_name,
|
||||
"type": ds.type,
|
||||
"is_enabled": ds.is_enabled,
|
||||
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
||||
"last_sync_status": ds.last_sync_status,
|
||||
"last_sync_stats": ds.last_sync_stats,
|
||||
"total_templates": template_count,
|
||||
"total_rules": rule_count,
|
||||
}
|
||||
return get_source_stats(db, source_id)
|
||||
|
||||
@@ -2,20 +2,24 @@
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
||||
from app.auth import hash_password
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.user_service import (
|
||||
create_user,
|
||||
get_user_or_raise,
|
||||
list_users,
|
||||
update_user,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users — list all users
|
||||
@@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe
|
||||
|
||||
|
||||
@router.get("", response_model=list[UserOut])
|
||||
def list_users(
|
||||
def list_users_route(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of all users. **Requires admin role.**"""
|
||||
return db.query(User).order_by(User.username).all()
|
||||
return list_users(db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -37,38 +41,23 @@ def list_users(
|
||||
|
||||
|
||||
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
def create_user(
|
||||
def create_user_route(
|
||||
payload: UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Create a new user. **Requires admin role.**"""
|
||||
|
||||
# Check if username already exists
|
||||
existing = db.query(User).filter(User.username == payload.username).first()
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Username '{payload.username}' already exists",
|
||||
)
|
||||
|
||||
# Validate role
|
||||
if payload.role not in VALID_ROLES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
||||
)
|
||||
|
||||
user = User(
|
||||
user = create_user(
|
||||
db,
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hash_password(payload.password),
|
||||
password=payload.password,
|
||||
role=payload.role,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
uow.commit()
|
||||
db.refresh(user)
|
||||
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
@@ -77,7 +66,7 @@ def create_user(
|
||||
entity_id=user.id,
|
||||
details={"username": user.username, "role": user.role},
|
||||
)
|
||||
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -93,13 +82,7 @@ def get_user(
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a single user by ID. **Requires admin role.**"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
return user
|
||||
return get_user_or_raise(db, user_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -108,46 +91,26 @@ def get_user(
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserOut)
|
||||
def update_user(
|
||||
def update_user_route(
|
||||
user_id: uuid.UUID,
|
||||
payload: UserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
# Validate role if being updated
|
||||
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
||||
)
|
||||
|
||||
# Hash password if being updated
|
||||
if "password" in update_data:
|
||||
update_data["hashed_password"] = hash_password(update_data.pop("password"))
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
db.commit()
|
||||
user = update_user(db, user_id, **update_data)
|
||||
with UnitOfWork(db) as uow:
|
||||
uow.commit()
|
||||
db.refresh(user)
|
||||
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_user",
|
||||
entity_type="user",
|
||||
entity_id=user.id,
|
||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
|
||||
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user