feat: add Campaign/Compliance domain entities and extract users/audit/data_sources to services (LP-2 through LP-6)

This commit is contained in:
2026-02-20 13:28:14 +01:00
parent 44621364be
commit c0c6cda11d
11 changed files with 939 additions and 319 deletions

View File

@@ -1,3 +1,15 @@
from app.domain.entities.campaign import CampaignEntity
from app.domain.entities.compliance import (
ComplianceControlEntity,
ComplianceFrameworkEntity,
ControlCoverageStatus,
)
from app.domain.entities.technique import TechniqueEntity
__all__ = ["TechniqueEntity"]
__all__ = [
"CampaignEntity",
"ComplianceControlEntity",
"ComplianceFrameworkEntity",
"ControlCoverageStatus",
"TechniqueEntity",
]

View File

@@ -0,0 +1,103 @@
"""Campaign domain entity with lifecycle validation.
Pure domain logic — no framework imports.
"""
from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field
from typing import Any
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
class CampaignStatus(str, enum.Enum):
draft = "draft"
active = "active"
completed = "completed"
archived = "archived"
class CampaignType(str, enum.Enum):
custom = "custom"
apt_emulation = "apt_emulation"
kill_chain = "kill_chain"
compliance = "compliance"
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
CampaignStatus.draft: [CampaignStatus.active],
CampaignStatus.active: [CampaignStatus.completed],
CampaignStatus.completed: [CampaignStatus.archived],
CampaignStatus.archived: [],
}
@dataclass
class CampaignEntity:
name: str
type: CampaignType = CampaignType.custom
status: CampaignStatus = CampaignStatus.draft
id: uuid.UUID | None = None
description: str | None = None
threat_actor_id: uuid.UUID | None = None
created_by: uuid.UUID | None = None
target_platform: str | None = None
tags: list[str] = field(default_factory=list)
test_count: int = 0
def can_transition_to(self, target: CampaignStatus) -> bool:
return target in VALID_TRANSITIONS.get(self.status, [])
def activate(self) -> None:
if not self.can_transition_to(CampaignStatus.active):
raise InvalidStateTransition(
self.status.value, CampaignStatus.active.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
if self.test_count == 0:
raise BusinessRuleViolation(
"Campaign must have at least one test to activate"
)
self.status = CampaignStatus.active
def complete(self) -> None:
if not self.can_transition_to(CampaignStatus.completed):
raise InvalidStateTransition(
self.status.value, CampaignStatus.completed.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
self.status = CampaignStatus.completed
def archive(self) -> None:
if not self.can_transition_to(CampaignStatus.archived):
raise InvalidStateTransition(
self.status.value, CampaignStatus.archived.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
self.status = CampaignStatus.archived
def ensure_modifiable(self) -> None:
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
raise BusinessRuleViolation(
f"Cannot modify campaign in '{self.status.value}' state"
)
@classmethod
def from_orm(cls, orm: Any) -> CampaignEntity:
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
test_count = len(getattr(orm, "campaign_tests", None) or [])
return cls(
id=orm.id,
name=orm.name,
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
description=orm.description,
threat_actor_id=orm.threat_actor_id,
created_by=orm.created_by,
target_platform=orm.target_platform,
tags=orm.tags or [],
test_count=test_count,
)

View File

@@ -0,0 +1,71 @@
"""Compliance domain entities with coverage calculation logic.
Pure domain logic — no framework imports.
"""
from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field
class ControlCoverageStatus(str, enum.Enum):
covered = "covered"
partially_covered = "partially_covered"
not_covered = "not_covered"
@dataclass
class ComplianceControlEntity:
control_id: str
title: str
id: uuid.UUID | None = None
description: str | None = None
category: str | None = None
technique_statuses: list[str] = field(default_factory=list)
@property
def coverage_status(self) -> ControlCoverageStatus:
if not self.technique_statuses:
return ControlCoverageStatus.not_covered
covered_statuses = {"validated", "partial"}
covered = [s for s in self.technique_statuses if s in covered_statuses]
if len(covered) == len(self.technique_statuses):
return ControlCoverageStatus.covered
elif len(covered) > 0:
return ControlCoverageStatus.partially_covered
return ControlCoverageStatus.not_covered
@dataclass
class ComplianceFrameworkEntity:
name: str
id: uuid.UUID | None = None
version: str | None = None
description: str | None = None
is_active: bool = True
controls: list[ComplianceControlEntity] = field(default_factory=list)
@property
def total_controls(self) -> int:
return len(self.controls)
@property
def covered_controls(self) -> int:
return sum(
1 for c in self.controls
if c.coverage_status == ControlCoverageStatus.covered
)
@property
def coverage_pct(self) -> float:
if self.total_controls == 0:
return 0.0
return round(self.covered_controls / self.total_controls * 100, 1)
def get_gap_controls(self) -> list[ComplianceControlEntity]:
return [
c for c in self.controls
if c.coverage_status != ControlCoverageStatus.covered
]

View File

@@ -4,14 +4,17 @@ from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_role
from app.models.audit import AuditLog
from app.models.user import User
from app.schemas.audit import AuditLogOut, AuditLogPage
from app.services.audit_query_service import (
list_distinct_actions,
list_distinct_entity_types,
list_logs,
)
router = APIRouter(prefix="/audit-logs", tags=["audit"])
@@ -29,56 +32,25 @@ def list_audit_logs(
current_user: User = Depends(require_role("admin")),
):
"""Return paginated audit logs with optional filters.
**Requires admin role.**
"""
query = db.query(AuditLog).options(joinedload(AuditLog.user))
# Apply filters
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action)
if entity_type:
query = query.filter(AuditLog.entity_type == entity_type)
if start_date:
query = query.filter(AuditLog.timestamp >= start_date)
if end_date:
query = query.filter(AuditLog.timestamp <= end_date)
# Get total count
total = query.count()
# Get paginated results
logs = (
query
.order_by(AuditLog.timestamp.desc())
.offset(offset)
.limit(limit)
.all()
)
# Convert to response format with username
items = []
for log in logs:
item = AuditLogOut(
id=log.id,
user_id=log.user_id,
username=log.user.username if log.user else None,
action=log.action,
entity_type=log.entity_type,
entity_id=log.entity_id,
timestamp=log.timestamp,
details=log.details,
)
items.append(item)
return AuditLogPage(
items=items,
total=total,
result = list_logs(
db,
user_id=user_id,
action=action,
entity_type=entity_type,
start_date=start_date,
end_date=end_date,
offset=offset,
limit=limit,
)
return AuditLogPage(
items=[AuditLogOut(**item) for item in result["items"]],
total=result["total"],
offset=result["offset"],
limit=result["limit"],
)
@router.get("/actions", response_model=list[str])
@@ -87,16 +59,10 @@ def list_actions(
current_user: User = Depends(require_role("admin")),
):
"""Return a list of distinct action types in the audit log.
**Requires admin role.**
"""
actions = (
db.query(AuditLog.action)
.distinct()
.order_by(AuditLog.action)
.all()
)
return [a[0] for a in actions]
return list_distinct_actions(db)
@router.get("/entity-types", response_model=list[str])
@@ -105,14 +71,7 @@ def list_entity_types(
current_user: User = Depends(require_role("admin")),
):
"""Return a list of distinct entity types in the audit log.
**Requires admin role.**
"""
types = (
db.query(AuditLog.entity_type)
.filter(AuditLog.entity_type.isnot(None))
.distinct()
.order_by(AuditLog.entity_type)
.all()
)
return [t[0] for t in types]
return list_distinct_entity_types(db)

View File

@@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources
including sync triggers, enable/disable toggles, and statistics.
"""
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing import Optional
from app.database import get_db
from app.dependencies.auth import require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.data_source import DataSource
from app.services.audit_service import log_action
from app.services.data_source_service import (
get_source_stats,
list_sources,
sync_all_sources,
sync_source,
update_source,
)
# ---------------------------------------------------------------------------
@@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel):
sync_frequency: Optional[str] = None
config: Optional[dict] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
# ---------------------------------------------------------------------------
# Sync dispatcher — maps source name → import function
# ---------------------------------------------------------------------------
def _get_sync_handler(source_name: str):
"""Lazily import and return the sync function for *source_name*.
We import lazily to avoid circular imports and to only load the
modules that are actually needed.
"""
handlers = {
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
"sigma": ("app.services.sigma_import_service", "sync"),
"lolbas": ("app.services.lolbas_import_service", "sync"),
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
"caldera": ("app.services.caldera_import_service", "sync"),
"elastic_rules": ("app.services.elastic_import_service", "sync"),
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
"d3fend": ("app.services.d3fend_import_service", "sync"),
}
if source_name not in handlers:
return None
module_path, func_name = handlers[source_name]
import importlib
mod = importlib.import_module(module_path)
return getattr(mod, func_name)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@@ -79,25 +52,7 @@ def list_data_sources(
**Requires** the ``admin`` role.
"""
sources = db.query(DataSource).order_by(DataSource.name).all()
return [
{
"id": str(s.id),
"name": s.name,
"display_name": s.display_name,
"type": s.type,
"url": s.url,
"description": s.description,
"is_enabled": s.is_enabled,
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
"last_sync_status": s.last_sync_status,
"last_sync_stats": s.last_sync_stats,
"sync_frequency": s.sync_frequency,
"config": s.config,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in sources
]
return list_sources(db)
@router.patch("/{source_id}")
@@ -111,31 +66,21 @@ def update_data_source(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
update_data = body.model_dump(exclude_unset=True)
if "is_enabled" in update_data:
ds.is_enabled = update_data["is_enabled"]
if "sync_frequency" in update_data:
ds.sync_frequency = update_data["sync_frequency"]
if "config" in update_data:
ds.config = update_data["config"]
db.commit()
update_source(db, source_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
log_action(
db,
user_id=current_user.id,
action="update_data_source",
entity_type="data_source",
entity_id=str(ds.id),
entity_id=source_id,
details={"updates": update_data},
)
return {"message": "Data source updated", "id": str(ds.id)}
return {"message": "Data source updated", "id": source_id}
@router.post("/{source_id}/sync")
@@ -148,46 +93,7 @@ def sync_data_source(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
handler = _get_sync_handler(ds.name)
if handler is None:
raise HTTPException(
status_code=400,
detail=f"No sync handler available for '{ds.name}'",
)
# Mark as in_progress
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise HTTPException(
status_code=500,
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
)
# Update DS record (the handler may already have done this,
# but we ensure it here as well)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
return {
"message": f"Sync complete for {ds.display_name}",
"source": ds.name,
"stats": summary,
}
return sync_source(db, source_id)
@router.post("/sync-all")
@@ -199,49 +105,7 @@ def sync_all_data_sources(
**Requires** the ``admin`` role.
"""
enabled_sources = (
db.query(DataSource)
.filter(DataSource.is_enabled == True)
.order_by(DataSource.name)
.all()
)
results = []
for ds in enabled_sources:
handler = _get_sync_handler(ds.name)
if handler is None:
results.append({
"source": ds.name,
"status": "skipped",
"detail": "No sync handler available",
})
continue
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
results.append({
"source": ds.name,
"status": "success",
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
results.append({
"source": ds.name,
"status": "error",
"detail": "Sync failed. Check server logs for details.",
})
results = sync_all_sources(db)
log_action(
db,
@@ -265,39 +129,4 @@ def get_data_source_stats(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
# Count items from this source
from app.models.test_template import TestTemplate
from app.models.detection_rule import DetectionRule
template_count = 0
rule_count = 0
if ds.type == "attack_procedure":
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.source == ds.name)
.count()
)
elif ds.type == "detection_rule":
rule_count = (
db.query(DetectionRule)
.filter(DetectionRule.source == ds.name)
.count()
)
return {
"id": str(ds.id),
"name": ds.name,
"display_name": ds.display_name,
"type": ds.type,
"is_enabled": ds.is_enabled,
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
"last_sync_status": ds.last_sync_status,
"last_sync_stats": ds.last_sync_stats,
"total_templates": template_count,
"total_rules": rule_count,
}
return get_source_stats(db, source_id)

View File

@@ -2,20 +2,24 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserOut
from app.auth import hash_password
from app.services.audit_service import log_action
from app.services.user_service import (
create_user,
get_user_or_raise,
list_users,
update_user,
)
router = APIRouter(prefix="/users", tags=["users"])
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
# ---------------------------------------------------------------------------
# GET /users — list all users
@@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe
@router.get("", response_model=list[UserOut])
def list_users(
def list_users_route(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Return a list of all users. **Requires admin role.**"""
return db.query(User).order_by(User.username).all()
return list_users(db)
# ---------------------------------------------------------------------------
@@ -37,38 +41,23 @@ def list_users(
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
def create_user(
def create_user_route(
payload: UserCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Create a new user. **Requires admin role.**"""
# Check if username already exists
existing = db.query(User).filter(User.username == payload.username).first()
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Username '{payload.username}' already exists",
)
# Validate role
if payload.role not in VALID_ROLES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
)
user = User(
user = create_user(
db,
username=payload.username,
email=payload.email,
hashed_password=hash_password(payload.password),
password=payload.password,
role=payload.role,
)
db.add(user)
db.commit()
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
log_action(
db,
user_id=current_user.id,
@@ -77,7 +66,7 @@ def create_user(
entity_id=user.id,
details={"username": user.username, "role": user.role},
)
return user
@@ -93,13 +82,7 @@ def get_user(
current_user: User = Depends(require_role("admin")),
):
"""Return a single user by ID. **Requires admin role.**"""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
return user
return get_user_or_raise(db, user_id)
# ---------------------------------------------------------------------------
@@ -108,46 +91,26 @@ def get_user(
@router.patch("/{user_id}", response_model=UserOut)
def update_user(
def update_user_route(
user_id: uuid.UUID,
payload: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update one or more fields of an existing user. **Requires admin role.**"""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
update_data = payload.model_dump(exclude_unset=True)
# Validate role if being updated
if "role" in update_data and update_data["role"] not in VALID_ROLES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
)
# Hash password if being updated
if "password" in update_data:
update_data["hashed_password"] = hash_password(update_data.pop("password"))
for field, value in update_data.items():
setattr(user, field, value)
db.commit()
user = update_user(db, user_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
log_action(
db,
user_id=current_user.id,
action="update_user",
entity_type="user",
entity_id=user.id,
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
details={"updated_fields": list(update_data.keys())},
)
return user

View File

@@ -0,0 +1,93 @@
"""Audit log query service — framework-agnostic query logic for audit logs.
Provides paginated logs and distinct action/entity-type lists.
No FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy.orm import Session, joinedload
from app.models.audit import AuditLog
def list_logs(
db: Session,
*,
user_id: str | None = None,
action: str | None = None,
entity_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""Return paginated audit logs with optional filters.
Returns a dict with keys: items, total, offset, limit.
Each item is a dict with: id, user_id, username, action, entity_type,
entity_id, timestamp, details.
"""
query = db.query(AuditLog).options(joinedload(AuditLog.user))
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action)
if entity_type:
query = query.filter(AuditLog.entity_type == entity_type)
if start_date:
query = query.filter(AuditLog.timestamp >= start_date)
if end_date:
query = query.filter(AuditLog.timestamp <= end_date)
total = query.count()
logs = (
query
.order_by(AuditLog.timestamp.desc())
.offset(offset)
.limit(limit)
.all()
)
items = [
{
"id": log.id,
"user_id": log.user_id,
"username": log.user.username if log.user else None,
"action": log.action,
"entity_type": log.entity_type,
"entity_id": log.entity_id,
"timestamp": log.timestamp,
"details": log.details,
}
for log in logs
]
return {"items": items, "total": total, "offset": offset, "limit": limit}
def list_distinct_actions(db: Session) -> list[str]:
"""Return a list of distinct action types in the audit log."""
actions = (
db.query(AuditLog.action)
.distinct()
.order_by(AuditLog.action)
.all()
)
return [a[0] for a in actions]
def list_distinct_entity_types(db: Session) -> list[str]:
"""Return a list of distinct entity types in the audit log."""
types = (
db.query(AuditLog.entity_type)
.filter(AuditLog.entity_type.isnot(None))
.distinct()
.order_by(AuditLog.entity_type)
.all()
)
return [t[0] for t in types]

View File

@@ -0,0 +1,222 @@
"""Data source management service — framework-agnostic query and sync logic.
Provides list, update, sync, and stats. Sync operations commit internally
since they are long-running and self-contained.
"""
from __future__ import annotations
import importlib
import logging
from datetime import datetime
from sqlalchemy.orm import Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.models.data_source import DataSource
logger = logging.getLogger(__name__)
def _get_sync_handler(source_name: str):
"""Lazily import and return the sync function for *source_name*.
We import lazily to avoid circular imports and to only load the
modules that are actually needed.
"""
handlers = {
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
"sigma": ("app.services.sigma_import_service", "sync"),
"lolbas": ("app.services.lolbas_import_service", "sync"),
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
"caldera": ("app.services.caldera_import_service", "sync"),
"elastic_rules": ("app.services.elastic_import_service", "sync"),
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
"d3fend": ("app.services.d3fend_import_service", "sync"),
}
if source_name not in handlers:
return None
module_path, func_name = handlers[source_name]
mod = importlib.import_module(module_path)
return getattr(mod, func_name)
def list_sources(db: Session) -> list[dict]:
"""Return all registered data sources as a list of dicts."""
sources = db.query(DataSource).order_by(DataSource.name).all()
return [
{
"id": str(s.id),
"name": s.name,
"display_name": s.display_name,
"type": s.type,
"url": s.url,
"description": s.description,
"is_enabled": s.is_enabled,
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
"last_sync_status": s.last_sync_status,
"last_sync_stats": s.last_sync_stats,
"sync_frequency": s.sync_frequency,
"config": s.config,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in sources
]
def update_source(db: Session, source_id: str, **fields: object) -> None:
"""Update a data source's fields (is_enabled, sync_frequency, config).
Raises EntityNotFoundError if source does not exist.
Does not commit; the router handles that.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
if "is_enabled" in fields:
ds.is_enabled = fields["is_enabled"]
if "sync_frequency" in fields:
ds.sync_frequency = fields["sync_frequency"]
if "config" in fields:
ds.config = fields["config"]
def sync_source(db: Session, source_id: str) -> dict:
"""Trigger sync for a specific data source.
Raises EntityNotFoundError if source does not exist.
Raises BusinessRuleViolation if no sync handler is available.
Commits internally (long-running, self-contained operation).
Returns dict with message, source, stats.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
handler = _get_sync_handler(ds.name)
if handler is None:
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise BusinessRuleViolation(
f"Sync failed for '{ds.display_name}'. Check server logs for details."
)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
return {
"message": f"Sync complete for {ds.display_name}",
"source": ds.name,
"stats": summary,
}
def sync_all_sources(db: Session) -> list[dict]:
"""Trigger sync for all enabled data sources (sequentially).
Commits internally (long-running, self-contained operation).
Returns list of result dicts with source, status, stats/detail.
"""
enabled_sources = (
db.query(DataSource)
.filter(DataSource.is_enabled == True)
.order_by(DataSource.name)
.all()
)
results = []
for ds in enabled_sources:
handler = _get_sync_handler(ds.name)
if handler is None:
results.append({
"source": ds.name,
"status": "skipped",
"detail": "No sync handler available",
})
continue
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
results.append({
"source": ds.name,
"status": "success",
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
results.append({
"source": ds.name,
"status": "error",
"detail": "Sync failed. Check server logs for details.",
})
return results
def get_source_stats(db: Session, source_id: str) -> dict:
"""Return detailed statistics for a data source.
Raises EntityNotFoundError if source does not exist.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
from app.models.test_template import TestTemplate
from app.models.detection_rule import DetectionRule
template_count = 0
rule_count = 0
if ds.type == "attack_procedure":
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.source == ds.name)
.count()
)
elif ds.type == "detection_rule":
rule_count = (
db.query(DetectionRule)
.filter(DetectionRule.source == ds.name)
.count()
)
return {
"id": str(ds.id),
"name": ds.name,
"display_name": ds.display_name,
"type": ds.type,
"is_enabled": ds.is_enabled,
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
"last_sync_status": ds.last_sync_status,
"last_sync_stats": ds.last_sync_stats,
"total_templates": template_count,
"total_rules": rule_count,
}

View File

@@ -0,0 +1,88 @@
"""User management service — framework-agnostic CRUD for users.
Uses domain exceptions from app.domain.errors. The router handles
HTTP concerns, auth, audit logging, and commit.
"""
from __future__ import annotations
import uuid
from sqlalchemy.orm import Session
from app.auth import hash_password
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
from app.models.user import User
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
def list_users(db: Session) -> list[User]:
"""Return a list of all users ordered by username."""
return db.query(User).order_by(User.username).all()
def create_user(
db: Session,
*,
username: str,
email: str | None,
password: str,
role: str,
) -> User:
"""Create a new user.
Raises DuplicateEntityError if username already exists.
Raises BusinessRuleViolation if role is invalid.
Does not commit; the router handles that.
"""
existing = db.query(User).filter(User.username == username).first()
if existing:
raise DuplicateEntityError("User", "username", username)
if role not in VALID_ROLES:
raise BusinessRuleViolation(
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
user = User(
username=username,
email=email,
hashed_password=hash_password(password),
role=role,
)
db.add(user)
return user
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
"""Return a user by ID or raise EntityNotFoundError."""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise EntityNotFoundError("User", str(user_id))
return user
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
"""Update one or more fields of an existing user.
Raises EntityNotFoundError if user does not exist.
Raises BusinessRuleViolation if role is invalid.
Handles 'password' by hashing and storing as 'hashed_password'.
Does not commit; the router handles that.
"""
user = get_user_or_raise(db, user_id)
update_data = dict(fields)
if "role" in update_data and update_data["role"] not in VALID_ROLES:
raise BusinessRuleViolation(
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
if "password" in update_data:
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
for field, value in update_data.items():
setattr(user, field, value)
return user