feat: add Campaign/Compliance domain entities and extract users/audit/data_sources to services (LP-2 through LP-6)

This commit is contained in:
2026-02-20 13:28:14 +01:00
parent 44621364be
commit c0c6cda11d
11 changed files with 939 additions and 319 deletions

View File

@@ -1,3 +1,15 @@
from app.domain.entities.campaign import CampaignEntity
from app.domain.entities.compliance import (
ComplianceControlEntity,
ComplianceFrameworkEntity,
ControlCoverageStatus,
)
from app.domain.entities.technique import TechniqueEntity
__all__ = ["TechniqueEntity"]
__all__ = [
"CampaignEntity",
"ComplianceControlEntity",
"ComplianceFrameworkEntity",
"ControlCoverageStatus",
"TechniqueEntity",
]

View File

@@ -0,0 +1,103 @@
"""Campaign domain entity with lifecycle validation.
Pure domain logic — no framework imports.
"""
from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field
from typing import Any
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
class CampaignStatus(str, enum.Enum):
draft = "draft"
active = "active"
completed = "completed"
archived = "archived"
class CampaignType(str, enum.Enum):
custom = "custom"
apt_emulation = "apt_emulation"
kill_chain = "kill_chain"
compliance = "compliance"
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
CampaignStatus.draft: [CampaignStatus.active],
CampaignStatus.active: [CampaignStatus.completed],
CampaignStatus.completed: [CampaignStatus.archived],
CampaignStatus.archived: [],
}
@dataclass
class CampaignEntity:
name: str
type: CampaignType = CampaignType.custom
status: CampaignStatus = CampaignStatus.draft
id: uuid.UUID | None = None
description: str | None = None
threat_actor_id: uuid.UUID | None = None
created_by: uuid.UUID | None = None
target_platform: str | None = None
tags: list[str] = field(default_factory=list)
test_count: int = 0
def can_transition_to(self, target: CampaignStatus) -> bool:
return target in VALID_TRANSITIONS.get(self.status, [])
def activate(self) -> None:
if not self.can_transition_to(CampaignStatus.active):
raise InvalidStateTransition(
self.status.value, CampaignStatus.active.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
if self.test_count == 0:
raise BusinessRuleViolation(
"Campaign must have at least one test to activate"
)
self.status = CampaignStatus.active
def complete(self) -> None:
if not self.can_transition_to(CampaignStatus.completed):
raise InvalidStateTransition(
self.status.value, CampaignStatus.completed.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
self.status = CampaignStatus.completed
def archive(self) -> None:
if not self.can_transition_to(CampaignStatus.archived):
raise InvalidStateTransition(
self.status.value, CampaignStatus.archived.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
self.status = CampaignStatus.archived
def ensure_modifiable(self) -> None:
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
raise BusinessRuleViolation(
f"Cannot modify campaign in '{self.status.value}' state"
)
@classmethod
def from_orm(cls, orm: Any) -> CampaignEntity:
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
test_count = len(getattr(orm, "campaign_tests", None) or [])
return cls(
id=orm.id,
name=orm.name,
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
description=orm.description,
threat_actor_id=orm.threat_actor_id,
created_by=orm.created_by,
target_platform=orm.target_platform,
tags=orm.tags or [],
test_count=test_count,
)

View File

@@ -0,0 +1,71 @@
"""Compliance domain entities with coverage calculation logic.
Pure domain logic — no framework imports.
"""
from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field
class ControlCoverageStatus(str, enum.Enum):
covered = "covered"
partially_covered = "partially_covered"
not_covered = "not_covered"
@dataclass
class ComplianceControlEntity:
control_id: str
title: str
id: uuid.UUID | None = None
description: str | None = None
category: str | None = None
technique_statuses: list[str] = field(default_factory=list)
@property
def coverage_status(self) -> ControlCoverageStatus:
if not self.technique_statuses:
return ControlCoverageStatus.not_covered
covered_statuses = {"validated", "partial"}
covered = [s for s in self.technique_statuses if s in covered_statuses]
if len(covered) == len(self.technique_statuses):
return ControlCoverageStatus.covered
elif len(covered) > 0:
return ControlCoverageStatus.partially_covered
return ControlCoverageStatus.not_covered
@dataclass
class ComplianceFrameworkEntity:
name: str
id: uuid.UUID | None = None
version: str | None = None
description: str | None = None
is_active: bool = True
controls: list[ComplianceControlEntity] = field(default_factory=list)
@property
def total_controls(self) -> int:
return len(self.controls)
@property
def covered_controls(self) -> int:
return sum(
1 for c in self.controls
if c.coverage_status == ControlCoverageStatus.covered
)
@property
def coverage_pct(self) -> float:
if self.total_controls == 0:
return 0.0
return round(self.covered_controls / self.total_controls * 100, 1)
def get_gap_controls(self) -> list[ComplianceControlEntity]:
return [
c for c in self.controls
if c.coverage_status != ControlCoverageStatus.covered
]

View File

@@ -4,14 +4,17 @@ from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_role
from app.models.audit import AuditLog
from app.models.user import User
from app.schemas.audit import AuditLogOut, AuditLogPage
from app.services.audit_query_service import (
list_distinct_actions,
list_distinct_entity_types,
list_logs,
)
router = APIRouter(prefix="/audit-logs", tags=["audit"])
@@ -32,53 +35,22 @@ def list_audit_logs(
**Requires admin role.**
"""
query = db.query(AuditLog).options(joinedload(AuditLog.user))
# Apply filters
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action)
if entity_type:
query = query.filter(AuditLog.entity_type == entity_type)
if start_date:
query = query.filter(AuditLog.timestamp >= start_date)
if end_date:
query = query.filter(AuditLog.timestamp <= end_date)
# Get total count
total = query.count()
# Get paginated results
logs = (
query
.order_by(AuditLog.timestamp.desc())
.offset(offset)
.limit(limit)
.all()
)
# Convert to response format with username
items = []
for log in logs:
item = AuditLogOut(
id=log.id,
user_id=log.user_id,
username=log.user.username if log.user else None,
action=log.action,
entity_type=log.entity_type,
entity_id=log.entity_id,
timestamp=log.timestamp,
details=log.details,
)
items.append(item)
return AuditLogPage(
items=items,
total=total,
result = list_logs(
db,
user_id=user_id,
action=action,
entity_type=entity_type,
start_date=start_date,
end_date=end_date,
offset=offset,
limit=limit,
)
return AuditLogPage(
items=[AuditLogOut(**item) for item in result["items"]],
total=result["total"],
offset=result["offset"],
limit=result["limit"],
)
@router.get("/actions", response_model=list[str])
@@ -90,13 +62,7 @@ def list_actions(
**Requires admin role.**
"""
actions = (
db.query(AuditLog.action)
.distinct()
.order_by(AuditLog.action)
.all()
)
return [a[0] for a in actions]
return list_distinct_actions(db)
@router.get("/entity-types", response_model=list[str])
@@ -108,11 +74,4 @@ def list_entity_types(
**Requires admin role.**
"""
types = (
db.query(AuditLog.entity_type)
.filter(AuditLog.entity_type.isnot(None))
.distinct()
.order_by(AuditLog.entity_type)
.all()
)
return [t[0] for t in types]
return list_distinct_entity_types(db)

View File

@@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources
including sync triggers, enable/disable toggles, and statistics.
"""
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing import Optional
from app.database import get_db
from app.dependencies.auth import require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.data_source import DataSource
from app.services.audit_service import log_action
from app.services.data_source_service import (
get_source_stats,
list_sources,
sync_all_sources,
sync_source,
update_source,
)
# ---------------------------------------------------------------------------
@@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel):
sync_frequency: Optional[str] = None
config: Optional[dict] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
# ---------------------------------------------------------------------------
# Sync dispatcher — maps source name → import function
# ---------------------------------------------------------------------------
def _get_sync_handler(source_name: str):
"""Lazily import and return the sync function for *source_name*.
We import lazily to avoid circular imports and to only load the
modules that are actually needed.
"""
handlers = {
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
"sigma": ("app.services.sigma_import_service", "sync"),
"lolbas": ("app.services.lolbas_import_service", "sync"),
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
"caldera": ("app.services.caldera_import_service", "sync"),
"elastic_rules": ("app.services.elastic_import_service", "sync"),
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
"d3fend": ("app.services.d3fend_import_service", "sync"),
}
if source_name not in handlers:
return None
module_path, func_name = handlers[source_name]
import importlib
mod = importlib.import_module(module_path)
return getattr(mod, func_name)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@@ -79,25 +52,7 @@ def list_data_sources(
**Requires** the ``admin`` role.
"""
sources = db.query(DataSource).order_by(DataSource.name).all()
return [
{
"id": str(s.id),
"name": s.name,
"display_name": s.display_name,
"type": s.type,
"url": s.url,
"description": s.description,
"is_enabled": s.is_enabled,
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
"last_sync_status": s.last_sync_status,
"last_sync_stats": s.last_sync_stats,
"sync_frequency": s.sync_frequency,
"config": s.config,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in sources
]
return list_sources(db)
@router.patch("/{source_id}")
@@ -111,31 +66,21 @@ def update_data_source(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
update_data = body.model_dump(exclude_unset=True)
if "is_enabled" in update_data:
ds.is_enabled = update_data["is_enabled"]
if "sync_frequency" in update_data:
ds.sync_frequency = update_data["sync_frequency"]
if "config" in update_data:
ds.config = update_data["config"]
db.commit()
update_source(db, source_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
log_action(
db,
user_id=current_user.id,
action="update_data_source",
entity_type="data_source",
entity_id=str(ds.id),
entity_id=source_id,
details={"updates": update_data},
)
return {"message": "Data source updated", "id": str(ds.id)}
return {"message": "Data source updated", "id": source_id}
@router.post("/{source_id}/sync")
@@ -148,46 +93,7 @@ def sync_data_source(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
handler = _get_sync_handler(ds.name)
if handler is None:
raise HTTPException(
status_code=400,
detail=f"No sync handler available for '{ds.name}'",
)
# Mark as in_progress
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise HTTPException(
status_code=500,
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
)
# Update DS record (the handler may already have done this,
# but we ensure it here as well)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
return {
"message": f"Sync complete for {ds.display_name}",
"source": ds.name,
"stats": summary,
}
return sync_source(db, source_id)
@router.post("/sync-all")
@@ -199,49 +105,7 @@ def sync_all_data_sources(
**Requires** the ``admin`` role.
"""
enabled_sources = (
db.query(DataSource)
.filter(DataSource.is_enabled == True)
.order_by(DataSource.name)
.all()
)
results = []
for ds in enabled_sources:
handler = _get_sync_handler(ds.name)
if handler is None:
results.append({
"source": ds.name,
"status": "skipped",
"detail": "No sync handler available",
})
continue
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
results.append({
"source": ds.name,
"status": "success",
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
results.append({
"source": ds.name,
"status": "error",
"detail": "Sync failed. Check server logs for details.",
})
results = sync_all_sources(db)
log_action(
db,
@@ -265,39 +129,4 @@ def get_data_source_stats(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
# Count items from this source
from app.models.test_template import TestTemplate
from app.models.detection_rule import DetectionRule
template_count = 0
rule_count = 0
if ds.type == "attack_procedure":
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.source == ds.name)
.count()
)
elif ds.type == "detection_rule":
rule_count = (
db.query(DetectionRule)
.filter(DetectionRule.source == ds.name)
.count()
)
return {
"id": str(ds.id),
"name": ds.name,
"display_name": ds.display_name,
"type": ds.type,
"is_enabled": ds.is_enabled,
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
"last_sync_status": ds.last_sync_status,
"last_sync_stats": ds.last_sync_stats,
"total_templates": template_count,
"total_rules": rule_count,
}
return get_source_stats(db, source_id)

View File

@@ -2,20 +2,24 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserOut
from app.auth import hash_password
from app.services.audit_service import log_action
from app.services.user_service import (
create_user,
get_user_or_raise,
list_users,
update_user,
)
router = APIRouter(prefix="/users", tags=["users"])
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
# ---------------------------------------------------------------------------
# GET /users — list all users
@@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe
@router.get("", response_model=list[UserOut])
def list_users(
def list_users_route(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Return a list of all users. **Requires admin role.**"""
return db.query(User).order_by(User.username).all()
return list_users(db)
# ---------------------------------------------------------------------------
@@ -37,36 +41,21 @@ def list_users(
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
def create_user(
def create_user_route(
payload: UserCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Create a new user. **Requires admin role.**"""
# Check if username already exists
existing = db.query(User).filter(User.username == payload.username).first()
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Username '{payload.username}' already exists",
)
# Validate role
if payload.role not in VALID_ROLES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
)
user = User(
user = create_user(
db,
username=payload.username,
email=payload.email,
hashed_password=hash_password(payload.password),
password=payload.password,
role=payload.role,
)
db.add(user)
db.commit()
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
log_action(
@@ -93,13 +82,7 @@ def get_user(
current_user: User = Depends(require_role("admin")),
):
"""Return a single user by ID. **Requires admin role.**"""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
return user
return get_user_or_raise(db, user_id)
# ---------------------------------------------------------------------------
@@ -108,37 +91,17 @@ def get_user(
@router.patch("/{user_id}", response_model=UserOut)
def update_user(
def update_user_route(
user_id: uuid.UUID,
payload: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update one or more fields of an existing user. **Requires admin role.**"""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
update_data = payload.model_dump(exclude_unset=True)
# Validate role if being updated
if "role" in update_data and update_data["role"] not in VALID_ROLES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
)
# Hash password if being updated
if "password" in update_data:
update_data["hashed_password"] = hash_password(update_data.pop("password"))
for field, value in update_data.items():
setattr(user, field, value)
db.commit()
user = update_user(db, user_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
log_action(
@@ -147,7 +110,7 @@ def update_user(
action="update_user",
entity_type="user",
entity_id=user.id,
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
details={"updated_fields": list(update_data.keys())},
)
return user

View File

@@ -0,0 +1,93 @@
"""Audit log query service — framework-agnostic query logic for audit logs.
Provides paginated logs and distinct action/entity-type lists.
No FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy.orm import Session, joinedload
from app.models.audit import AuditLog
def list_logs(
db: Session,
*,
user_id: str | None = None,
action: str | None = None,
entity_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""Return paginated audit logs with optional filters.
Returns a dict with keys: items, total, offset, limit.
Each item is a dict with: id, user_id, username, action, entity_type,
entity_id, timestamp, details.
"""
query = db.query(AuditLog).options(joinedload(AuditLog.user))
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action)
if entity_type:
query = query.filter(AuditLog.entity_type == entity_type)
if start_date:
query = query.filter(AuditLog.timestamp >= start_date)
if end_date:
query = query.filter(AuditLog.timestamp <= end_date)
total = query.count()
logs = (
query
.order_by(AuditLog.timestamp.desc())
.offset(offset)
.limit(limit)
.all()
)
items = [
{
"id": log.id,
"user_id": log.user_id,
"username": log.user.username if log.user else None,
"action": log.action,
"entity_type": log.entity_type,
"entity_id": log.entity_id,
"timestamp": log.timestamp,
"details": log.details,
}
for log in logs
]
return {"items": items, "total": total, "offset": offset, "limit": limit}
def list_distinct_actions(db: Session) -> list[str]:
"""Return a list of distinct action types in the audit log."""
actions = (
db.query(AuditLog.action)
.distinct()
.order_by(AuditLog.action)
.all()
)
return [a[0] for a in actions]
def list_distinct_entity_types(db: Session) -> list[str]:
"""Return a list of distinct entity types in the audit log."""
types = (
db.query(AuditLog.entity_type)
.filter(AuditLog.entity_type.isnot(None))
.distinct()
.order_by(AuditLog.entity_type)
.all()
)
return [t[0] for t in types]

View File

@@ -0,0 +1,222 @@
"""Data source management service — framework-agnostic query and sync logic.
Provides list, update, sync, and stats. Sync operations commit internally
since they are long-running and self-contained.
"""
from __future__ import annotations
import importlib
import logging
from datetime import datetime
from sqlalchemy.orm import Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.models.data_source import DataSource
logger = logging.getLogger(__name__)
def _get_sync_handler(source_name: str):
"""Lazily import and return the sync function for *source_name*.
We import lazily to avoid circular imports and to only load the
modules that are actually needed.
"""
handlers = {
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
"sigma": ("app.services.sigma_import_service", "sync"),
"lolbas": ("app.services.lolbas_import_service", "sync"),
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
"caldera": ("app.services.caldera_import_service", "sync"),
"elastic_rules": ("app.services.elastic_import_service", "sync"),
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
"d3fend": ("app.services.d3fend_import_service", "sync"),
}
if source_name not in handlers:
return None
module_path, func_name = handlers[source_name]
mod = importlib.import_module(module_path)
return getattr(mod, func_name)
def list_sources(db: Session) -> list[dict]:
"""Return all registered data sources as a list of dicts."""
sources = db.query(DataSource).order_by(DataSource.name).all()
return [
{
"id": str(s.id),
"name": s.name,
"display_name": s.display_name,
"type": s.type,
"url": s.url,
"description": s.description,
"is_enabled": s.is_enabled,
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
"last_sync_status": s.last_sync_status,
"last_sync_stats": s.last_sync_stats,
"sync_frequency": s.sync_frequency,
"config": s.config,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in sources
]
def update_source(db: Session, source_id: str, **fields: object) -> None:
"""Update a data source's fields (is_enabled, sync_frequency, config).
Raises EntityNotFoundError if source does not exist.
Does not commit; the router handles that.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
if "is_enabled" in fields:
ds.is_enabled = fields["is_enabled"]
if "sync_frequency" in fields:
ds.sync_frequency = fields["sync_frequency"]
if "config" in fields:
ds.config = fields["config"]
def sync_source(db: Session, source_id: str) -> dict:
"""Trigger sync for a specific data source.
Raises EntityNotFoundError if source does not exist.
Raises BusinessRuleViolation if no sync handler is available.
Commits internally (long-running, self-contained operation).
Returns dict with message, source, stats.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
handler = _get_sync_handler(ds.name)
if handler is None:
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise BusinessRuleViolation(
f"Sync failed for '{ds.display_name}'. Check server logs for details."
)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
return {
"message": f"Sync complete for {ds.display_name}",
"source": ds.name,
"stats": summary,
}
def sync_all_sources(db: Session) -> list[dict]:
"""Trigger sync for all enabled data sources (sequentially).
Commits internally (long-running, self-contained operation).
Returns list of result dicts with source, status, stats/detail.
"""
enabled_sources = (
db.query(DataSource)
.filter(DataSource.is_enabled == True)
.order_by(DataSource.name)
.all()
)
results = []
for ds in enabled_sources:
handler = _get_sync_handler(ds.name)
if handler is None:
results.append({
"source": ds.name,
"status": "skipped",
"detail": "No sync handler available",
})
continue
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
results.append({
"source": ds.name,
"status": "success",
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
results.append({
"source": ds.name,
"status": "error",
"detail": "Sync failed. Check server logs for details.",
})
return results
def get_source_stats(db: Session, source_id: str) -> dict:
"""Return detailed statistics for a data source.
Raises EntityNotFoundError if source does not exist.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
from app.models.test_template import TestTemplate
from app.models.detection_rule import DetectionRule
template_count = 0
rule_count = 0
if ds.type == "attack_procedure":
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.source == ds.name)
.count()
)
elif ds.type == "detection_rule":
rule_count = (
db.query(DetectionRule)
.filter(DetectionRule.source == ds.name)
.count()
)
return {
"id": str(ds.id),
"name": ds.name,
"display_name": ds.display_name,
"type": ds.type,
"is_enabled": ds.is_enabled,
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
"last_sync_status": ds.last_sync_status,
"last_sync_stats": ds.last_sync_stats,
"total_templates": template_count,
"total_rules": rule_count,
}

View File

@@ -0,0 +1,88 @@
"""User management service — framework-agnostic CRUD for users.
Uses domain exceptions from app.domain.errors. The router handles
HTTP concerns, auth, audit logging, and commit.
"""
from __future__ import annotations
import uuid
from sqlalchemy.orm import Session
from app.auth import hash_password
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
from app.models.user import User
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
def list_users(db: Session) -> list[User]:
"""Return a list of all users ordered by username."""
return db.query(User).order_by(User.username).all()
def create_user(
db: Session,
*,
username: str,
email: str | None,
password: str,
role: str,
) -> User:
"""Create a new user.
Raises DuplicateEntityError if username already exists.
Raises BusinessRuleViolation if role is invalid.
Does not commit; the router handles that.
"""
existing = db.query(User).filter(User.username == username).first()
if existing:
raise DuplicateEntityError("User", "username", username)
if role not in VALID_ROLES:
raise BusinessRuleViolation(
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
user = User(
username=username,
email=email,
hashed_password=hash_password(password),
role=role,
)
db.add(user)
return user
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
"""Return a user by ID or raise EntityNotFoundError."""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise EntityNotFoundError("User", str(user_id))
return user
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
"""Update one or more fields of an existing user.
Raises EntityNotFoundError if user does not exist.
Raises BusinessRuleViolation if role is invalid.
Handles 'password' by hashing and storing as 'hashed_password'.
Does not commit; the router handles that.
"""
user = get_user_or_raise(db, user_id)
update_data = dict(fields)
if "role" in update_data and update_data["role"] not in VALID_ROLES:
raise BusinessRuleViolation(
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
if "password" in update_data:
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
for field, value in update_data.items():
setattr(user, field, value)
return user

View File

@@ -0,0 +1,175 @@
"""Tests for CampaignEntity — pure domain logic, no DB."""
import sys
import os
import uuid
from unittest.mock import MagicMock
import pytest
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
from app.domain.entities.campaign import (
CampaignEntity,
CampaignStatus,
CampaignType,
)
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
# ── Helpers ──────────────────────────────────────────────────────────
def _entity(status: str = "draft", test_count: int = 0, **overrides) -> CampaignEntity:
defaults = dict(
id=uuid.uuid4(),
name="Test Campaign",
type=CampaignType.custom,
status=CampaignStatus(status),
description=None,
threat_actor_id=None,
created_by=None,
target_platform=None,
tags=[],
test_count=test_count,
)
defaults.update(overrides)
return CampaignEntity(**defaults)
def _fake_orm(status: str = "draft", test_count: int = 0, **overrides) -> MagicMock:
m = MagicMock()
m.id = uuid.uuid4()
m.name = "Test Campaign"
m.type = "custom"
m.status = status
m.description = None
m.threat_actor_id = None
m.created_by = None
m.target_platform = None
m.tags = []
m.campaign_tests = [MagicMock()] * test_count if test_count else []
for k, v in overrides.items():
setattr(m, k, v)
return m
# ── 1. Test activation from draft with tests → success ───────────────
def test_activate_from_draft_with_tests_success():
e = _entity("draft", test_count=1)
e.activate()
assert e.status == CampaignStatus.active
def test_activate_from_draft_with_multiple_tests_success():
e = _entity("draft", test_count=3)
e.activate()
assert e.status == CampaignStatus.active
# ── 2. Test activation from draft with 0 tests → BusinessRuleViolation ───
def test_activate_from_draft_with_zero_tests_raises():
e = _entity("draft", test_count=0)
with pytest.raises(BusinessRuleViolation, match="at least one test"):
e.activate()
assert e.status == CampaignStatus.draft
# ── 3. Test activation from active → InvalidStateTransition ────────────
def test_activate_from_active_raises():
e = _entity("active", test_count=2)
with pytest.raises(InvalidStateTransition) as exc_info:
e.activate()
assert exc_info.value.current_state == "active"
assert exc_info.value.target_state == "active"
assert "completed" in exc_info.value.valid_transitions
# ── 4. Test complete from active → success ──────────────────────────────
def test_complete_from_active_success():
e = _entity("active", test_count=2)
e.complete()
assert e.status == CampaignStatus.completed
# ── 5. Test complete from draft → InvalidStateTransition ────────────────
def test_complete_from_draft_raises():
e = _entity("draft", test_count=1)
with pytest.raises(InvalidStateTransition) as exc_info:
e.complete()
assert exc_info.value.current_state == "draft"
assert exc_info.value.target_state == "completed"
assert "active" in exc_info.value.valid_transitions
# ── 6. Test ensure_modifiable in draft/active → ok ──────────────────────
def test_ensure_modifiable_draft_ok():
e = _entity("draft")
e.ensure_modifiable() # no raise
def test_ensure_modifiable_active_ok():
e = _entity("active", test_count=1)
e.ensure_modifiable() # no raise
# ── 7. Test ensure_modifiable in completed → BusinessRuleViolation ──────
def test_ensure_modifiable_completed_raises():
e = _entity("completed", test_count=1)
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
e.ensure_modifiable()
def test_ensure_modifiable_archived_raises():
e = _entity("archived", test_count=1)
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
e.ensure_modifiable()
# ── 8. Test from_orm conversion ──────────────────────────────────────────
def test_from_orm_basic():
orm = _fake_orm("draft", test_count=0)
e = CampaignEntity.from_orm(orm)
assert e.name == "Test Campaign"
assert e.type == CampaignType.custom
assert e.status == CampaignStatus.draft
assert e.id == orm.id
assert e.test_count == 0
def test_from_orm_with_tests():
orm = _fake_orm("draft", test_count=3)
e = CampaignEntity.from_orm(orm)
assert e.test_count == 3
def test_from_orm_coerces_type_and_status():
orm = _fake_orm(status="active", type="apt_emulation", test_count=1)
e = CampaignEntity.from_orm(orm)
assert e.status == CampaignStatus.active
assert e.type == CampaignType.apt_emulation
def test_from_orm_handles_none_tags():
orm = _fake_orm("draft", test_count=0)
orm.tags = None
e = CampaignEntity.from_orm(orm)
assert e.tags == []

View File

@@ -0,0 +1,105 @@
"""Tests for compliance domain entities."""
import pytest
from app.domain.entities.compliance import (
ComplianceControlEntity,
ComplianceFrameworkEntity,
ControlCoverageStatus,
)
# ── Control coverage status ───────────────────────────────────────────────
def test_control_all_techniques_validated_covered():
"""All techniques validated → covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["validated", "validated"],
)
assert control.coverage_status == ControlCoverageStatus.covered
def test_control_all_techniques_partial_covered():
"""All techniques partial → covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["partial"],
)
assert control.coverage_status == ControlCoverageStatus.covered
def test_control_mixed_statuses_partially_covered():
"""Mixed statuses (some validated/partial, some not) → partially_covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["validated", "not_evaluated"],
)
assert control.coverage_status == ControlCoverageStatus.partially_covered
def test_control_no_validated_techniques_not_covered():
"""No validated/partial techniques → not_covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["not_evaluated", "not_covered"],
)
assert control.coverage_status == ControlCoverageStatus.not_covered
def test_control_empty_techniques_not_covered():
"""Empty technique_statuses → not_covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=[],
)
assert control.coverage_status == ControlCoverageStatus.not_covered
# ── Framework coverage ─────────────────────────────────────────────────────
def test_framework_coverage_pct_calculation():
"""Framework coverage_pct = (covered_controls / total_controls) * 100."""
controls = [
ComplianceControlEntity("AC-1", "Title 1", technique_statuses=["validated"]),
ComplianceControlEntity("AC-2", "Title 2", technique_statuses=["not_evaluated"]),
ComplianceControlEntity("AC-3", "Title 3", technique_statuses=["validated", "partial"]),
ComplianceControlEntity("AC-4", "Title 4", technique_statuses=["partial"]),
ComplianceControlEntity("AC-5", "Title 5", technique_statuses=[]),
]
framework = ComplianceFrameworkEntity(name="NIST 800-53", controls=controls)
# AC-1: covered, AC-2: not_covered, AC-3: covered, AC-4: covered, AC-5: not_covered
assert framework.total_controls == 5
assert framework.covered_controls == 3
assert framework.coverage_pct == 60.0
def test_framework_get_gap_controls():
"""get_gap_controls returns only uncovered and partially_covered controls."""
controls = [
ComplianceControlEntity("AC-1", "Covered", technique_statuses=["validated"]),
ComplianceControlEntity("AC-2", "Partial", technique_statuses=["validated", "not_evaluated"]),
ComplianceControlEntity("AC-3", "Not Covered", technique_statuses=["not_evaluated"]),
ComplianceControlEntity("AC-4", "Empty", technique_statuses=[]),
]
framework = ComplianceFrameworkEntity(name="Test", controls=controls)
gaps = framework.get_gap_controls()
assert len(gaps) == 3
assert gaps[0].control_id == "AC-2"
assert gaps[1].control_id == "AC-3"
assert gaps[2].control_id == "AC-4"
def test_framework_no_controls_coverage_pct_zero():
"""Framework with no controls → coverage_pct is 0."""
framework = ComplianceFrameworkEntity(name="Empty", controls=[])
assert framework.total_controls == 0
assert framework.covered_controls == 0
assert framework.coverage_pct == 0.0