feat: add Campaign/Compliance domain entities and extract users/audit/data_sources to services (LP-2 through LP-6)
This commit is contained in:
@@ -1,3 +1,15 @@
|
|||||||
|
from app.domain.entities.campaign import CampaignEntity
|
||||||
|
from app.domain.entities.compliance import (
|
||||||
|
ComplianceControlEntity,
|
||||||
|
ComplianceFrameworkEntity,
|
||||||
|
ControlCoverageStatus,
|
||||||
|
)
|
||||||
from app.domain.entities.technique import TechniqueEntity
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
|
||||||
__all__ = ["TechniqueEntity"]
|
__all__ = [
|
||||||
|
"CampaignEntity",
|
||||||
|
"ComplianceControlEntity",
|
||||||
|
"ComplianceFrameworkEntity",
|
||||||
|
"ControlCoverageStatus",
|
||||||
|
"TechniqueEntity",
|
||||||
|
]
|
||||||
|
|||||||
103
backend/app/domain/entities/campaign.py
Normal file
103
backend/app/domain/entities/campaign.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Campaign domain entity with lifecycle validation.
|
||||||
|
|
||||||
|
Pure domain logic — no framework imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignStatus(str, enum.Enum):
|
||||||
|
draft = "draft"
|
||||||
|
active = "active"
|
||||||
|
completed = "completed"
|
||||||
|
archived = "archived"
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignType(str, enum.Enum):
|
||||||
|
custom = "custom"
|
||||||
|
apt_emulation = "apt_emulation"
|
||||||
|
kill_chain = "kill_chain"
|
||||||
|
compliance = "compliance"
|
||||||
|
|
||||||
|
|
||||||
|
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
||||||
|
CampaignStatus.draft: [CampaignStatus.active],
|
||||||
|
CampaignStatus.active: [CampaignStatus.completed],
|
||||||
|
CampaignStatus.completed: [CampaignStatus.archived],
|
||||||
|
CampaignStatus.archived: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignEntity:
|
||||||
|
name: str
|
||||||
|
type: CampaignType = CampaignType.custom
|
||||||
|
status: CampaignStatus = CampaignStatus.draft
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
description: str | None = None
|
||||||
|
threat_actor_id: uuid.UUID | None = None
|
||||||
|
created_by: uuid.UUID | None = None
|
||||||
|
target_platform: str | None = None
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
test_count: int = 0
|
||||||
|
|
||||||
|
def can_transition_to(self, target: CampaignStatus) -> bool:
|
||||||
|
return target in VALID_TRANSITIONS.get(self.status, [])
|
||||||
|
|
||||||
|
def activate(self) -> None:
|
||||||
|
if not self.can_transition_to(CampaignStatus.active):
|
||||||
|
raise InvalidStateTransition(
|
||||||
|
self.status.value, CampaignStatus.active.value,
|
||||||
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
|
)
|
||||||
|
if self.test_count == 0:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
"Campaign must have at least one test to activate"
|
||||||
|
)
|
||||||
|
self.status = CampaignStatus.active
|
||||||
|
|
||||||
|
def complete(self) -> None:
|
||||||
|
if not self.can_transition_to(CampaignStatus.completed):
|
||||||
|
raise InvalidStateTransition(
|
||||||
|
self.status.value, CampaignStatus.completed.value,
|
||||||
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
|
)
|
||||||
|
self.status = CampaignStatus.completed
|
||||||
|
|
||||||
|
def archive(self) -> None:
|
||||||
|
if not self.can_transition_to(CampaignStatus.archived):
|
||||||
|
raise InvalidStateTransition(
|
||||||
|
self.status.value, CampaignStatus.archived.value,
|
||||||
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
|
)
|
||||||
|
self.status = CampaignStatus.archived
|
||||||
|
|
||||||
|
def ensure_modifiable(self) -> None:
|
||||||
|
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot modify campaign in '{self.status.value}' state"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_orm(cls, orm: Any) -> CampaignEntity:
|
||||||
|
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
|
||||||
|
test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||||
|
return cls(
|
||||||
|
id=orm.id,
|
||||||
|
name=orm.name,
|
||||||
|
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
|
||||||
|
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
|
||||||
|
description=orm.description,
|
||||||
|
threat_actor_id=orm.threat_actor_id,
|
||||||
|
created_by=orm.created_by,
|
||||||
|
target_platform=orm.target_platform,
|
||||||
|
tags=orm.tags or [],
|
||||||
|
test_count=test_count,
|
||||||
|
)
|
||||||
71
backend/app/domain/entities/compliance.py
Normal file
71
backend/app/domain/entities/compliance.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""Compliance domain entities with coverage calculation logic.
|
||||||
|
|
||||||
|
Pure domain logic — no framework imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
class ControlCoverageStatus(str, enum.Enum):
|
||||||
|
covered = "covered"
|
||||||
|
partially_covered = "partially_covered"
|
||||||
|
not_covered = "not_covered"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComplianceControlEntity:
|
||||||
|
control_id: str
|
||||||
|
title: str
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
description: str | None = None
|
||||||
|
category: str | None = None
|
||||||
|
technique_statuses: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coverage_status(self) -> ControlCoverageStatus:
|
||||||
|
if not self.technique_statuses:
|
||||||
|
return ControlCoverageStatus.not_covered
|
||||||
|
covered_statuses = {"validated", "partial"}
|
||||||
|
covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||||
|
if len(covered) == len(self.technique_statuses):
|
||||||
|
return ControlCoverageStatus.covered
|
||||||
|
elif len(covered) > 0:
|
||||||
|
return ControlCoverageStatus.partially_covered
|
||||||
|
return ControlCoverageStatus.not_covered
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComplianceFrameworkEntity:
|
||||||
|
name: str
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
version: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
is_active: bool = True
|
||||||
|
controls: list[ComplianceControlEntity] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_controls(self) -> int:
|
||||||
|
return len(self.controls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def covered_controls(self) -> int:
|
||||||
|
return sum(
|
||||||
|
1 for c in self.controls
|
||||||
|
if c.coverage_status == ControlCoverageStatus.covered
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coverage_pct(self) -> float:
|
||||||
|
if self.total_controls == 0:
|
||||||
|
return 0.0
|
||||||
|
return round(self.covered_controls / self.total_controls * 100, 1)
|
||||||
|
|
||||||
|
def get_gap_controls(self) -> list[ComplianceControlEntity]:
|
||||||
|
return [
|
||||||
|
c for c in self.controls
|
||||||
|
if c.coverage_status != ControlCoverageStatus.covered
|
||||||
|
]
|
||||||
@@ -4,14 +4,17 @@ from datetime import datetime
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy import func
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.orm import Session, joinedload
|
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
from app.models.audit import AuditLog
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||||
|
from app.services.audit_query_service import (
|
||||||
|
list_distinct_actions,
|
||||||
|
list_distinct_entity_types,
|
||||||
|
list_logs,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||||
|
|
||||||
@@ -32,53 +35,22 @@ def list_audit_logs(
|
|||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
result = list_logs(
|
||||||
|
db,
|
||||||
# Apply filters
|
user_id=user_id,
|
||||||
if user_id:
|
action=action,
|
||||||
query = query.filter(AuditLog.user_id == user_id)
|
entity_type=entity_type,
|
||||||
if action:
|
start_date=start_date,
|
||||||
query = query.filter(AuditLog.action == action)
|
end_date=end_date,
|
||||||
if entity_type:
|
|
||||||
query = query.filter(AuditLog.entity_type == entity_type)
|
|
||||||
if start_date:
|
|
||||||
query = query.filter(AuditLog.timestamp >= start_date)
|
|
||||||
if end_date:
|
|
||||||
query = query.filter(AuditLog.timestamp <= end_date)
|
|
||||||
|
|
||||||
# Get total count
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
# Get paginated results
|
|
||||||
logs = (
|
|
||||||
query
|
|
||||||
.order_by(AuditLog.timestamp.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to response format with username
|
|
||||||
items = []
|
|
||||||
for log in logs:
|
|
||||||
item = AuditLogOut(
|
|
||||||
id=log.id,
|
|
||||||
user_id=log.user_id,
|
|
||||||
username=log.user.username if log.user else None,
|
|
||||||
action=log.action,
|
|
||||||
entity_type=log.entity_type,
|
|
||||||
entity_id=log.entity_id,
|
|
||||||
timestamp=log.timestamp,
|
|
||||||
details=log.details,
|
|
||||||
)
|
|
||||||
items.append(item)
|
|
||||||
|
|
||||||
return AuditLogPage(
|
|
||||||
items=items,
|
|
||||||
total=total,
|
|
||||||
offset=offset,
|
offset=offset,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
return AuditLogPage(
|
||||||
|
items=[AuditLogOut(**item) for item in result["items"]],
|
||||||
|
total=result["total"],
|
||||||
|
offset=result["offset"],
|
||||||
|
limit=result["limit"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/actions", response_model=list[str])
|
@router.get("/actions", response_model=list[str])
|
||||||
@@ -90,13 +62,7 @@ def list_actions(
|
|||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
actions = (
|
return list_distinct_actions(db)
|
||||||
db.query(AuditLog.action)
|
|
||||||
.distinct()
|
|
||||||
.order_by(AuditLog.action)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [a[0] for a in actions]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/entity-types", response_model=list[str])
|
@router.get("/entity-types", response_model=list[str])
|
||||||
@@ -108,11 +74,4 @@ def list_entity_types(
|
|||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
types = (
|
return list_distinct_entity_types(db)
|
||||||
db.query(AuditLog.entity_type)
|
|
||||||
.filter(AuditLog.entity_type.isnot(None))
|
|
||||||
.distinct()
|
|
||||||
.order_by(AuditLog.entity_type)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [t[0] for t in types]
|
|
||||||
|
|||||||
@@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources
|
|||||||
including sync triggers, enable/disable toggles, and statistics.
|
including sync triggers, enable/disable toggles, and statistics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
from fastapi import APIRouter, Depends
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.data_source import DataSource
|
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
from app.services.data_source_service import (
|
||||||
|
get_source_stats,
|
||||||
|
list_sources,
|
||||||
|
sync_all_sources,
|
||||||
|
sync_source,
|
||||||
|
update_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel):
|
|||||||
sync_frequency: Optional[str] = None
|
sync_frequency: Optional[str] = None
|
||||||
config: Optional[dict] = None
|
config: Optional[dict] = None
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Sync dispatcher — maps source name → import function
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_sync_handler(source_name: str):
|
|
||||||
"""Lazily import and return the sync function for *source_name*.
|
|
||||||
|
|
||||||
We import lazily to avoid circular imports and to only load the
|
|
||||||
modules that are actually needed.
|
|
||||||
"""
|
|
||||||
handlers = {
|
|
||||||
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
|
|
||||||
"sigma": ("app.services.sigma_import_service", "sync"),
|
|
||||||
"lolbas": ("app.services.lolbas_import_service", "sync"),
|
|
||||||
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
|
|
||||||
"caldera": ("app.services.caldera_import_service", "sync"),
|
|
||||||
"elastic_rules": ("app.services.elastic_import_service", "sync"),
|
|
||||||
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
|
|
||||||
"d3fend": ("app.services.d3fend_import_service", "sync"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if source_name not in handlers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
module_path, func_name = handlers[source_name]
|
|
||||||
import importlib
|
|
||||||
mod = importlib.import_module(module_path)
|
|
||||||
return getattr(mod, func_name)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Endpoints
|
# Endpoints
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -79,25 +52,7 @@ def list_data_sources(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
return list_sources(db)
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": str(s.id),
|
|
||||||
"name": s.name,
|
|
||||||
"display_name": s.display_name,
|
|
||||||
"type": s.type,
|
|
||||||
"url": s.url,
|
|
||||||
"description": s.description,
|
|
||||||
"is_enabled": s.is_enabled,
|
|
||||||
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
|
||||||
"last_sync_status": s.last_sync_status,
|
|
||||||
"last_sync_stats": s.last_sync_stats,
|
|
||||||
"sync_frequency": s.sync_frequency,
|
|
||||||
"config": s.config,
|
|
||||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
|
||||||
}
|
|
||||||
for s in sources
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{source_id}")
|
@router.patch("/{source_id}")
|
||||||
@@ -111,31 +66,21 @@ def update_data_source(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
||||||
if not ds:
|
|
||||||
raise HTTPException(status_code=404, detail="Data source not found")
|
|
||||||
|
|
||||||
update_data = body.model_dump(exclude_unset=True)
|
update_data = body.model_dump(exclude_unset=True)
|
||||||
|
update_source(db, source_id, **update_data)
|
||||||
if "is_enabled" in update_data:
|
with UnitOfWork(db) as uow:
|
||||||
ds.is_enabled = update_data["is_enabled"]
|
uow.commit()
|
||||||
if "sync_frequency" in update_data:
|
|
||||||
ds.sync_frequency = update_data["sync_frequency"]
|
|
||||||
if "config" in update_data:
|
|
||||||
ds.config = update_data["config"]
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
action="update_data_source",
|
action="update_data_source",
|
||||||
entity_type="data_source",
|
entity_type="data_source",
|
||||||
entity_id=str(ds.id),
|
entity_id=source_id,
|
||||||
details={"updates": update_data},
|
details={"updates": update_data},
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"message": "Data source updated", "id": str(ds.id)}
|
return {"message": "Data source updated", "id": source_id}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{source_id}/sync")
|
@router.post("/{source_id}/sync")
|
||||||
@@ -148,46 +93,7 @@ def sync_data_source(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
return sync_source(db, source_id)
|
||||||
if not ds:
|
|
||||||
raise HTTPException(status_code=404, detail="Data source not found")
|
|
||||||
|
|
||||||
handler = _get_sync_handler(ds.name)
|
|
||||||
if handler is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"No sync handler available for '{ds.name}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark as in_progress
|
|
||||||
ds.last_sync_status = "in_progress"
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
summary = handler(db)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
|
||||||
ds.last_sync_status = "error"
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_stats = {"error": str(exc)}
|
|
||||||
db.commit()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update DS record (the handler may already have done this,
|
|
||||||
# but we ensure it here as well)
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_status = "success"
|
|
||||||
ds.last_sync_stats = summary
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": f"Sync complete for {ds.display_name}",
|
|
||||||
"source": ds.name,
|
|
||||||
"stats": summary,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sync-all")
|
@router.post("/sync-all")
|
||||||
@@ -199,49 +105,7 @@ def sync_all_data_sources(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
enabled_sources = (
|
results = sync_all_sources(db)
|
||||||
db.query(DataSource)
|
|
||||||
.filter(DataSource.is_enabled == True)
|
|
||||||
.order_by(DataSource.name)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for ds in enabled_sources:
|
|
||||||
handler = _get_sync_handler(ds.name)
|
|
||||||
if handler is None:
|
|
||||||
results.append({
|
|
||||||
"source": ds.name,
|
|
||||||
"status": "skipped",
|
|
||||||
"detail": "No sync handler available",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
ds.last_sync_status = "in_progress"
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
summary = handler(db)
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_status = "success"
|
|
||||||
ds.last_sync_stats = summary
|
|
||||||
db.commit()
|
|
||||||
results.append({
|
|
||||||
"source": ds.name,
|
|
||||||
"status": "success",
|
|
||||||
"stats": summary,
|
|
||||||
})
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
|
||||||
ds.last_sync_status = "error"
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_stats = {"error": str(exc)}
|
|
||||||
db.commit()
|
|
||||||
results.append({
|
|
||||||
"source": ds.name,
|
|
||||||
"status": "error",
|
|
||||||
"detail": "Sync failed. Check server logs for details.",
|
|
||||||
})
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
@@ -265,39 +129,4 @@ def get_data_source_stats(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
return get_source_stats(db, source_id)
|
||||||
if not ds:
|
|
||||||
raise HTTPException(status_code=404, detail="Data source not found")
|
|
||||||
|
|
||||||
# Count items from this source
|
|
||||||
from app.models.test_template import TestTemplate
|
|
||||||
from app.models.detection_rule import DetectionRule
|
|
||||||
|
|
||||||
template_count = 0
|
|
||||||
rule_count = 0
|
|
||||||
|
|
||||||
if ds.type == "attack_procedure":
|
|
||||||
template_count = (
|
|
||||||
db.query(TestTemplate)
|
|
||||||
.filter(TestTemplate.source == ds.name)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
elif ds.type == "detection_rule":
|
|
||||||
rule_count = (
|
|
||||||
db.query(DetectionRule)
|
|
||||||
.filter(DetectionRule.source == ds.name)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(ds.id),
|
|
||||||
"name": ds.name,
|
|
||||||
"display_name": ds.display_name,
|
|
||||||
"type": ds.type,
|
|
||||||
"is_enabled": ds.is_enabled,
|
|
||||||
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
|
||||||
"last_sync_status": ds.last_sync_status,
|
|
||||||
"last_sync_stats": ds.last_sync_stats,
|
|
||||||
"total_templates": template_count,
|
|
||||||
"total_rules": rule_count,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,20 +2,24 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
||||||
from app.auth import hash_password
|
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
from app.services.user_service import (
|
||||||
|
create_user,
|
||||||
|
get_user_or_raise,
|
||||||
|
list_users,
|
||||||
|
update_user,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/users", tags=["users"])
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /users — list all users
|
# GET /users — list all users
|
||||||
@@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[UserOut])
|
@router.get("", response_model=list[UserOut])
|
||||||
def list_users(
|
def list_users_route(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Return a list of all users. **Requires admin role.**"""
|
"""Return a list of all users. **Requires admin role.**"""
|
||||||
return db.query(User).order_by(User.username).all()
|
return list_users(db)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -37,36 +41,21 @@ def list_users(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||||
def create_user(
|
def create_user_route(
|
||||||
payload: UserCreate,
|
payload: UserCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Create a new user. **Requires admin role.**"""
|
"""Create a new user. **Requires admin role.**"""
|
||||||
|
user = create_user(
|
||||||
# Check if username already exists
|
db,
|
||||||
existing = db.query(User).filter(User.username == payload.username).first()
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=f"Username '{payload.username}' already exists",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate role
|
|
||||||
if payload.role not in VALID_ROLES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
|
||||||
)
|
|
||||||
|
|
||||||
user = User(
|
|
||||||
username=payload.username,
|
username=payload.username,
|
||||||
email=payload.email,
|
email=payload.email,
|
||||||
hashed_password=hash_password(payload.password),
|
password=payload.password,
|
||||||
role=payload.role,
|
role=payload.role,
|
||||||
)
|
)
|
||||||
db.add(user)
|
with UnitOfWork(db) as uow:
|
||||||
db.commit()
|
uow.commit()
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -93,13 +82,7 @@ def get_user(
|
|||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Return a single user by ID. **Requires admin role.**"""
|
"""Return a single user by ID. **Requires admin role.**"""
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
return get_user_or_raise(db, user_id)
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -108,37 +91,17 @@ def get_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{user_id}", response_model=UserOut)
|
@router.patch("/{user_id}", response_model=UserOut)
|
||||||
def update_user(
|
def update_user_route(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
payload: UserUpdate,
|
payload: UserUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
|
user = update_user(db, user_id, **update_data)
|
||||||
# Validate role if being updated
|
with UnitOfWork(db) as uow:
|
||||||
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
uow.commit()
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hash password if being updated
|
|
||||||
if "password" in update_data:
|
|
||||||
update_data["hashed_password"] = hash_password(update_data.pop("password"))
|
|
||||||
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(user, field, value)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -147,7 +110,7 @@ def update_user(
|
|||||||
action="update_user",
|
action="update_user",
|
||||||
entity_type="user",
|
entity_type="user",
|
||||||
entity_id=user.id,
|
entity_id=user.id,
|
||||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
details={"updated_fields": list(update_data.keys())},
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|||||||
93
backend/app/services/audit_query_service.py
Normal file
93
backend/app/services/audit_query_service.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Audit log query service — framework-agnostic query logic for audit logs.
|
||||||
|
|
||||||
|
Provides paginated logs and distinct action/entity-type lists.
|
||||||
|
No FastAPI imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
|
||||||
|
|
||||||
|
def list_logs(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
action: str | None = None,
|
||||||
|
entity_type: str | None = None,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""Return paginated audit logs with optional filters.
|
||||||
|
|
||||||
|
Returns a dict with keys: items, total, offset, limit.
|
||||||
|
Each item is a dict with: id, user_id, username, action, entity_type,
|
||||||
|
entity_id, timestamp, details.
|
||||||
|
"""
|
||||||
|
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
query = query.filter(AuditLog.user_id == user_id)
|
||||||
|
if action:
|
||||||
|
query = query.filter(AuditLog.action == action)
|
||||||
|
if entity_type:
|
||||||
|
query = query.filter(AuditLog.entity_type == entity_type)
|
||||||
|
if start_date:
|
||||||
|
query = query.filter(AuditLog.timestamp >= start_date)
|
||||||
|
if end_date:
|
||||||
|
query = query.filter(AuditLog.timestamp <= end_date)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
logs = (
|
||||||
|
query
|
||||||
|
.order_by(AuditLog.timestamp.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
{
|
||||||
|
"id": log.id,
|
||||||
|
"user_id": log.user_id,
|
||||||
|
"username": log.user.username if log.user else None,
|
||||||
|
"action": log.action,
|
||||||
|
"entity_type": log.entity_type,
|
||||||
|
"entity_id": log.entity_id,
|
||||||
|
"timestamp": log.timestamp,
|
||||||
|
"details": log.details,
|
||||||
|
}
|
||||||
|
for log in logs
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"items": items, "total": total, "offset": offset, "limit": limit}
|
||||||
|
|
||||||
|
|
||||||
|
def list_distinct_actions(db: Session) -> list[str]:
|
||||||
|
"""Return a list of distinct action types in the audit log."""
|
||||||
|
actions = (
|
||||||
|
db.query(AuditLog.action)
|
||||||
|
.distinct()
|
||||||
|
.order_by(AuditLog.action)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [a[0] for a in actions]
|
||||||
|
|
||||||
|
|
||||||
|
def list_distinct_entity_types(db: Session) -> list[str]:
|
||||||
|
"""Return a list of distinct entity types in the audit log."""
|
||||||
|
types = (
|
||||||
|
db.query(AuditLog.entity_type)
|
||||||
|
.filter(AuditLog.entity_type.isnot(None))
|
||||||
|
.distinct()
|
||||||
|
.order_by(AuditLog.entity_type)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [t[0] for t in types]
|
||||||
222
backend/app/services/data_source_service.py
Normal file
222
backend/app/services/data_source_service.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Data source management service — framework-agnostic query and sync logic.
|
||||||
|
|
||||||
|
Provides list, update, sync, and stats. Sync operations commit internally
|
||||||
|
since they are long-running and self-contained.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||||
|
from app.models.data_source import DataSource
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sync_handler(source_name: str):
|
||||||
|
"""Lazily import and return the sync function for *source_name*.
|
||||||
|
|
||||||
|
We import lazily to avoid circular imports and to only load the
|
||||||
|
modules that are actually needed.
|
||||||
|
"""
|
||||||
|
handlers = {
|
||||||
|
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
|
||||||
|
"sigma": ("app.services.sigma_import_service", "sync"),
|
||||||
|
"lolbas": ("app.services.lolbas_import_service", "sync"),
|
||||||
|
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
|
||||||
|
"caldera": ("app.services.caldera_import_service", "sync"),
|
||||||
|
"elastic_rules": ("app.services.elastic_import_service", "sync"),
|
||||||
|
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
|
||||||
|
"d3fend": ("app.services.d3fend_import_service", "sync"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if source_name not in handlers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
module_path, func_name = handlers[source_name]
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
return getattr(mod, func_name)
|
||||||
|
|
||||||
|
|
||||||
|
def list_sources(db: Session) -> list[dict]:
|
||||||
|
"""Return all registered data sources as a list of dicts."""
|
||||||
|
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(s.id),
|
||||||
|
"name": s.name,
|
||||||
|
"display_name": s.display_name,
|
||||||
|
"type": s.type,
|
||||||
|
"url": s.url,
|
||||||
|
"description": s.description,
|
||||||
|
"is_enabled": s.is_enabled,
|
||||||
|
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
||||||
|
"last_sync_status": s.last_sync_status,
|
||||||
|
"last_sync_stats": s.last_sync_stats,
|
||||||
|
"sync_frequency": s.sync_frequency,
|
||||||
|
"config": s.config,
|
||||||
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
|
}
|
||||||
|
for s in sources
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def update_source(db: Session, source_id: str, **fields: object) -> None:
|
||||||
|
"""Update a data source's fields (is_enabled, sync_frequency, config).
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if source does not exist.
|
||||||
|
Does not commit; the router handles that.
|
||||||
|
"""
|
||||||
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||||
|
if not ds:
|
||||||
|
raise EntityNotFoundError("Data source", source_id)
|
||||||
|
|
||||||
|
if "is_enabled" in fields:
|
||||||
|
ds.is_enabled = fields["is_enabled"]
|
||||||
|
if "sync_frequency" in fields:
|
||||||
|
ds.sync_frequency = fields["sync_frequency"]
|
||||||
|
if "config" in fields:
|
||||||
|
ds.config = fields["config"]
|
||||||
|
|
||||||
|
|
||||||
|
def sync_source(db: Session, source_id: str) -> dict:
|
||||||
|
"""Trigger sync for a specific data source.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if source does not exist.
|
||||||
|
Raises BusinessRuleViolation if no sync handler is available.
|
||||||
|
Commits internally (long-running, self-contained operation).
|
||||||
|
Returns dict with message, source, stats.
|
||||||
|
"""
|
||||||
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||||
|
if not ds:
|
||||||
|
raise EntityNotFoundError("Data source", source_id)
|
||||||
|
|
||||||
|
handler = _get_sync_handler(ds.name)
|
||||||
|
if handler is None:
|
||||||
|
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
|
||||||
|
|
||||||
|
ds.last_sync_status = "in_progress"
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = handler(db)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||||
|
ds.last_sync_status = "error"
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_stats = {"error": str(exc)}
|
||||||
|
db.commit()
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Sync failed for '{ds.display_name}'. Check server logs for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_status = "success"
|
||||||
|
ds.last_sync_stats = summary
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Sync complete for {ds.display_name}",
|
||||||
|
"source": ds.name,
|
||||||
|
"stats": summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sync_all_sources(db: Session) -> list[dict]:
|
||||||
|
"""Trigger sync for all enabled data sources (sequentially).
|
||||||
|
|
||||||
|
Commits internally (long-running, self-contained operation).
|
||||||
|
Returns list of result dicts with source, status, stats/detail.
|
||||||
|
"""
|
||||||
|
enabled_sources = (
|
||||||
|
db.query(DataSource)
|
||||||
|
.filter(DataSource.is_enabled == True)
|
||||||
|
.order_by(DataSource.name)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ds in enabled_sources:
|
||||||
|
handler = _get_sync_handler(ds.name)
|
||||||
|
if handler is None:
|
||||||
|
results.append({
|
||||||
|
"source": ds.name,
|
||||||
|
"status": "skipped",
|
||||||
|
"detail": "No sync handler available",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
ds.last_sync_status = "in_progress"
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = handler(db)
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_status = "success"
|
||||||
|
ds.last_sync_stats = summary
|
||||||
|
db.commit()
|
||||||
|
results.append({
|
||||||
|
"source": ds.name,
|
||||||
|
"status": "success",
|
||||||
|
"stats": summary,
|
||||||
|
})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||||
|
ds.last_sync_status = "error"
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_stats = {"error": str(exc)}
|
||||||
|
db.commit()
|
||||||
|
results.append({
|
||||||
|
"source": ds.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": "Sync failed. Check server logs for details.",
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def get_source_stats(db: Session, source_id: str) -> dict:
|
||||||
|
"""Return detailed statistics for a data source.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if source does not exist.
|
||||||
|
"""
|
||||||
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||||
|
if not ds:
|
||||||
|
raise EntityNotFoundError("Data source", source_id)
|
||||||
|
|
||||||
|
from app.models.test_template import TestTemplate
|
||||||
|
from app.models.detection_rule import DetectionRule
|
||||||
|
|
||||||
|
template_count = 0
|
||||||
|
rule_count = 0
|
||||||
|
|
||||||
|
if ds.type == "attack_procedure":
|
||||||
|
template_count = (
|
||||||
|
db.query(TestTemplate)
|
||||||
|
.filter(TestTemplate.source == ds.name)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
elif ds.type == "detection_rule":
|
||||||
|
rule_count = (
|
||||||
|
db.query(DetectionRule)
|
||||||
|
.filter(DetectionRule.source == ds.name)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(ds.id),
|
||||||
|
"name": ds.name,
|
||||||
|
"display_name": ds.display_name,
|
||||||
|
"type": ds.type,
|
||||||
|
"is_enabled": ds.is_enabled,
|
||||||
|
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
||||||
|
"last_sync_status": ds.last_sync_status,
|
||||||
|
"last_sync_stats": ds.last_sync_stats,
|
||||||
|
"total_templates": template_count,
|
||||||
|
"total_rules": rule_count,
|
||||||
|
}
|
||||||
88
backend/app/services/user_service.py
Normal file
88
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""User management service — framework-agnostic CRUD for users.
|
||||||
|
|
||||||
|
Uses domain exceptions from app.domain.errors. The router handles
|
||||||
|
HTTP concerns, auth, audit logging, and commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth import hash_password
|
||||||
|
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||||
|
|
||||||
|
|
||||||
|
def list_users(db: Session) -> list[User]:
|
||||||
|
"""Return a list of all users ordered by username."""
|
||||||
|
return db.query(User).order_by(User.username).all()
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
username: str,
|
||||||
|
email: str | None,
|
||||||
|
password: str,
|
||||||
|
role: str,
|
||||||
|
) -> User:
|
||||||
|
"""Create a new user.
|
||||||
|
|
||||||
|
Raises DuplicateEntityError if username already exists.
|
||||||
|
Raises BusinessRuleViolation if role is invalid.
|
||||||
|
Does not commit; the router handles that.
|
||||||
|
"""
|
||||||
|
existing = db.query(User).filter(User.username == username).first()
|
||||||
|
if existing:
|
||||||
|
raise DuplicateEntityError("User", "username", username)
|
||||||
|
|
||||||
|
if role not in VALID_ROLES:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username=username,
|
||||||
|
email=email,
|
||||||
|
hashed_password=hash_password(password),
|
||||||
|
role=role,
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
|
||||||
|
"""Return a user by ID or raise EntityNotFoundError."""
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if user is None:
|
||||||
|
raise EntityNotFoundError("User", str(user_id))
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
|
||||||
|
"""Update one or more fields of an existing user.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if user does not exist.
|
||||||
|
Raises BusinessRuleViolation if role is invalid.
|
||||||
|
Handles 'password' by hashing and storing as 'hashed_password'.
|
||||||
|
Does not commit; the router handles that.
|
||||||
|
"""
|
||||||
|
user = get_user_or_raise(db, user_id)
|
||||||
|
update_data = dict(fields)
|
||||||
|
|
||||||
|
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "password" in update_data:
|
||||||
|
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
|
||||||
|
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(user, field, value)
|
||||||
|
|
||||||
|
return user
|
||||||
175
backend/tests/test_campaign_entity.py
Normal file
175
backend/tests/test_campaign_entity.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""Tests for CampaignEntity — pure domain logic, no DB."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if backend_dir not in sys.path:
|
||||||
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
|
from app.domain.entities.campaign import (
|
||||||
|
CampaignEntity,
|
||||||
|
CampaignStatus,
|
||||||
|
CampaignType,
|
||||||
|
)
|
||||||
|
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _entity(status: str = "draft", test_count: int = 0, **overrides) -> CampaignEntity:
|
||||||
|
defaults = dict(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Campaign",
|
||||||
|
type=CampaignType.custom,
|
||||||
|
status=CampaignStatus(status),
|
||||||
|
description=None,
|
||||||
|
threat_actor_id=None,
|
||||||
|
created_by=None,
|
||||||
|
target_platform=None,
|
||||||
|
tags=[],
|
||||||
|
test_count=test_count,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return CampaignEntity(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_orm(status: str = "draft", test_count: int = 0, **overrides) -> MagicMock:
|
||||||
|
m = MagicMock()
|
||||||
|
m.id = uuid.uuid4()
|
||||||
|
m.name = "Test Campaign"
|
||||||
|
m.type = "custom"
|
||||||
|
m.status = status
|
||||||
|
m.description = None
|
||||||
|
m.threat_actor_id = None
|
||||||
|
m.created_by = None
|
||||||
|
m.target_platform = None
|
||||||
|
m.tags = []
|
||||||
|
m.campaign_tests = [MagicMock()] * test_count if test_count else []
|
||||||
|
for k, v in overrides.items():
|
||||||
|
setattr(m, k, v)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
# ── 1. Test activation from draft with tests → success ───────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_draft_with_tests_success():
|
||||||
|
e = _entity("draft", test_count=1)
|
||||||
|
e.activate()
|
||||||
|
assert e.status == CampaignStatus.active
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_draft_with_multiple_tests_success():
|
||||||
|
e = _entity("draft", test_count=3)
|
||||||
|
e.activate()
|
||||||
|
assert e.status == CampaignStatus.active
|
||||||
|
|
||||||
|
|
||||||
|
# ── 2. Test activation from draft with 0 tests → BusinessRuleViolation ───
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_draft_with_zero_tests_raises():
|
||||||
|
e = _entity("draft", test_count=0)
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="at least one test"):
|
||||||
|
e.activate()
|
||||||
|
assert e.status == CampaignStatus.draft
|
||||||
|
|
||||||
|
|
||||||
|
# ── 3. Test activation from active → InvalidStateTransition ────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_active_raises():
|
||||||
|
e = _entity("active", test_count=2)
|
||||||
|
with pytest.raises(InvalidStateTransition) as exc_info:
|
||||||
|
e.activate()
|
||||||
|
assert exc_info.value.current_state == "active"
|
||||||
|
assert exc_info.value.target_state == "active"
|
||||||
|
assert "completed" in exc_info.value.valid_transitions
|
||||||
|
|
||||||
|
|
||||||
|
# ── 4. Test complete from active → success ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_from_active_success():
|
||||||
|
e = _entity("active", test_count=2)
|
||||||
|
e.complete()
|
||||||
|
assert e.status == CampaignStatus.completed
|
||||||
|
|
||||||
|
|
||||||
|
# ── 5. Test complete from draft → InvalidStateTransition ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_from_draft_raises():
|
||||||
|
e = _entity("draft", test_count=1)
|
||||||
|
with pytest.raises(InvalidStateTransition) as exc_info:
|
||||||
|
e.complete()
|
||||||
|
assert exc_info.value.current_state == "draft"
|
||||||
|
assert exc_info.value.target_state == "completed"
|
||||||
|
assert "active" in exc_info.value.valid_transitions
|
||||||
|
|
||||||
|
|
||||||
|
# ── 6. Test ensure_modifiable in draft/active → ok ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_draft_ok():
|
||||||
|
e = _entity("draft")
|
||||||
|
e.ensure_modifiable() # no raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_active_ok():
|
||||||
|
e = _entity("active", test_count=1)
|
||||||
|
e.ensure_modifiable() # no raise
|
||||||
|
|
||||||
|
|
||||||
|
# ── 7. Test ensure_modifiable in completed → BusinessRuleViolation ──────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_completed_raises():
|
||||||
|
e = _entity("completed", test_count=1)
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
|
||||||
|
e.ensure_modifiable()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_archived_raises():
|
||||||
|
e = _entity("archived", test_count=1)
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
|
||||||
|
e.ensure_modifiable()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 8. Test from_orm conversion ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_basic():
|
||||||
|
orm = _fake_orm("draft", test_count=0)
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.name == "Test Campaign"
|
||||||
|
assert e.type == CampaignType.custom
|
||||||
|
assert e.status == CampaignStatus.draft
|
||||||
|
assert e.id == orm.id
|
||||||
|
assert e.test_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_with_tests():
|
||||||
|
orm = _fake_orm("draft", test_count=3)
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.test_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_coerces_type_and_status():
|
||||||
|
orm = _fake_orm(status="active", type="apt_emulation", test_count=1)
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.status == CampaignStatus.active
|
||||||
|
assert e.type == CampaignType.apt_emulation
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_handles_none_tags():
|
||||||
|
orm = _fake_orm("draft", test_count=0)
|
||||||
|
orm.tags = None
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.tags == []
|
||||||
105
backend/tests/test_compliance_entity.py
Normal file
105
backend/tests/test_compliance_entity.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Tests for compliance domain entities."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.domain.entities.compliance import (
|
||||||
|
ComplianceControlEntity,
|
||||||
|
ComplianceFrameworkEntity,
|
||||||
|
ControlCoverageStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Control coverage status ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_all_techniques_validated_covered():
|
||||||
|
"""All techniques validated → covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["validated", "validated"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_all_techniques_partial_covered():
|
||||||
|
"""All techniques partial → covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["partial"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_mixed_statuses_partially_covered():
|
||||||
|
"""Mixed statuses (some validated/partial, some not) → partially_covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["validated", "not_evaluated"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.partially_covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_no_validated_techniques_not_covered():
|
||||||
|
"""No validated/partial techniques → not_covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["not_evaluated", "not_covered"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.not_covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_empty_techniques_not_covered():
|
||||||
|
"""Empty technique_statuses → not_covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=[],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.not_covered
|
||||||
|
|
||||||
|
|
||||||
|
# ── Framework coverage ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_framework_coverage_pct_calculation():
|
||||||
|
"""Framework coverage_pct = (covered_controls / total_controls) * 100."""
|
||||||
|
controls = [
|
||||||
|
ComplianceControlEntity("AC-1", "Title 1", technique_statuses=["validated"]),
|
||||||
|
ComplianceControlEntity("AC-2", "Title 2", technique_statuses=["not_evaluated"]),
|
||||||
|
ComplianceControlEntity("AC-3", "Title 3", technique_statuses=["validated", "partial"]),
|
||||||
|
ComplianceControlEntity("AC-4", "Title 4", technique_statuses=["partial"]),
|
||||||
|
ComplianceControlEntity("AC-5", "Title 5", technique_statuses=[]),
|
||||||
|
]
|
||||||
|
framework = ComplianceFrameworkEntity(name="NIST 800-53", controls=controls)
|
||||||
|
# AC-1: covered, AC-2: not_covered, AC-3: covered, AC-4: covered, AC-5: not_covered
|
||||||
|
assert framework.total_controls == 5
|
||||||
|
assert framework.covered_controls == 3
|
||||||
|
assert framework.coverage_pct == 60.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_framework_get_gap_controls():
|
||||||
|
"""get_gap_controls returns only uncovered and partially_covered controls."""
|
||||||
|
controls = [
|
||||||
|
ComplianceControlEntity("AC-1", "Covered", technique_statuses=["validated"]),
|
||||||
|
ComplianceControlEntity("AC-2", "Partial", technique_statuses=["validated", "not_evaluated"]),
|
||||||
|
ComplianceControlEntity("AC-3", "Not Covered", technique_statuses=["not_evaluated"]),
|
||||||
|
ComplianceControlEntity("AC-4", "Empty", technique_statuses=[]),
|
||||||
|
]
|
||||||
|
framework = ComplianceFrameworkEntity(name="Test", controls=controls)
|
||||||
|
gaps = framework.get_gap_controls()
|
||||||
|
assert len(gaps) == 3
|
||||||
|
assert gaps[0].control_id == "AC-2"
|
||||||
|
assert gaps[1].control_id == "AC-3"
|
||||||
|
assert gaps[2].control_id == "AC-4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_framework_no_controls_coverage_pct_zero():
|
||||||
|
"""Framework with no controls → coverage_pct is 0."""
|
||||||
|
framework = ComplianceFrameworkEntity(name="Empty", controls=[])
|
||||||
|
assert framework.total_controls == 0
|
||||||
|
assert framework.covered_controls == 0
|
||||||
|
assert framework.coverage_pct == 0.0
|
||||||
Reference in New Issue
Block a user