From c0c6cda11df2f995e9f9bf377d2d17277a3af5b2 Mon Sep 17 00:00:00 2001 From: Kitos Date: Fri, 20 Feb 2026 13:28:14 +0100 Subject: [PATCH] feat: add Campaign/Compliance domain entities and extract users/audit/data_sources to services (LP-2 through LP-6) --- backend/app/domain/entities/__init__.py | 14 +- backend/app/domain/entities/campaign.py | 103 +++++++++ backend/app/domain/entities/compliance.py | 71 +++++++ backend/app/routers/audit.py | 89 +++----- backend/app/routers/data_sources.py | 209 ++---------------- backend/app/routers/users.py | 89 +++----- backend/app/services/audit_query_service.py | 93 ++++++++ backend/app/services/data_source_service.py | 222 ++++++++++++++++++++ backend/app/services/user_service.py | 88 ++++++++ backend/tests/test_campaign_entity.py | 175 +++++++++++++++ backend/tests/test_compliance_entity.py | 105 +++++++++ 11 files changed, 939 insertions(+), 319 deletions(-) create mode 100644 backend/app/domain/entities/campaign.py create mode 100644 backend/app/domain/entities/compliance.py create mode 100644 backend/app/services/audit_query_service.py create mode 100644 backend/app/services/data_source_service.py create mode 100644 backend/app/services/user_service.py create mode 100644 backend/tests/test_campaign_entity.py create mode 100644 backend/tests/test_compliance_entity.py diff --git a/backend/app/domain/entities/__init__.py b/backend/app/domain/entities/__init__.py index 79c7772..63969bf 100644 --- a/backend/app/domain/entities/__init__.py +++ b/backend/app/domain/entities/__init__.py @@ -1,3 +1,15 @@ +from app.domain.entities.campaign import CampaignEntity +from app.domain.entities.compliance import ( + ComplianceControlEntity, + ComplianceFrameworkEntity, + ControlCoverageStatus, +) from app.domain.entities.technique import TechniqueEntity -__all__ = ["TechniqueEntity"] +__all__ = [ + "CampaignEntity", + "ComplianceControlEntity", + "ComplianceFrameworkEntity", + "ControlCoverageStatus", + "TechniqueEntity", +] diff --git a/backend/app/domain/entities/campaign.py b/backend/app/domain/entities/campaign.py new file mode 100644 index 0000000..02c1487 --- /dev/null +++ b/backend/app/domain/entities/campaign.py @@ -0,0 +1,103 @@ +"""Campaign domain entity with lifecycle validation. + +Pure domain logic — no framework imports. +""" + +from __future__ import annotations + +import enum +import uuid +from dataclasses import dataclass, field +from typing import Any + +from app.domain.errors import BusinessRuleViolation, InvalidStateTransition + + +class CampaignStatus(str, enum.Enum): + draft = "draft" + active = "active" + completed = "completed" + archived = "archived" + + +class CampaignType(str, enum.Enum): + custom = "custom" + apt_emulation = "apt_emulation" + kill_chain = "kill_chain" + compliance = "compliance" + + +VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = { + CampaignStatus.draft: [CampaignStatus.active], + CampaignStatus.active: [CampaignStatus.completed], + CampaignStatus.completed: [CampaignStatus.archived], + CampaignStatus.archived: [], +} + + +@dataclass +class CampaignEntity: + name: str + type: CampaignType = CampaignType.custom + status: CampaignStatus = CampaignStatus.draft + id: uuid.UUID | None = None + description: str | None = None + threat_actor_id: uuid.UUID | None = None + created_by: uuid.UUID | None = None + target_platform: str | None = None + tags: list[str] = field(default_factory=list) + test_count: int = 0 + + def can_transition_to(self, target: CampaignStatus) -> bool: + return target in VALID_TRANSITIONS.get(self.status, []) + + def activate(self) -> None: + if not self.can_transition_to(CampaignStatus.active): + raise InvalidStateTransition( + self.status.value, CampaignStatus.active.value, + [s.value for s in VALID_TRANSITIONS[self.status]], + ) + if self.test_count == 0: + raise BusinessRuleViolation( + "Campaign must have at least one test to activate" + ) + self.status = CampaignStatus.active + + def complete(self) -> None: + if not self.can_transition_to(CampaignStatus.completed): + raise InvalidStateTransition( + self.status.value, CampaignStatus.completed.value, + [s.value for s in VALID_TRANSITIONS[self.status]], + ) + self.status = CampaignStatus.completed + + def archive(self) -> None: + if not self.can_transition_to(CampaignStatus.archived): + raise InvalidStateTransition( + self.status.value, CampaignStatus.archived.value, + [s.value for s in VALID_TRANSITIONS[self.status]], + ) + self.status = CampaignStatus.archived + + def ensure_modifiable(self) -> None: + if self.status not in (CampaignStatus.draft, CampaignStatus.active): + raise BusinessRuleViolation( + f"Cannot modify campaign in '{self.status.value}' state" + ) + + @classmethod + def from_orm(cls, orm: Any) -> CampaignEntity: + """Build a CampaignEntity from a SQLAlchemy Campaign model.""" + test_count = len(getattr(orm, "campaign_tests", None) or []) + return cls( + id=orm.id, + name=orm.name, + type=CampaignType(orm.type) if orm.type else CampaignType.custom, + status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft, + description=orm.description, + threat_actor_id=orm.threat_actor_id, + created_by=orm.created_by, + target_platform=orm.target_platform, + tags=orm.tags or [], + test_count=test_count, + ) diff --git a/backend/app/domain/entities/compliance.py b/backend/app/domain/entities/compliance.py new file mode 100644 index 0000000..549eb6b --- /dev/null +++ b/backend/app/domain/entities/compliance.py @@ -0,0 +1,71 @@ +"""Compliance domain entities with coverage calculation logic. + +Pure domain logic — no framework imports. +""" + +from __future__ import annotations + +import enum +import uuid +from dataclasses import dataclass, field + + +class ControlCoverageStatus(str, enum.Enum): + covered = "covered" + partially_covered = "partially_covered" + not_covered = "not_covered" + + +@dataclass +class ComplianceControlEntity: + control_id: str + title: str + id: uuid.UUID | None = None + description: str | None = None + category: str | None = None + technique_statuses: list[str] = field(default_factory=list) + + @property + def coverage_status(self) -> ControlCoverageStatus: + if not self.technique_statuses: + return ControlCoverageStatus.not_covered + covered_statuses = {"validated", "partial"} + covered = [s for s in self.technique_statuses if s in covered_statuses] + if len(covered) == len(self.technique_statuses): + return ControlCoverageStatus.covered + elif len(covered) > 0: + return ControlCoverageStatus.partially_covered + return ControlCoverageStatus.not_covered + + +@dataclass +class ComplianceFrameworkEntity: + name: str + id: uuid.UUID | None = None + version: str | None = None + description: str | None = None + is_active: bool = True + controls: list[ComplianceControlEntity] = field(default_factory=list) + + @property + def total_controls(self) -> int: + return len(self.controls) + + @property + def covered_controls(self) -> int: + return sum( + 1 for c in self.controls + if c.coverage_status == ControlCoverageStatus.covered + ) + + @property + def coverage_pct(self) -> float: + if self.total_controls == 0: + return 0.0 + return round(self.covered_controls / self.total_controls * 100, 1) + + def get_gap_controls(self) -> list[ComplianceControlEntity]: + return [ + c for c in self.controls + if c.coverage_status != ControlCoverageStatus.covered + ] diff --git a/backend/app/routers/audit.py b/backend/app/routers/audit.py index 033a86f..0dd257b 100644 --- a/backend/app/routers/audit.py +++ b/backend/app/routers/audit.py @@ -4,14 +4,17 @@ from datetime import datetime from typing import Optional from fastapi import APIRouter, Depends, Query -from sqlalchemy import func -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import require_role -from app.models.audit import AuditLog from app.models.user import User from app.schemas.audit import AuditLogOut, AuditLogPage +from app.services.audit_query_service import ( + list_distinct_actions, + list_distinct_entity_types, + list_logs, +) router = APIRouter(prefix="/audit-logs", tags=["audit"]) @@ -29,56 +32,25 @@ def list_audit_logs( current_user: User = Depends(require_role("admin")), ): """Return paginated audit logs with optional filters. - + **Requires admin role.** """ - query = db.query(AuditLog).options(joinedload(AuditLog.user)) - - # Apply filters - if user_id: - query = query.filter(AuditLog.user_id == user_id) - if action: - query = query.filter(AuditLog.action == action) - if entity_type: - query = query.filter(AuditLog.entity_type == entity_type) - if start_date: - query = query.filter(AuditLog.timestamp >= start_date) - if end_date: - query = query.filter(AuditLog.timestamp <= end_date) - - # Get total count - total = query.count() - - # Get paginated results - logs = ( - query - .order_by(AuditLog.timestamp.desc()) - .offset(offset) - .limit(limit) - .all() - ) - - # Convert to response format with username - items = [] - for log in logs: - item = AuditLogOut( - id=log.id, - user_id=log.user_id, - username=log.user.username if log.user else None, - action=log.action, - entity_type=log.entity_type, - entity_id=log.entity_id, - timestamp=log.timestamp, - details=log.details, - ) - items.append(item) - - return AuditLogPage( - items=items, - total=total, + result = list_logs( + db, + user_id=user_id, + action=action, + entity_type=entity_type, + start_date=start_date, + end_date=end_date, offset=offset, limit=limit, ) + return AuditLogPage( + items=[AuditLogOut(**item) for item in result["items"]], + total=result["total"], + offset=result["offset"], + limit=result["limit"], + ) @router.get("/actions", response_model=list[str]) @@ -87,16 +59,10 @@ def list_actions( current_user: User = Depends(require_role("admin")), ): """Return a list of distinct action types in the audit log. - + **Requires admin role.** """ - actions = ( - db.query(AuditLog.action) - .distinct() - .order_by(AuditLog.action) - .all() - ) - return [a[0] for a in actions] + return list_distinct_actions(db) @router.get("/entity-types", response_model=list[str]) @@ -105,14 +71,7 @@ def list_entity_types( current_user: User = Depends(require_role("admin")), ): """Return a list of distinct entity types in the audit log. - + **Requires admin role.** """ - types = ( - db.query(AuditLog.entity_type) - .filter(AuditLog.entity_type.isnot(None)) - .distinct() - .order_by(AuditLog.entity_type) - .all() - ) - return [t[0] for t in types] + return list_distinct_entity_types(db) diff --git a/backend/app/routers/data_sources.py b/backend/app/routers/data_sources.py index 8809237..5ef7b57 100644 --- a/backend/app/routers/data_sources.py +++ b/backend/app/routers/data_sources.py @@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources including sync triggers, enable/disable toggles, and statistics. """ -import logging -from datetime import datetime -from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends from pydantic import BaseModel from sqlalchemy.orm import Session +from typing import Optional from app.database import get_db from app.dependencies.auth import require_role +from app.domain.unit_of_work import UnitOfWork from app.models.user import User -from app.models.data_source import DataSource from app.services.audit_service import log_action +from app.services.data_source_service import ( + get_source_stats, + list_sources, + sync_all_sources, + sync_source, + update_source, +) # --------------------------------------------------------------------------- @@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel): sync_frequency: Optional[str] = None config: Optional[dict] = None -logger = logging.getLogger(__name__) router = APIRouter(prefix="/data-sources", tags=["data-sources"]) -# --------------------------------------------------------------------------- -# Sync dispatcher — maps source name → import function -# --------------------------------------------------------------------------- - -def _get_sync_handler(source_name: str): - """Lazily import and return the sync function for *source_name*. - - We import lazily to avoid circular imports and to only load the - modules that are actually needed. - """ - handlers = { - "atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"), - "sigma": ("app.services.sigma_import_service", "sync"), - "lolbas": ("app.services.lolbas_import_service", "sync"), - "gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"), - "caldera": ("app.services.caldera_import_service", "sync"), - "elastic_rules": ("app.services.elastic_import_service", "sync"), - "mitre_cti": ("app.services.threat_actor_import_service", "sync"), - "d3fend": ("app.services.d3fend_import_service", "sync"), - } - - if source_name not in handlers: - return None - - module_path, func_name = handlers[source_name] - import importlib - mod = importlib.import_module(module_path) - return getattr(mod, func_name) - - # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @@ -79,25 +52,7 @@ def list_data_sources( **Requires** the ``admin`` role. """ - sources = db.query(DataSource).order_by(DataSource.name).all() - return [ - { - "id": str(s.id), - "name": s.name, - "display_name": s.display_name, - "type": s.type, - "url": s.url, - "description": s.description, - "is_enabled": s.is_enabled, - "last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None, - "last_sync_status": s.last_sync_status, - "last_sync_stats": s.last_sync_stats, - "sync_frequency": s.sync_frequency, - "config": s.config, - "created_at": s.created_at.isoformat() if s.created_at else None, - } - for s in sources - ] + return list_sources(db) @router.patch("/{source_id}") @@ -111,31 +66,21 @@ def update_data_source( **Requires** the ``admin`` role. """ - ds = db.query(DataSource).filter(DataSource.id == source_id).first() - if not ds: - raise HTTPException(status_code=404, detail="Data source not found") - update_data = body.model_dump(exclude_unset=True) - - if "is_enabled" in update_data: - ds.is_enabled = update_data["is_enabled"] - if "sync_frequency" in update_data: - ds.sync_frequency = update_data["sync_frequency"] - if "config" in update_data: - ds.config = update_data["config"] - - db.commit() + update_source(db, source_id, **update_data) + with UnitOfWork(db) as uow: + uow.commit() log_action( db, user_id=current_user.id, action="update_data_source", entity_type="data_source", - entity_id=str(ds.id), + entity_id=source_id, details={"updates": update_data}, ) - return {"message": "Data source updated", "id": str(ds.id)} + return {"message": "Data source updated", "id": source_id} @router.post("/{source_id}/sync") @@ -148,46 +93,7 @@ def sync_data_source( **Requires** the ``admin`` role. """ - ds = db.query(DataSource).filter(DataSource.id == source_id).first() - if not ds: - raise HTTPException(status_code=404, detail="Data source not found") - - handler = _get_sync_handler(ds.name) - if handler is None: - raise HTTPException( - status_code=400, - detail=f"No sync handler available for '{ds.name}'", - ) - - # Mark as in_progress - ds.last_sync_status = "in_progress" - db.commit() - - try: - summary = handler(db) - except Exception as exc: - logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) - ds.last_sync_status = "error" - ds.last_sync_at = datetime.utcnow() - ds.last_sync_stats = {"error": str(exc)} - db.commit() - raise HTTPException( - status_code=500, - detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.", - ) - - # Update DS record (the handler may already have done this, - # but we ensure it here as well) - ds.last_sync_at = datetime.utcnow() - ds.last_sync_status = "success" - ds.last_sync_stats = summary - db.commit() - - return { - "message": f"Sync complete for {ds.display_name}", - "source": ds.name, - "stats": summary, - } + return sync_source(db, source_id) @router.post("/sync-all") @@ -199,49 +105,7 @@ def sync_all_data_sources( **Requires** the ``admin`` role. """ - enabled_sources = ( - db.query(DataSource) - .filter(DataSource.is_enabled == True) - .order_by(DataSource.name) - .all() - ) - - results = [] - for ds in enabled_sources: - handler = _get_sync_handler(ds.name) - if handler is None: - results.append({ - "source": ds.name, - "status": "skipped", - "detail": "No sync handler available", - }) - continue - - ds.last_sync_status = "in_progress" - db.commit() - - try: - summary = handler(db) - ds.last_sync_at = datetime.utcnow() - ds.last_sync_status = "success" - ds.last_sync_stats = summary - db.commit() - results.append({ - "source": ds.name, - "status": "success", - "stats": summary, - }) - except Exception as exc: - logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) - ds.last_sync_status = "error" - ds.last_sync_at = datetime.utcnow() - ds.last_sync_stats = {"error": str(exc)} - db.commit() - results.append({ - "source": ds.name, - "status": "error", - "detail": "Sync failed. Check server logs for details.", - }) + results = sync_all_sources(db) log_action( db, @@ -265,39 +129,4 @@ def get_data_source_stats( **Requires** the ``admin`` role. """ - ds = db.query(DataSource).filter(DataSource.id == source_id).first() - if not ds: - raise HTTPException(status_code=404, detail="Data source not found") - - # Count items from this source - from app.models.test_template import TestTemplate - from app.models.detection_rule import DetectionRule - - template_count = 0 - rule_count = 0 - - if ds.type == "attack_procedure": - template_count = ( - db.query(TestTemplate) - .filter(TestTemplate.source == ds.name) - .count() - ) - elif ds.type == "detection_rule": - rule_count = ( - db.query(DetectionRule) - .filter(DetectionRule.source == ds.name) - .count() - ) - - return { - "id": str(ds.id), - "name": ds.name, - "display_name": ds.display_name, - "type": ds.type, - "is_enabled": ds.is_enabled, - "last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None, - "last_sync_status": ds.last_sync_status, - "last_sync_stats": ds.last_sync_stats, - "total_templates": template_count, - "total_rules": rule_count, - } + return get_source_stats(db, source_id) diff --git a/backend/app/routers/users.py b/backend/app/routers/users.py index 28dd279..15fc09c 100644 --- a/backend/app/routers/users.py +++ b/backend/app/routers/users.py @@ -2,20 +2,24 @@ import uuid -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, status from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import require_role +from app.domain.unit_of_work import UnitOfWork from app.models.user import User from app.schemas.user import UserCreate, UserUpdate, UserOut -from app.auth import hash_password from app.services.audit_service import log_action +from app.services.user_service import ( + create_user, + get_user_or_raise, + list_users, + update_user, +) router = APIRouter(prefix="/users", tags=["users"]) -VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"} - # --------------------------------------------------------------------------- # GET /users — list all users @@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe @router.get("", response_model=list[UserOut]) -def list_users( +def list_users_route( db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), ): """Return a list of all users. **Requires admin role.**""" - return db.query(User).order_by(User.username).all() + return list_users(db) # --------------------------------------------------------------------------- @@ -37,38 +41,23 @@ def list_users( @router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED) -def create_user( +def create_user_route( payload: UserCreate, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), ): """Create a new user. **Requires admin role.**""" - - # Check if username already exists - existing = db.query(User).filter(User.username == payload.username).first() - if existing: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=f"Username '{payload.username}' already exists", - ) - - # Validate role - if payload.role not in VALID_ROLES: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}", - ) - - user = User( + user = create_user( + db, username=payload.username, email=payload.email, - hashed_password=hash_password(payload.password), + password=payload.password, role=payload.role, ) - db.add(user) - db.commit() + with UnitOfWork(db) as uow: + uow.commit() db.refresh(user) - + log_action( db, user_id=current_user.id, @@ -77,7 +66,7 @@ def create_user( entity_id=user.id, details={"username": user.username, "role": user.role}, ) - + return user @@ -93,13 +82,7 @@ def get_user( current_user: User = Depends(require_role("admin")), ): """Return a single user by ID. **Requires admin role.**""" - user = db.query(User).filter(User.id == user_id).first() - if user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) - return user + return get_user_or_raise(db, user_id) # --------------------------------------------------------------------------- @@ -108,46 +91,26 @@ def get_user( @router.patch("/{user_id}", response_model=UserOut) -def update_user( +def update_user_route( user_id: uuid.UUID, payload: UserUpdate, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), ): """Update one or more fields of an existing user. **Requires admin role.**""" - user = db.query(User).filter(User.id == user_id).first() - if user is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) - update_data = payload.model_dump(exclude_unset=True) - - # Validate role if being updated - if "role" in update_data and update_data["role"] not in VALID_ROLES: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}", - ) - - # Hash password if being updated - if "password" in update_data: - update_data["hashed_password"] = hash_password(update_data.pop("password")) - - for field, value in update_data.items(): - setattr(user, field, value) - - db.commit() + user = update_user(db, user_id, **update_data) + with UnitOfWork(db) as uow: + uow.commit() db.refresh(user) - + log_action( db, user_id=current_user.id, action="update_user", entity_type="user", entity_id=user.id, - details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())}, + details={"updated_fields": list(update_data.keys())}, ) - + return user diff --git a/backend/app/services/audit_query_service.py b/backend/app/services/audit_query_service.py new file mode 100644 index 0000000..6b3faaa --- /dev/null +++ b/backend/app/services/audit_query_service.py @@ -0,0 +1,93 @@ +"""Audit log query service — framework-agnostic query logic for audit logs. + +Provides paginated logs and distinct action/entity-type lists. +No FastAPI imports. +""" + +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy.orm import Session, joinedload + +from app.models.audit import AuditLog + + +def list_logs( + db: Session, + *, + user_id: str | None = None, + action: str | None = None, + entity_type: str | None = None, + start_date: datetime | None = None, + end_date: datetime | None = None, + offset: int = 0, + limit: int = 50, +) -> dict: + """Return paginated audit logs with optional filters. + + Returns a dict with keys: items, total, offset, limit. + Each item is a dict with: id, user_id, username, action, entity_type, + entity_id, timestamp, details. + """ + query = db.query(AuditLog).options(joinedload(AuditLog.user)) + + if user_id: + query = query.filter(AuditLog.user_id == user_id) + if action: + query = query.filter(AuditLog.action == action) + if entity_type: + query = query.filter(AuditLog.entity_type == entity_type) + if start_date: + query = query.filter(AuditLog.timestamp >= start_date) + if end_date: + query = query.filter(AuditLog.timestamp <= end_date) + + total = query.count() + + logs = ( + query + .order_by(AuditLog.timestamp.desc()) + .offset(offset) + .limit(limit) + .all() + ) + + items = [ + { + "id": log.id, + "user_id": log.user_id, + "username": log.user.username if log.user else None, + "action": log.action, + "entity_type": log.entity_type, + "entity_id": log.entity_id, + "timestamp": log.timestamp, + "details": log.details, + } + for log in logs + ] + + return {"items": items, "total": total, "offset": offset, "limit": limit} + + +def list_distinct_actions(db: Session) -> list[str]: + """Return a list of distinct action types in the audit log.""" + actions = ( + db.query(AuditLog.action) + .distinct() + .order_by(AuditLog.action) + .all() + ) + return [a[0] for a in actions] + + +def list_distinct_entity_types(db: Session) -> list[str]: + """Return a list of distinct entity types in the audit log.""" + types = ( + db.query(AuditLog.entity_type) + .filter(AuditLog.entity_type.isnot(None)) + .distinct() + .order_by(AuditLog.entity_type) + .all() + ) + return [t[0] for t in types] diff --git a/backend/app/services/data_source_service.py b/backend/app/services/data_source_service.py new file mode 100644 index 0000000..6de64ef --- /dev/null +++ b/backend/app/services/data_source_service.py @@ -0,0 +1,222 @@ +"""Data source management service — framework-agnostic query and sync logic. + +Provides list, update, sync, and stats. Sync operations commit internally +since they are long-running and self-contained. +""" + +from __future__ import annotations + +import importlib +import logging +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.domain.errors import BusinessRuleViolation, EntityNotFoundError +from app.models.data_source import DataSource + +logger = logging.getLogger(__name__) + + +def _get_sync_handler(source_name: str): + """Lazily import and return the sync function for *source_name*. + + We import lazily to avoid circular imports and to only load the + modules that are actually needed. + """ + handlers = { + "atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"), + "sigma": ("app.services.sigma_import_service", "sync"), + "lolbas": ("app.services.lolbas_import_service", "sync"), + "gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"), + "caldera": ("app.services.caldera_import_service", "sync"), + "elastic_rules": ("app.services.elastic_import_service", "sync"), + "mitre_cti": ("app.services.threat_actor_import_service", "sync"), + "d3fend": ("app.services.d3fend_import_service", "sync"), + } + + if source_name not in handlers: + return None + + module_path, func_name = handlers[source_name] + mod = importlib.import_module(module_path) + return getattr(mod, func_name) + + +def list_sources(db: Session) -> list[dict]: + """Return all registered data sources as a list of dicts.""" + sources = db.query(DataSource).order_by(DataSource.name).all() + return [ + { + "id": str(s.id), + "name": s.name, + "display_name": s.display_name, + "type": s.type, + "url": s.url, + "description": s.description, + "is_enabled": s.is_enabled, + "last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None, + "last_sync_status": s.last_sync_status, + "last_sync_stats": s.last_sync_stats, + "sync_frequency": s.sync_frequency, + "config": s.config, + "created_at": s.created_at.isoformat() if s.created_at else None, + } + for s in sources + ] + + +def update_source(db: Session, source_id: str, **fields: object) -> None: + """Update a data source's fields (is_enabled, sync_frequency, config). + + Raises EntityNotFoundError if source does not exist. + Does not commit; the router handles that. + """ + ds = db.query(DataSource).filter(DataSource.id == source_id).first() + if not ds: + raise EntityNotFoundError("Data source", source_id) + + if "is_enabled" in fields: + ds.is_enabled = fields["is_enabled"] + if "sync_frequency" in fields: + ds.sync_frequency = fields["sync_frequency"] + if "config" in fields: + ds.config = fields["config"] + + +def sync_source(db: Session, source_id: str) -> dict: + """Trigger sync for a specific data source. + + Raises EntityNotFoundError if source does not exist. + Raises BusinessRuleViolation if no sync handler is available. + Commits internally (long-running, self-contained operation). + Returns dict with message, source, stats. + """ + ds = db.query(DataSource).filter(DataSource.id == source_id).first() + if not ds: + raise EntityNotFoundError("Data source", source_id) + + handler = _get_sync_handler(ds.name) + if handler is None: + raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'") + + ds.last_sync_status = "in_progress" + db.commit() + + try: + summary = handler(db) + except Exception as exc: + logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) + ds.last_sync_status = "error" + ds.last_sync_at = datetime.utcnow() + ds.last_sync_stats = {"error": str(exc)} + db.commit() + raise BusinessRuleViolation( + f"Sync failed for '{ds.display_name}'. Check server logs for details." + ) + + ds.last_sync_at = datetime.utcnow() + ds.last_sync_status = "success" + ds.last_sync_stats = summary + db.commit() + + return { + "message": f"Sync complete for {ds.display_name}", + "source": ds.name, + "stats": summary, + } + + +def sync_all_sources(db: Session) -> list[dict]: + """Trigger sync for all enabled data sources (sequentially). + + Commits internally (long-running, self-contained operation). + Returns list of result dicts with source, status, stats/detail. + """ + enabled_sources = ( + db.query(DataSource) + .filter(DataSource.is_enabled == True) + .order_by(DataSource.name) + .all() + ) + + results = [] + for ds in enabled_sources: + handler = _get_sync_handler(ds.name) + if handler is None: + results.append({ + "source": ds.name, + "status": "skipped", + "detail": "No sync handler available", + }) + continue + + ds.last_sync_status = "in_progress" + db.commit() + + try: + summary = handler(db) + ds.last_sync_at = datetime.utcnow() + ds.last_sync_status = "success" + ds.last_sync_stats = summary + db.commit() + results.append({ + "source": ds.name, + "status": "success", + "stats": summary, + }) + except Exception as exc: + logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True) + ds.last_sync_status = "error" + ds.last_sync_at = datetime.utcnow() + ds.last_sync_stats = {"error": str(exc)} + db.commit() + results.append({ + "source": ds.name, + "status": "error", + "detail": "Sync failed. Check server logs for details.", + }) + + return results + + +def get_source_stats(db: Session, source_id: str) -> dict: + """Return detailed statistics for a data source. + + Raises EntityNotFoundError if source does not exist. + """ + ds = db.query(DataSource).filter(DataSource.id == source_id).first() + if not ds: + raise EntityNotFoundError("Data source", source_id) + + from app.models.test_template import TestTemplate + from app.models.detection_rule import DetectionRule + + template_count = 0 + rule_count = 0 + + if ds.type == "attack_procedure": + template_count = ( + db.query(TestTemplate) + .filter(TestTemplate.source == ds.name) + .count() + ) + elif ds.type == "detection_rule": + rule_count = ( + db.query(DetectionRule) + .filter(DetectionRule.source == ds.name) + .count() + ) + + return { + "id": str(ds.id), + "name": ds.name, + "display_name": ds.display_name, + "type": ds.type, + "is_enabled": ds.is_enabled, + "last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None, + "last_sync_status": ds.last_sync_status, + "last_sync_stats": ds.last_sync_stats, + "total_templates": template_count, + "total_rules": rule_count, + } diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py new file mode 100644 index 0000000..004da4b --- /dev/null +++ b/backend/app/services/user_service.py @@ -0,0 +1,88 @@ +"""User management service — framework-agnostic CRUD for users. + +Uses domain exceptions from app.domain.errors. The router handles +HTTP concerns, auth, audit logging, and commit. +""" + +from __future__ import annotations + +import uuid + +from sqlalchemy.orm import Session + +from app.auth import hash_password +from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError +from app.models.user import User + +VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"} + + +def list_users(db: Session) -> list[User]: + """Return a list of all users ordered by username.""" + return db.query(User).order_by(User.username).all() + + +def create_user( + db: Session, + *, + username: str, + email: str | None, + password: str, + role: str, +) -> User: + """Create a new user. + + Raises DuplicateEntityError if username already exists. + Raises BusinessRuleViolation if role is invalid. + Does not commit; the router handles that. + """ + existing = db.query(User).filter(User.username == username).first() + if existing: + raise DuplicateEntityError("User", "username", username) + + if role not in VALID_ROLES: + raise BusinessRuleViolation( + f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}" + ) + + user = User( + username=username, + email=email, + hashed_password=hash_password(password), + role=role, + ) + db.add(user) + return user + + +def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User: + """Return a user by ID or raise EntityNotFoundError.""" + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise EntityNotFoundError("User", str(user_id)) + return user + + +def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User: + """Update one or more fields of an existing user. + + Raises EntityNotFoundError if user does not exist. + Raises BusinessRuleViolation if role is invalid. + Handles 'password' by hashing and storing as 'hashed_password'. + Does not commit; the router handles that. + """ + user = get_user_or_raise(db, user_id) + update_data = dict(fields) + + if "role" in update_data and update_data["role"] not in VALID_ROLES: + raise BusinessRuleViolation( + f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}" + ) + + if "password" in update_data: + update_data["hashed_password"] = hash_password(str(update_data.pop("password"))) + + for field, value in update_data.items(): + setattr(user, field, value) + + return user diff --git a/backend/tests/test_campaign_entity.py b/backend/tests/test_campaign_entity.py new file mode 100644 index 0000000..3b5b9eb --- /dev/null +++ b/backend/tests/test_campaign_entity.py @@ -0,0 +1,175 @@ +"""Tests for CampaignEntity — pure domain logic, no DB.""" + +import sys +import os +import uuid +from unittest.mock import MagicMock + +import pytest + +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +from app.domain.entities.campaign import ( + CampaignEntity, + CampaignStatus, + CampaignType, +) +from app.domain.errors import BusinessRuleViolation, InvalidStateTransition + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _entity(status: str = "draft", test_count: int = 0, **overrides) -> CampaignEntity: + defaults = dict( + id=uuid.uuid4(), + name="Test Campaign", + type=CampaignType.custom, + status=CampaignStatus(status), + description=None, + threat_actor_id=None, + created_by=None, + target_platform=None, + tags=[], + test_count=test_count, + ) + defaults.update(overrides) + return CampaignEntity(**defaults) + + +def _fake_orm(status: str = "draft", test_count: int = 0, **overrides) -> MagicMock: + m = MagicMock() + m.id = uuid.uuid4() + m.name = "Test Campaign" + m.type = "custom" + m.status = status + m.description = None + m.threat_actor_id = None + m.created_by = None + m.target_platform = None + m.tags = [] + m.campaign_tests = [MagicMock()] * test_count if test_count else [] + for k, v in overrides.items(): + setattr(m, k, v) + return m + + +# ── 1. Test activation from draft with tests → success ─────────────── + + +def test_activate_from_draft_with_tests_success(): + e = _entity("draft", test_count=1) + e.activate() + assert e.status == CampaignStatus.active + + +def test_activate_from_draft_with_multiple_tests_success(): + e = _entity("draft", test_count=3) + e.activate() + assert e.status == CampaignStatus.active + + +# ── 2. Test activation from draft with 0 tests → BusinessRuleViolation ─── + + +def test_activate_from_draft_with_zero_tests_raises(): + e = _entity("draft", test_count=0) + with pytest.raises(BusinessRuleViolation, match="at least one test"): + e.activate() + assert e.status == CampaignStatus.draft + + +# ── 3. Test activation from active → InvalidStateTransition ──────────── + + +def test_activate_from_active_raises(): + e = _entity("active", test_count=2) + with pytest.raises(InvalidStateTransition) as exc_info: + e.activate() + assert exc_info.value.current_state == "active" + assert exc_info.value.target_state == "active" + assert "completed" in exc_info.value.valid_transitions + + +# ── 4. Test complete from active → success ────────────────────────────── + + +def test_complete_from_active_success(): + e = _entity("active", test_count=2) + e.complete() + assert e.status == CampaignStatus.completed + + +# ── 5. Test complete from draft → InvalidStateTransition ──────────────── + + +def test_complete_from_draft_raises(): + e = _entity("draft", test_count=1) + with pytest.raises(InvalidStateTransition) as exc_info: + e.complete() + assert exc_info.value.current_state == "draft" + assert exc_info.value.target_state == "completed" + assert "active" in exc_info.value.valid_transitions + + +# ── 6. Test ensure_modifiable in draft/active → ok ────────────────────── + + +def test_ensure_modifiable_draft_ok(): + e = _entity("draft") + e.ensure_modifiable() # no raise + + +def test_ensure_modifiable_active_ok(): + e = _entity("active", test_count=1) + e.ensure_modifiable() # no raise + + +# ── 7. Test ensure_modifiable in completed → BusinessRuleViolation ────── + + +def test_ensure_modifiable_completed_raises(): + e = _entity("completed", test_count=1) + with pytest.raises(BusinessRuleViolation, match="Cannot modify"): + e.ensure_modifiable() + + +def test_ensure_modifiable_archived_raises(): + e = _entity("archived", test_count=1) + with pytest.raises(BusinessRuleViolation, match="Cannot modify"): + e.ensure_modifiable() + + +# ── 8. Test from_orm conversion ────────────────────────────────────────── + + +def test_from_orm_basic(): + orm = _fake_orm("draft", test_count=0) + e = CampaignEntity.from_orm(orm) + assert e.name == "Test Campaign" + assert e.type == CampaignType.custom + assert e.status == CampaignStatus.draft + assert e.id == orm.id + assert e.test_count == 0 + + +def test_from_orm_with_tests(): + orm = _fake_orm("draft", test_count=3) + e = CampaignEntity.from_orm(orm) + assert e.test_count == 3 + + +def test_from_orm_coerces_type_and_status(): + orm = _fake_orm(status="active", type="apt_emulation", test_count=1) + e = CampaignEntity.from_orm(orm) + assert e.status == CampaignStatus.active + assert e.type == CampaignType.apt_emulation + + +def test_from_orm_handles_none_tags(): + orm = _fake_orm("draft", test_count=0) + orm.tags = None + e = CampaignEntity.from_orm(orm) + assert e.tags == [] diff --git a/backend/tests/test_compliance_entity.py b/backend/tests/test_compliance_entity.py new file mode 100644 index 0000000..76cd085 --- /dev/null +++ b/backend/tests/test_compliance_entity.py @@ -0,0 +1,105 @@ +"""Tests for compliance domain entities.""" + +import pytest + +from app.domain.entities.compliance import ( + ComplianceControlEntity, + ComplianceFrameworkEntity, + ControlCoverageStatus, +) + + +# ── Control coverage status ─────────────────────────────────────────────── + + +def test_control_all_techniques_validated_covered(): + """All techniques validated → covered.""" + control = ComplianceControlEntity( + control_id="AC-2", + title="Account Management", + technique_statuses=["validated", "validated"], + ) + assert control.coverage_status == ControlCoverageStatus.covered + + +def test_control_all_techniques_partial_covered(): + """All techniques partial → covered.""" + control = ComplianceControlEntity( + control_id="AC-2", + title="Account Management", + technique_statuses=["partial"], + ) + assert control.coverage_status == ControlCoverageStatus.covered + + +def test_control_mixed_statuses_partially_covered(): + """Mixed statuses (some validated/partial, some not) → partially_covered.""" + control = ComplianceControlEntity( + control_id="AC-2", + title="Account Management", + technique_statuses=["validated", "not_evaluated"], + ) + assert control.coverage_status == ControlCoverageStatus.partially_covered + + +def test_control_no_validated_techniques_not_covered(): + """No validated/partial techniques → not_covered.""" + control = ComplianceControlEntity( + control_id="AC-2", + title="Account Management", + technique_statuses=["not_evaluated", "not_covered"], + ) + assert control.coverage_status == ControlCoverageStatus.not_covered + + +def test_control_empty_techniques_not_covered(): + """Empty technique_statuses → not_covered.""" + control = ComplianceControlEntity( + control_id="AC-2", + title="Account Management", + technique_statuses=[], + ) + assert control.coverage_status == ControlCoverageStatus.not_covered + + +# ── Framework coverage ───────────────────────────────────────────────────── + + +def test_framework_coverage_pct_calculation(): + """Framework coverage_pct = (covered_controls / total_controls) * 100.""" + controls = [ + ComplianceControlEntity("AC-1", "Title 1", technique_statuses=["validated"]), + ComplianceControlEntity("AC-2", "Title 2", technique_statuses=["not_evaluated"]), + ComplianceControlEntity("AC-3", "Title 3", technique_statuses=["validated", "partial"]), + ComplianceControlEntity("AC-4", "Title 4", technique_statuses=["partial"]), + ComplianceControlEntity("AC-5", "Title 5", technique_statuses=[]), + ] + framework = ComplianceFrameworkEntity(name="NIST 800-53", controls=controls) + # AC-1: covered, AC-2: not_covered, AC-3: covered, AC-4: covered, AC-5: not_covered + assert framework.total_controls == 5 + assert framework.covered_controls == 3 + assert framework.coverage_pct == 60.0 + + +def test_framework_get_gap_controls(): + """get_gap_controls returns only uncovered and partially_covered controls.""" + controls = [ + ComplianceControlEntity("AC-1", "Covered", technique_statuses=["validated"]), + ComplianceControlEntity("AC-2", "Partial", technique_statuses=["validated", "not_evaluated"]), + ComplianceControlEntity("AC-3", "Not Covered", technique_statuses=["not_evaluated"]), + ComplianceControlEntity("AC-4", "Empty", technique_statuses=[]), + ] + framework = ComplianceFrameworkEntity(name="Test", controls=controls) + gaps = framework.get_gap_controls() + assert len(gaps) == 3 + assert gaps[0].control_id == "AC-2" + assert gaps[1].control_id == "AC-3" + assert gaps[2].control_id == "AC-4" + + +def test_framework_no_controls_coverage_pct_zero(): + """Framework with no controls → coverage_pct is 0.""" + framework = ComplianceFrameworkEntity(name="Empty", controls=[]) + assert framework.total_controls == 0 + assert framework.covered_controls == 0 + assert framework.coverage_pct == 0.0