Compare commits

...

8 Commits

44 changed files with 2668 additions and 1252 deletions
+16 -1
View File
@@ -1,3 +1,18 @@
from app.domain.entities.campaign import CampaignEntity
from app.domain.entities.compliance import (
ComplianceControlEntity,
ComplianceFrameworkEntity,
ControlCoverageStatus,
)
from app.domain.entities.technique import TechniqueEntity
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
__all__ = ["TechniqueEntity"]
__all__ = [
"CampaignEntity",
"ComplianceControlEntity",
"ComplianceFrameworkEntity",
"ControlCoverageStatus",
"TechniqueEntity",
"ThreatActorEntity",
"ThreatActorTechniqueRef",
]
+103
View File
@@ -0,0 +1,103 @@
"""Campaign domain entity with lifecycle validation.
Pure domain logic — no framework imports.
"""
from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field
from typing import Any
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
class CampaignStatus(str, enum.Enum):
draft = "draft"
active = "active"
completed = "completed"
archived = "archived"
class CampaignType(str, enum.Enum):
custom = "custom"
apt_emulation = "apt_emulation"
kill_chain = "kill_chain"
compliance = "compliance"
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
CampaignStatus.draft: [CampaignStatus.active],
CampaignStatus.active: [CampaignStatus.completed],
CampaignStatus.completed: [CampaignStatus.archived],
CampaignStatus.archived: [],
}
@dataclass
class CampaignEntity:
name: str
type: CampaignType = CampaignType.custom
status: CampaignStatus = CampaignStatus.draft
id: uuid.UUID | None = None
description: str | None = None
threat_actor_id: uuid.UUID | None = None
created_by: uuid.UUID | None = None
target_platform: str | None = None
tags: list[str] = field(default_factory=list)
test_count: int = 0
def can_transition_to(self, target: CampaignStatus) -> bool:
return target in VALID_TRANSITIONS.get(self.status, [])
def activate(self) -> None:
if not self.can_transition_to(CampaignStatus.active):
raise InvalidStateTransition(
self.status.value, CampaignStatus.active.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
if self.test_count == 0:
raise BusinessRuleViolation(
"Campaign must have at least one test to activate"
)
self.status = CampaignStatus.active
def complete(self) -> None:
if not self.can_transition_to(CampaignStatus.completed):
raise InvalidStateTransition(
self.status.value, CampaignStatus.completed.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
self.status = CampaignStatus.completed
def archive(self) -> None:
if not self.can_transition_to(CampaignStatus.archived):
raise InvalidStateTransition(
self.status.value, CampaignStatus.archived.value,
[s.value for s in VALID_TRANSITIONS[self.status]],
)
self.status = CampaignStatus.archived
def ensure_modifiable(self) -> None:
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
raise BusinessRuleViolation(
f"Cannot modify campaign in '{self.status.value}' state"
)
@classmethod
def from_orm(cls, orm: Any) -> CampaignEntity:
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
test_count = len(getattr(orm, "campaign_tests", None) or [])
return cls(
id=orm.id,
name=orm.name,
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
description=orm.description,
threat_actor_id=orm.threat_actor_id,
created_by=orm.created_by,
target_platform=orm.target_platform,
tags=orm.tags or [],
test_count=test_count,
)
+71
View File
@@ -0,0 +1,71 @@
"""Compliance domain entities with coverage calculation logic.
Pure domain logic — no framework imports.
"""
from __future__ import annotations
import enum
import uuid
from dataclasses import dataclass, field
class ControlCoverageStatus(str, enum.Enum):
covered = "covered"
partially_covered = "partially_covered"
not_covered = "not_covered"
@dataclass
class ComplianceControlEntity:
control_id: str
title: str
id: uuid.UUID | None = None
description: str | None = None
category: str | None = None
technique_statuses: list[str] = field(default_factory=list)
@property
def coverage_status(self) -> ControlCoverageStatus:
if not self.technique_statuses:
return ControlCoverageStatus.not_covered
covered_statuses = {"validated", "partial"}
covered = [s for s in self.technique_statuses if s in covered_statuses]
if len(covered) == len(self.technique_statuses):
return ControlCoverageStatus.covered
elif len(covered) > 0:
return ControlCoverageStatus.partially_covered
return ControlCoverageStatus.not_covered
@dataclass
class ComplianceFrameworkEntity:
name: str
id: uuid.UUID | None = None
version: str | None = None
description: str | None = None
is_active: bool = True
controls: list[ComplianceControlEntity] = field(default_factory=list)
@property
def total_controls(self) -> int:
return len(self.controls)
@property
def covered_controls(self) -> int:
return sum(
1 for c in self.controls
if c.coverage_status == ControlCoverageStatus.covered
)
@property
def coverage_pct(self) -> float:
if self.total_controls == 0:
return 0.0
return round(self.covered_controls / self.total_controls * 100, 1)
def get_gap_controls(self) -> list[ComplianceControlEntity]:
return [
c for c in self.controls
if c.coverage_status != ControlCoverageStatus.covered
]
@@ -0,0 +1,96 @@
"""Threat actor domain entity with coverage analysis logic.
Pure domain logic — no framework imports.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from typing import Any
@dataclass
class ThreatActorTechniqueRef:
"""Lightweight reference to a technique used by an actor."""
technique_id: uuid.UUID
mitre_id: str | None = None
name: str | None = None
status: str | None = None
usage_description: str | None = None
@dataclass
class ThreatActorEntity:
name: str
id: uuid.UUID | None = None
mitre_id: str | None = None
aliases: list[str] = field(default_factory=list)
description: str | None = None
country: str | None = None
target_sectors: list[str] = field(default_factory=list)
target_regions: list[str] = field(default_factory=list)
motivation: str | None = None
sophistication: str | None = None
first_seen: str | None = None
last_seen: str | None = None
is_active: bool = True
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
@property
def technique_count(self) -> int:
return len(self.techniques)
@property
def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
return [
t for t in self.techniques
if t.status in ("validated", "partial")
]
@property
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
return [
t for t in self.techniques
if t.status not in ("validated", "partial")
]
@property
def coverage_pct(self) -> float:
if not self.techniques:
return 0.0
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
@classmethod
def from_orm(cls, orm: Any) -> ThreatActorEntity:
techs: list[ThreatActorTechniqueRef] = []
for tat in getattr(orm, "techniques", None) or []:
technique = getattr(tat, "technique", None)
techs.append(ThreatActorTechniqueRef(
technique_id=tat.technique_id,
mitre_id=getattr(technique, "mitre_id", None) if technique else None,
name=getattr(technique, "name", None) if technique else None,
status=(
technique.status_global.value
if technique and hasattr(technique.status_global, "value")
else getattr(technique, "status_global", None) if technique else None
),
usage_description=tat.usage_description,
))
return cls(
id=orm.id,
name=orm.name,
mitre_id=orm.mitre_id,
aliases=orm.aliases or [],
description=orm.description,
country=orm.country,
target_sectors=orm.target_sectors or [],
target_regions=orm.target_regions or [],
motivation=orm.motivation,
sophistication=orm.sophistication,
first_seen=orm.first_seen,
last_seen=orm.last_seen,
is_active=orm.is_active if orm.is_active is not None else True,
techniques=techs,
)
@@ -0,0 +1,88 @@
"""Port defining the common interface for data import services.
All import services (Atomic Red Team, Sigma, CALDERA, etc.) follow the
same contract: they receive a database session and return a summary dict
with import statistics.
New import sources can be added by:
1. Implementing the ``ImportService`` protocol in a new module
2. Registering the handler in the ``IMPORT_REGISTRY``
This satisfies the Open/Closed Principle — the system is open for new
import sources without modifying existing code.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
from sqlalchemy.orm import Session
@runtime_checkable
class ImportService(Protocol):
"""Contract for any data-import operation.
Each implementation is a callable ``(Session) -> dict`` that
downloads, parses, and upserts records from an external source.
"""
def __call__(self, db: Session) -> dict[str, Any]: ...
class ImportServiceEntry:
"""Lazy-loading wrapper that resolves a module-level function on first call."""
__slots__ = ("_module_path", "_func_name", "_resolved")
def __init__(self, module_path: str, func_name: str) -> None:
self._module_path = module_path
self._func_name = func_name
self._resolved: ImportService | None = None
def __call__(self, db: Session) -> dict[str, Any]:
if self._resolved is None:
import importlib
mod = importlib.import_module(self._module_path)
self._resolved = getattr(mod, self._func_name)
return self._resolved(db)
@property
def source_info(self) -> str:
return f"{self._module_path}.{self._func_name}"
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
"atomic_red_team": ImportServiceEntry(
"app.services.atomic_import_service", "import_atomic_red_team",
),
"sigma": ImportServiceEntry(
"app.services.sigma_import_service", "sync",
),
"lolbas": ImportServiceEntry(
"app.services.lolbas_import_service", "sync",
),
"gtfobins": ImportServiceEntry(
"app.services.lolbas_import_service", "sync_gtfobins",
),
"caldera": ImportServiceEntry(
"app.services.caldera_import_service", "sync",
),
"elastic_rules": ImportServiceEntry(
"app.services.elastic_import_service", "sync",
),
"mitre_cti": ImportServiceEntry(
"app.services.threat_actor_import_service", "sync",
),
"d3fend": ImportServiceEntry(
"app.services.d3fend_import_service", "sync",
),
}
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
"""Look up the import handler for *source_name*.
Returns ``None`` when no handler is registered.
"""
return IMPORT_REGISTRY.get(source_name)
+10
View File
@@ -10,6 +10,16 @@ Usage in routers::
If an exception propagates, ``__exit__`` issues a rollback automatically.
Services should **never** call ``db.commit()``; they use ``db.add()`` /
``db.flush()`` to stage work and let the caller decide when to commit.
**Documented exceptions** (services that may commit internally):
- ``audit_service.log_action`` — called from 15+ routers; commits to ensure
audit records persist even when callers do not.
- Import services (atomic_import, sigma_import, etc.) — self-contained sync ops.
- Background jobs (campaign_scheduler, intel_service, stale_detection,
mitre_sync) — self-contained operations.
- Self-contained batch ops (e.g. detection_rule_service.auto_associate_rules,
snapshot_service.create_snapshot, campaign_service.generate_campaign_from_*,
osint_enrichment_service.enrich_technique_with_cves).
"""
from __future__ import annotations
+5 -139
View File
@@ -1,17 +1,12 @@
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
from datetime import datetime
from fastapi import APIRouter, Depends
from sqlalchemy import func, case
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.audit import AuditLog
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
from app.services import advanced_metrics_service
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
@@ -22,39 +17,7 @@ def coverage_by_tactic(
user: User = Depends(get_current_user),
):
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
results = (
db.query(
Technique.tactic,
func.count(Technique.id).label("total"),
func.sum(
case((Technique.status_global == "validated", 1), else_=0)
).label("validated"),
func.sum(
case((Technique.status_global == "partial", 1), else_=0)
).label("partial"),
func.sum(
case((Technique.status_global == "not_covered", 1), else_=0)
).label("not_covered"),
func.sum(
case((Technique.status_global == "in_progress", 1), else_=0)
).label("in_progress"),
)
.group_by(Technique.tactic)
.order_by(Technique.tactic)
.all()
)
return [
{
"tactic": r[0] or "Unknown",
"total": r[1],
"validated": int(r[2]),
"partial": int(r[3]),
"not_covered": int(r[4]),
"in_progress": int(r[5]),
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in results
]
return advanced_metrics_service.get_coverage_by_tactic(db)
@router.get("/never-tested")
@@ -63,24 +26,7 @@ def never_tested_techniques(
user: User = Depends(get_current_user),
):
"""Techniques that have never had a test created."""
tested_technique_ids = (
db.query(Test.technique_id).distinct().subquery()
)
techniques = (
db.query(Technique)
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
.order_by(Technique.mitre_id)
.all()
)
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"is_subtechnique": t.is_subtechnique,
}
for t in techniques
]
return advanced_metrics_service.get_never_tested_techniques(db)
@router.get("/avg-validation-time")
@@ -92,50 +38,7 @@ def avg_validation_time(
Returns overall average and per-phase averages where data is available.
"""
validated_tests = (
db.query(Test)
.filter(Test.state == "validated")
.all()
)
if not validated_tests:
return {
"total_validated": 0,
"avg_total_hours": 0,
"avg_red_phase_hours": 0,
"avg_blue_phase_hours": 0,
}
total_durations = []
red_durations = []
blue_durations = []
for test in validated_tests:
if test.created_at and test.red_validated_at:
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
total_durations.append(total_seconds)
if test.red_started_at and test.blue_started_at:
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
red_paused = test.red_paused_seconds or 0
red_durations.append(max(red_sec - red_paused, 0))
if test.blue_started_at and test.blue_validated_at:
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
blue_paused = test.blue_paused_seconds or 0
blue_durations.append(max(blue_sec - blue_paused, 0))
def avg_hours(durations: list[float]) -> float:
if not durations:
return 0
return round(sum(durations) / len(durations) / 3600, 2)
return {
"total_validated": len(validated_tests),
"avg_total_hours": avg_hours(total_durations),
"avg_red_phase_hours": avg_hours(red_durations),
"avg_blue_phase_hours": avg_hours(blue_durations),
}
return advanced_metrics_service.get_avg_validation_time(db)
@router.get("/detection-rate-trend")
@@ -144,41 +47,4 @@ def detection_rate_trend(
user: User = Depends(get_current_user),
):
"""Monthly detection rate trend for the last 12 months."""
from datetime import timedelta
now = datetime.utcnow()
months = []
for i in range(11, -1, -1):
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
month_end = month_start + timedelta(days=30)
validated = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
detected = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.detection_result == "detected",
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
months.append({
"month": month_start.strftime("%Y-%m"),
"validated": validated,
"detected": detected,
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
})
return months
return advanced_metrics_service.get_detection_rate_trend(db)
+7 -79
View File
@@ -5,15 +5,12 @@ directly from URL. All endpoints require authentication.
"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.coverage_snapshot import CoverageSnapshot
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
from app.services import analytics_service
router = APIRouter(prefix="/analytics", tags=["analytics"])
@@ -24,22 +21,7 @@ def analytics_coverage(
user: User = Depends(get_current_user),
):
"""Coverage per technique — flat format for BI dashboards."""
techniques = db.query(Technique).all()
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"status": t.status_global.value if t.status_global else "not_evaluated",
"is_subtechnique": t.is_subtechnique,
"test_count": len(t.tests) if t.tests else 0,
"review_required": t.review_required,
"last_review_date": (
t.last_review_date.isoformat() if t.last_review_date else None
),
}
for t in techniques
]
return analytics_service.get_coverage_analytics(db)
@router.get("/tests")
@@ -50,34 +32,9 @@ def analytics_tests(
user: User = Depends(get_current_user),
):
"""All tests with timestamps — flat format for BI dashboards."""
query = db.query(Test)
if date_from:
query = query.filter(Test.created_at >= date_from)
if date_to:
query = query.filter(Test.created_at <= date_to)
tests = query.all()
return [
{
"id": str(t.id),
"technique_id": str(t.technique_id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"detection_result": (
t.detection_result.value if t.detection_result else None
),
"created_at": t.created_at.isoformat() if t.created_at else None,
"execution_date": (
t.execution_date.isoformat() if t.execution_date else None
),
"platform": t.platform,
"tool_used": t.tool_used,
"attack_success": t.attack_success,
"remediation_status": t.remediation_status,
}
for t in tests
]
return analytics_service.get_tests_analytics(
db, date_from=date_from, date_to=date_to
)
@router.get("/trends")
@@ -86,23 +43,7 @@ def analytics_trends(
user: User = Depends(get_current_user),
):
"""Historical coverage snapshots for trend visualization."""
snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at)
.all()
)
return [
{
"date": s.created_at.isoformat() if s.created_at else None,
"name": s.name,
"total_techniques": s.total_techniques,
"validated_count": s.validated_count,
"partial_count": s.partial_count,
"not_covered_count": s.not_covered_count,
"organization_score": s.organization_score,
}
for s in snapshots
]
return analytics_service.get_trends_analytics(db)
@router.get("/operators")
@@ -111,17 +52,4 @@ def analytics_operators(
user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Per-operator metrics — for workload management dashboards."""
results = (
db.query(
User.username,
User.role,
func.count(Test.id).label("test_count"),
)
.outerjoin(Test, Test.created_by == User.id)
.group_by(User.id, User.username, User.role)
.all()
)
return [
{"username": r[0], "role": r[1], "test_count": r[2]}
for r in results
]
return analytics_service.get_operators_analytics(db)
+21 -62
View File
@@ -4,14 +4,17 @@ from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_role
from app.models.audit import AuditLog
from app.models.user import User
from app.schemas.audit import AuditLogOut, AuditLogPage
from app.services.audit_query_service import (
list_distinct_actions,
list_distinct_entity_types,
list_logs,
)
router = APIRouter(prefix="/audit-logs", tags=["audit"])
@@ -32,53 +35,22 @@ def list_audit_logs(
**Requires admin role.**
"""
query = db.query(AuditLog).options(joinedload(AuditLog.user))
# Apply filters
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action)
if entity_type:
query = query.filter(AuditLog.entity_type == entity_type)
if start_date:
query = query.filter(AuditLog.timestamp >= start_date)
if end_date:
query = query.filter(AuditLog.timestamp <= end_date)
# Get total count
total = query.count()
# Get paginated results
logs = (
query
.order_by(AuditLog.timestamp.desc())
.offset(offset)
.limit(limit)
.all()
)
# Convert to response format with username
items = []
for log in logs:
item = AuditLogOut(
id=log.id,
user_id=log.user_id,
username=log.user.username if log.user else None,
action=log.action,
entity_type=log.entity_type,
entity_id=log.entity_id,
timestamp=log.timestamp,
details=log.details,
)
items.append(item)
return AuditLogPage(
items=items,
total=total,
result = list_logs(
db,
user_id=user_id,
action=action,
entity_type=entity_type,
start_date=start_date,
end_date=end_date,
offset=offset,
limit=limit,
)
return AuditLogPage(
items=[AuditLogOut(**item) for item in result["items"]],
total=result["total"],
offset=result["offset"],
limit=result["limit"],
)
@router.get("/actions", response_model=list[str])
@@ -90,13 +62,7 @@ def list_actions(
**Requires admin role.**
"""
actions = (
db.query(AuditLog.action)
.distinct()
.order_by(AuditLog.action)
.all()
)
return [a[0] for a in actions]
return list_distinct_actions(db)
@router.get("/entity-types", response_model=list[str])
@@ -108,11 +74,4 @@ def list_entity_types(
**Requires admin role.**
"""
types = (
db.query(AuditLog.entity_type)
.filter(AuditLog.entity_type.isnot(None))
.distinct()
.order_by(AuditLog.entity_type)
.all()
)
return [t[0] for t in types]
return list_distinct_entity_types(db)
+17 -30
View File
@@ -9,7 +9,7 @@ cannot use cookies (e.g. Swagger UI).
import os
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status
from fastapi import APIRouter, Cookie, Depends, Request, Response
from fastapi.security import OAuth2PasswordRequestForm
from slowapi import Limiter
from slowapi.util import get_remote_address
@@ -17,11 +17,13 @@ from sqlalchemy.orm import Session
from jose import jwt, JWTError
from app.auth import verify_password, hash_password, create_access_token, blacklist_token
from app.auth import create_access_token, blacklist_token
from app.config import settings
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.services.auth_service import authenticate_user, change_password as auth_change_password
from app.schemas.auth import TokenResponse, UserOut
from app.schemas.user import PasswordChange
@@ -56,25 +58,11 @@ def login(
attacks. The token is set as an HttpOnly cookie **and** returned in the
JSON body for API/Swagger compatibility.
"""
user = db.query(User).filter(User.username == form_data.username).first()
# Constant-time comparison: always run bcrypt verify to prevent
# timing-based user enumeration (SEC-005).
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
hashed = user.hashed_password if user else _DUMMY_HASH
password_valid = verify_password(form_data.password, hashed)
if user is None or not password_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Incorrect username or password",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is disabled. Contact an administrator.",
)
user = authenticate_user(
db,
username=form_data.username,
password=form_data.password,
)
access_token = create_access_token(data={"sub": user.username})
@@ -163,14 +151,13 @@ def change_password(
``must_change_password`` flag is cleared so the user can proceed
normally.
"""
if not verify_password(body.current_password, current_user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect",
)
current_user.hashed_password = hash_password(body.new_password)
current_user.must_change_password = False
db.commit()
auth_change_password(
db,
current_user,
current_password=body.current_password,
new_password=body.new_password,
)
with UnitOfWork(db) as uow:
uow.commit()
return {"detail": "Password changed successfully"}
+10 -12
View File
@@ -30,7 +30,7 @@ from app.services.campaign_crud_service import (
serialize_campaign,
update_campaign as crud_update,
)
from app.services.notification_service import create_notification
from app.services.notification_service import notify_role
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
@@ -237,17 +237,15 @@ def activate_campaign(
db.commit()
db.refresh(campaign)
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
for user in red_techs:
create_notification(
db,
user_id=user.id,
type="campaign_activated",
title="Campaign activated",
message=f'Campaign "{campaign.name}" has been activated.',
entity_type="campaign",
entity_id=campaign.id,
)
notify_role(
db,
role="red_tech",
type="campaign_activated",
title="Campaign activated",
message=f'Campaign "{campaign.name}" has been activated.',
entity_type="campaign",
entity_id=campaign.id,
)
log_action(
db,
+13 -60
View File
@@ -3,18 +3,20 @@
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.models.user import User
from app.models.technique import Technique
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
from app.services.d3fend_import_service import (
import_d3fend_techniques,
import_d3fend_mappings,
get_defenses_for_technique,
)
from app.services.d3fend_query_service import (
list_defensive_techniques as list_defensive_techniques_svc,
list_d3fend_tactics,
get_defenses_for_attack_technique,
)
logger = logging.getLogger(__name__)
@@ -36,38 +38,9 @@ def list_defensive_techniques(
current_user: User = Depends(get_current_user),
):
"""List all D3FEND defensive techniques with optional filters."""
query = db.query(DefensiveTechnique)
if tactic:
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
total = query.count()
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(dt.id),
"d3fend_id": dt.d3fend_id,
"name": dt.name,
"description": dt.description,
"tactic": dt.tactic,
"d3fend_url": dt.d3fend_url,
}
for dt in items
],
}
return list_defensive_techniques_svc(
db, tactic=tactic, search=search, offset=offset, limit=limit
)
# ---------------------------------------------------------------------------
@@ -75,21 +48,12 @@ def list_defensive_techniques(
# ---------------------------------------------------------------------------
@router.get("/tactics")
def list_d3fend_tactics(
def list_d3fend_tactics_endpoint(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return a list of all D3FEND tactics with counts."""
from sqlalchemy import func
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
.group_by(DefensiveTechnique.tactic)
.order_by(DefensiveTechnique.tactic)
.all()
)
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
return list_d3fend_tactics(db)
# ---------------------------------------------------------------------------
@@ -97,24 +61,13 @@ def list_d3fend_tactics(
# ---------------------------------------------------------------------------
@router.get("/for-technique/{mitre_id}")
def get_defenses_for_attack_technique(
def get_defenses_for_attack_technique_endpoint(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if not technique:
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
defenses = get_defenses_for_technique(db, technique.id)
return {
"mitre_id": mitre_id,
"technique_name": technique.name,
"defenses": defenses,
"total": len(defenses),
}
return get_defenses_for_attack_technique(db, mitre_id)
# ---------------------------------------------------------------------------
+19 -190
View File
@@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources
including sync triggers, enable/disable toggles, and statistics.
"""
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing import Optional
from app.database import get_db
from app.dependencies.auth import require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.data_source import DataSource
from app.services.audit_service import log_action
from app.services.data_source_service import (
get_source_stats,
list_sources,
sync_all_sources,
sync_source,
update_source,
)
# ---------------------------------------------------------------------------
@@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel):
sync_frequency: Optional[str] = None
config: Optional[dict] = None
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
# ---------------------------------------------------------------------------
# Sync dispatcher — maps source name → import function
# ---------------------------------------------------------------------------
def _get_sync_handler(source_name: str):
"""Lazily import and return the sync function for *source_name*.
We import lazily to avoid circular imports and to only load the
modules that are actually needed.
"""
handlers = {
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
"sigma": ("app.services.sigma_import_service", "sync"),
"lolbas": ("app.services.lolbas_import_service", "sync"),
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
"caldera": ("app.services.caldera_import_service", "sync"),
"elastic_rules": ("app.services.elastic_import_service", "sync"),
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
"d3fend": ("app.services.d3fend_import_service", "sync"),
}
if source_name not in handlers:
return None
module_path, func_name = handlers[source_name]
import importlib
mod = importlib.import_module(module_path)
return getattr(mod, func_name)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@@ -79,25 +52,7 @@ def list_data_sources(
**Requires** the ``admin`` role.
"""
sources = db.query(DataSource).order_by(DataSource.name).all()
return [
{
"id": str(s.id),
"name": s.name,
"display_name": s.display_name,
"type": s.type,
"url": s.url,
"description": s.description,
"is_enabled": s.is_enabled,
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
"last_sync_status": s.last_sync_status,
"last_sync_stats": s.last_sync_stats,
"sync_frequency": s.sync_frequency,
"config": s.config,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in sources
]
return list_sources(db)
@router.patch("/{source_id}")
@@ -111,31 +66,21 @@ def update_data_source(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
update_data = body.model_dump(exclude_unset=True)
if "is_enabled" in update_data:
ds.is_enabled = update_data["is_enabled"]
if "sync_frequency" in update_data:
ds.sync_frequency = update_data["sync_frequency"]
if "config" in update_data:
ds.config = update_data["config"]
db.commit()
update_source(db, source_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
log_action(
db,
user_id=current_user.id,
action="update_data_source",
entity_type="data_source",
entity_id=str(ds.id),
entity_id=source_id,
details={"updates": update_data},
)
return {"message": "Data source updated", "id": str(ds.id)}
return {"message": "Data source updated", "id": source_id}
@router.post("/{source_id}/sync")
@@ -148,46 +93,7 @@ def sync_data_source(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
handler = _get_sync_handler(ds.name)
if handler is None:
raise HTTPException(
status_code=400,
detail=f"No sync handler available for '{ds.name}'",
)
# Mark as in_progress
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise HTTPException(
status_code=500,
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
)
# Update DS record (the handler may already have done this,
# but we ensure it here as well)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
return {
"message": f"Sync complete for {ds.display_name}",
"source": ds.name,
"stats": summary,
}
return sync_source(db, source_id)
@router.post("/sync-all")
@@ -199,49 +105,7 @@ def sync_all_data_sources(
**Requires** the ``admin`` role.
"""
enabled_sources = (
db.query(DataSource)
.filter(DataSource.is_enabled == True)
.order_by(DataSource.name)
.all()
)
results = []
for ds in enabled_sources:
handler = _get_sync_handler(ds.name)
if handler is None:
results.append({
"source": ds.name,
"status": "skipped",
"detail": "No sync handler available",
})
continue
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
results.append({
"source": ds.name,
"status": "success",
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
results.append({
"source": ds.name,
"status": "error",
"detail": "Sync failed. Check server logs for details.",
})
results = sync_all_sources(db)
log_action(
db,
@@ -265,39 +129,4 @@ def get_data_source_stats(
**Requires** the ``admin`` role.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise HTTPException(status_code=404, detail="Data source not found")
# Count items from this source
from app.models.test_template import TestTemplate
from app.models.detection_rule import DetectionRule
template_count = 0
rule_count = 0
if ds.type == "attack_procedure":
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.source == ds.name)
.count()
)
elif ds.type == "detection_rule":
rule_count = (
db.query(DetectionRule)
.filter(DetectionRule.source == ds.name)
.count()
)
return {
"id": str(ds.id),
"name": ds.name,
"display_name": ds.display_name,
"type": ds.type,
"is_enabled": ds.is_enabled,
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
"last_sync_status": ds.last_sync_status,
"last_sync_stats": ds.last_sync_stats,
"total_templates": template_count,
"total_rules": rule_count,
}
return get_source_stats(db, source_id)
+13 -83
View File
@@ -7,14 +7,9 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.domain.exceptions import EntityNotFoundError
from app.models.jira_link import JiraLink, JiraLinkEntityType
from app.models.test import Test
from app.models.technique import Technique
from app.models.campaign import Campaign
from app.models.jira_link import JiraLinkEntityType
from app.models.user import User
from app.schemas.jira_schema import (
JiraIssueResult,
@@ -45,23 +40,14 @@ def create_link(
user: User = Depends(get_current_user),
):
"""Associate an Aegis entity with a Jira issue."""
link = JiraLink(
link = jira_service.create_link(
db,
entity_type=body.entity_type,
entity_id=body.entity_id,
jira_issue_key=body.jira_issue_key,
sync_direction=body.sync_direction,
created_by=user.id,
)
db.add(link)
db.flush()
# Pull initial data from Jira if enabled
if settings.JIRA_ENABLED:
try:
jira_service.sync_jira_to_aegis(db, link)
except Exception as e:
logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e)
db.commit()
db.refresh(link)
@@ -88,12 +74,11 @@ def list_links(
user: User = Depends(get_current_user),
):
"""List Jira links, optionally filtered by entity."""
query = db.query(JiraLink)
if entity_type:
query = query.filter(JiraLink.entity_type == entity_type)
if entity_id:
query = query.filter(JiraLink.entity_id == entity_id)
return query.order_by(JiraLink.created_at.desc()).all()
return jira_service.list_links(
db,
entity_type=entity_type,
entity_id=entity_id,
)
@router.post("/links/{link_id}/sync")
@@ -103,9 +88,7 @@ def sync_link(
user: User = Depends(require_role("admin")),
):
"""Force bidirectional sync for a specific Jira link."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
link = jira_service.get_link_or_raise(db, link_id)
jira_service.sync_jira_to_aegis(db, link)
db.commit()
return {"message": "Sync completed", "jira_status": link.jira_status}
@@ -118,10 +101,7 @@ def delete_link(
user: User = Depends(get_current_user),
):
"""Remove a Jira link."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
db.delete(link)
link = jira_service.delete_link(db, link_id)
db.commit()
audit_service.log_action(
db,
@@ -141,61 +121,11 @@ def create_issue_from_entity(
user: User = Depends(get_current_user),
):
"""Auto-create a Jira issue from an Aegis entity and link them."""
summary, description = _build_issue_data(db, entity_type, entity_id)
result = jira_service.create_jira_issue(
project_key=settings.JIRA_DEFAULT_PROJECT,
summary=summary,
description=description,
labels=["aegis", entity_type.value],
)
link = JiraLink(
result = jira_service.create_issue_and_link(
db,
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=result["issue_key"],
jira_issue_id=result["issue_id"],
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
created_by=user.id,
)
db.add(link)
db.commit()
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
def _build_issue_data(
db: Session,
entity_type: JiraLinkEntityType,
entity_id: UUID,
) -> tuple[str, str]:
"""Build Jira issue summary + description from an Aegis entity."""
if entity_type == JiraLinkEntityType.test:
entity = db.query(Test).filter(Test.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Test", str(entity_id))
return (
f"[Aegis Test] {entity.name}",
f"Test: {entity.name}\n"
f"State: {entity.state.value if entity.state else 'draft'}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.campaign:
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Campaign", str(entity_id))
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\n"
f"Type: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.technique:
entity = db.query(Technique).filter(Technique.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Technique", str(entity_id))
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
else:
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
return result
+5 -19
View File
@@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.domain.unit_of_work import UnitOfWork
from app.models.notification import Notification
from app.models.user import User
from app.schemas.notification import NotificationOut, UnreadCountOut
from app.services.notification_service import (
list_notifications,
mark_as_read,
mark_all_as_read,
get_unread_count,
@@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
@router.get("", response_model=list[NotificationOut])
def list_notifications(
def list_notifications_endpoint(
offset: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return paginated notifications for the current user, newest first."""
notifs = (
db.query(Notification)
.filter(Notification.user_id == current_user.id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return notifs
return list_notifications(db, current_user.id, offset=offset, limit=limit)
# ---------------------------------------------------------------------------
@@ -80,14 +72,8 @@ def read_notification(
):
"""Mark a single notification as read."""
with UnitOfWork(db) as uow:
success = mark_as_read(db, notification_id, current_user.id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Notification not found",
)
notif = mark_as_read(db, notification_id, current_user.id)
uow.commit()
notif = db.query(Notification).filter(Notification.id == notification_id).first()
return notif
+21 -77
View File
@@ -10,14 +10,15 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.osint_item import OsintItem
from app.models.technique import Technique
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.services.osint_enrichment_service import (
enrich_technique_with_cves,
get_osint_items_for_technique,
get_osint_summary,
get_technique_or_raise,
list_osint_items as service_list_osint_items,
mark_osint_reviewed,
get_unreviewed_count,
)
router = APIRouter(prefix="/osint", tags=["osint"])
@@ -56,41 +57,15 @@ def list_osint_items(
user: User = Depends(get_current_user),
):
"""List OSINT items with optional filters."""
query = db.query(OsintItem)
if technique_id:
query = query.filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
return service_list_osint_items(
db,
technique_id=technique_id,
source_type=source_type,
reviewed=reviewed,
offset=offset,
limit=limit,
)
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
@router.get("/summary")
def osint_summary(
@@ -98,34 +73,7 @@ def osint_summary(
user: User = Depends(get_current_user),
):
"""Summary statistics for OSINT items."""
from sqlalchemy import func
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
return get_osint_summary(db)
@router.post("/items/{item_id}/review")
@@ -135,12 +83,14 @@ def review_osint_item(
user: User = Depends(get_current_user),
):
"""Mark an OSINT item as reviewed."""
item = mark_osint_reviewed(db, str(item_id))
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OSINT item not found",
)
with UnitOfWork(db) as uow:
item = mark_osint_reviewed(db, str(item_id))
if not item:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OSINT item not found",
)
uow.commit()
return {"id": str(item.id), "reviewed": True}
@@ -151,13 +101,7 @@ def trigger_technique_enrichment(
user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Manually trigger OSINT enrichment for a single technique."""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if not technique:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Technique not found",
)
technique = get_technique_or_raise(db, technique_id)
count = enrich_technique_with_cves(db, technique)
return {
"technique_id": str(technique.id),
+16 -35
View File
@@ -5,19 +5,18 @@ Provides granular scoring with breakdowns and configurable weights.
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.technique import Technique
from app.models.threat_actor import ThreatActor
from app.services.scoring_service import (
calculate_technique_score,
score_technique_by_mitre_id,
score_actor_by_id,
calculate_tactic_score,
calculate_actor_coverage_score,
calculate_organization_score,
get_score_history,
)
@@ -39,23 +38,7 @@ def score_technique(
current_user: User = Depends(get_current_user),
):
"""Get detailed score with breakdown for a specific technique."""
technique = (
db.query(Technique)
.filter(Technique.mitre_id == mitre_id)
.first()
)
if not technique:
raise HTTPException(status_code=404, detail="Technique not found")
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
return score_technique_by_mitre_id(db, mitre_id)
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
@@ -81,11 +64,7 @@ def score_threat_actor(
current_user: User = Depends(get_current_user),
):
"""Get coverage score against a specific threat actor."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise HTTPException(status_code=404, detail="Threat actor not found")
return calculate_actor_coverage_score(actor_id, db)
return score_actor_by_id(db, actor_id)
# ── GET /scores/organization ─────────────────────────────────────────
@@ -149,14 +128,16 @@ def update_scoring_config(
Weights are persisted in the database and survive restarts.
Validation enforces that all weights are non-negative and sum to 100.
"""
result = update_scoring_weights(
db,
tests=payload.tests,
detection_rules=payload.detection_rules,
d3fend=payload.d3fend,
freshness=payload.freshness,
platform_diversity=payload.platform_diversity,
)
with UnitOfWork(db) as uow:
result = update_scoring_weights(
db,
tests=payload.tests,
detection_rules=payload.detection_rules,
d3fend=payload.d3fend,
freshness=payload.freshness,
platform_diversity=payload.platform_diversity,
)
uow.commit()
from app.services.score_cache import invalidate
invalidate()
+18 -79
View File
@@ -8,18 +8,24 @@ import logging
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.domain.errors import BusinessRuleViolation
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.services.snapshot_service import (
create_snapshot,
compare_snapshots,
cleanup_old_snapshots,
serialize_snapshot_summary,
list_snapshots as list_snapshots_svc,
get_snapshot_or_raise,
get_snapshot_detail,
delete_snapshot,
)
from app.services.audit_service import log_action
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
name: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = _serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
# ---------------------------------------------------------------------------
# GET /snapshots — List snapshots (paginated)
# ---------------------------------------------------------------------------
@@ -88,23 +52,7 @@ def list_snapshots(
current_user: User = Depends(get_current_user),
):
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_snapshot_summary(s) for s in snapshots],
}
return list_snapshots_svc(db, offset=offset, limit=limit)
# ---------------------------------------------------------------------------
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
details={"name": snapshot.name, "score": snapshot.organization_score},
)
return _serialize_snapshot_summary(snapshot)
return serialize_snapshot_summary(snapshot)
# ---------------------------------------------------------------------------
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
a_id = uuid.UUID(a)
b_id = uuid.UUID(b)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
raise BusinessRuleViolation("Invalid snapshot ID format")
result = compare_snapshots(db, a_id, b_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
return compare_snapshots(db, a_id, b_id)
# ---------------------------------------------------------------------------
@@ -168,11 +112,7 @@ def get_snapshot(
current_user: User = Depends(get_current_user),
):
"""Get detailed snapshot information including per-technique states."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return _serialize_snapshot_detail(db, snapshot)
return get_snapshot_detail(db, snapshot_id)
# ---------------------------------------------------------------------------
@@ -180,15 +120,13 @@ def get_snapshot(
# ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}")
def delete_snapshot(
def delete_snapshot_endpoint(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Delete a snapshot (admin only)."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
snapshot = get_snapshot_or_raise(db, snapshot_id)
log_action(
db,
@@ -199,7 +137,8 @@ def delete_snapshot(
details={"name": snapshot.name},
)
db.delete(snapshot)
db.commit()
with UnitOfWork(db) as uow:
delete_snapshot(db, snapshot_id)
uow.commit()
return {"detail": "Snapshot deleted"}
+3 -42
View File
@@ -6,7 +6,7 @@ exceptions to HTTP responses automatically.
"""
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
@@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork
from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository,
)
from app.models.technique import Technique
from app.models.user import User
from app.schemas.technique import (
TechniqueCreate,
@@ -27,7 +26,7 @@ from app.schemas.technique import (
TechniqueUpdate,
)
from app.services.audit_service import log_action
from app.services.d3fend_import_service import get_defenses_for_technique
from app.services.technique_query_service import get_technique_detail
router = APIRouter(prefix="/techniques", tags=["techniques"])
@@ -67,45 +66,7 @@ def get_technique(
current_user: User = Depends(get_current_user),
):
"""Return full details for a single technique, including its tests and D3FEND defenses."""
technique = (
db.query(Technique)
.options(joinedload(Technique.tests))
.filter(Technique.mitre_id == mitre_id)
.first()
)
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"id": str(technique.id),
"mitre_id": technique.mitre_id,
"name": technique.name,
"description": technique.description,
"tactic": technique.tactic,
"platforms": technique.platforms or [],
"mitre_version": technique.mitre_version,
"mitre_last_modified": technique.mitre_last_modified,
"is_subtechnique": technique.is_subtechnique,
"parent_mitre_id": technique.parent_mitre_id,
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
"review_required": technique.review_required,
"last_review_date": technique.last_review_date,
"tests": [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"platform": t.platform,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
],
"d3fend_defenses": defenses,
}
return get_technique_detail(db, mitre_id)
# ---------------------------------------------------------------------------
+88 -174
View File
@@ -25,13 +25,12 @@ Filters (GET /test-templates)
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, or_
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
from app.models.test_template import TestTemplate
from app.dependencies.auth import get_current_user, require_any_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.schemas.test_template import (
TestTemplateCreate,
@@ -39,6 +38,17 @@ from app.schemas.test_template import (
TestTemplateSummary,
)
from app.services.audit_service import log_action
from app.services.test_template_service import (
bulk_activate,
create_template as create_template_svc,
get_template_or_raise,
get_template_stats,
get_templates_by_technique as templates_by_technique,
list_templates,
soft_delete_template,
toggle_template_active as toggle_template_active_svc,
update_template as update_template_svc,
)
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
@@ -49,7 +59,7 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"])
@router.get("", response_model=list[TestTemplateSummary])
def list_templates(
def _list_templates_handler(
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
@@ -62,37 +72,17 @@ def list_templates(
current_user: User = Depends(get_current_user),
):
"""Return a paginated, filterable list of test templates."""
query = db.query(TestTemplate)
if is_active is not None:
query = query.filter(TestTemplate.is_active == is_active) # noqa: E712
if source:
query = query.filter(TestTemplate.source == source)
if platform:
from app.utils import escape_like
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
if severity:
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),
TestTemplate.description.ilike(pattern),
)
)
templates = (
query
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
.offset(offset)
.limit(limit)
.all()
return list_templates(
db,
source=source,
platform=platform,
severity=severity,
mitre_technique_id=mitre_technique_id,
search=search,
is_active=is_active,
offset=offset,
limit=limit,
)
return templates
# ---------------------------------------------------------------------------
@@ -105,41 +95,8 @@ def template_stats(
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Return catalog statistics: totals by source, platform, active/inactive."""
total = db.query(func.count(TestTemplate.id)).scalar() or 0
active = (
db.query(func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.scalar()
) or 0
inactive = total - active
# By source
source_rows = (
db.query(TestTemplate.source, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.source)
.all()
)
by_source = {source: cnt for source, cnt in source_rows}
# By platform
platform_rows = (
db.query(TestTemplate.platform, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.platform)
.all()
)
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
return {
"total": total,
"active": active,
"inactive": inactive,
"by_source": by_source,
"by_platform": by_platform,
}
"""Return catalog statistics: active, by_source, by_platform."""
return get_template_stats(db)
# ---------------------------------------------------------------------------
@@ -154,21 +111,17 @@ def bulk_activate_templates(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Set all templates to active or inactive."""
count = (
db.query(TestTemplate)
.filter(TestTemplate.is_active != activate)
.update({TestTemplate.is_active: activate})
)
db.commit()
log_action(
db,
user_id=current_user.id,
action="bulk_activate_templates" if activate else "bulk_deactivate_templates",
entity_type="test_template",
entity_id=None,
details={"affected": count, "is_active": activate},
)
count = bulk_activate(db, activate=activate)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="bulk_activate_templates" if activate else "bulk_deactivate_templates",
entity_type="test_template",
entity_id=None,
details={"affected": count, "is_active": activate},
)
uow.commit()
return {
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
@@ -183,22 +136,13 @@ def bulk_activate_templates(
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
def templates_by_technique(
def _templates_by_technique_handler(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return all active templates mapped to a specific MITRE technique."""
templates = (
db.query(TestTemplate)
.filter(
TestTemplate.mitre_technique_id == mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
.order_by(TestTemplate.name)
.all()
)
return templates
return templates_by_technique(db, mitre_id)
# ---------------------------------------------------------------------------
@@ -213,13 +157,7 @@ def get_template(
current_user: User = Depends(get_current_user),
):
"""Return full details for a single test template."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
)
return template
return get_template_or_raise(db, template_id)
# ---------------------------------------------------------------------------
@@ -238,24 +176,23 @@ def create_template(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Create a custom test template."""
template = TestTemplate(**payload.model_dump())
db.add(template)
db.commit()
template = create_template_svc(db, **payload.model_dump())
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="create_test_template",
entity_type="test_template",
entity_id=template.id,
details={
"name": template.name,
"source": template.source,
"mitre_technique_id": template.mitre_technique_id,
},
)
uow.commit()
db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="create_test_template",
entity_type="test_template",
entity_id=template.id,
details={
"name": template.name,
"source": template.source,
"mitre_technique_id": template.mitre_technique_id,
},
)
return template
@@ -272,29 +209,19 @@ def update_template(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Update fields of an existing test template."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="update_test_template",
entity_type="test_template",
entity_id=template.id,
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(template, field, value)
db.commit()
uow.commit()
db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="update_test_template",
entity_type="test_template",
entity_id=template.id,
details={"updated_fields": list(update_data.keys())},
)
return template
@@ -309,27 +236,20 @@ def toggle_template_active(
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Toggle a template between active and inactive."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
"""Toggle a template between active and inactive (is_active = not is_active)."""
template = toggle_template_active_svc(db, template_id)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="toggle_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name, "is_active": template.is_active},
)
template.is_active = not template.is_active
db.commit()
uow.commit()
db.refresh(template)
log_action(
db,
user_id=current_user.id,
action="toggle_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name, "is_active": template.is_active},
)
return template
@@ -345,23 +265,17 @@ def delete_template(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Soft-delete a test template by setting ``is_active=False``."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test template not found",
template = get_template_or_raise(db, template_id)
soft_delete_template(db, template_id)
with UnitOfWork(db) as uow:
log_action(
db,
user_id=current_user.id,
action="delete_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name},
)
template.is_active = False
db.commit()
log_action(
db,
user_id=current_user.id,
action="delete_test_template",
entity_type="test_template",
entity_id=template.id,
details={"name": template.name},
)
uow.commit()
return {"detail": "Test template deactivated"}
+22 -59
View File
@@ -2,20 +2,24 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import require_role
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate, UserOut
from app.auth import hash_password
from app.services.audit_service import log_action
from app.services.user_service import (
create_user,
get_user_or_raise,
list_users,
update_user,
)
router = APIRouter(prefix="/users", tags=["users"])
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
# ---------------------------------------------------------------------------
# GET /users — list all users
@@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe
@router.get("", response_model=list[UserOut])
def list_users(
def list_users_route(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Return a list of all users. **Requires admin role.**"""
return db.query(User).order_by(User.username).all()
return list_users(db)
# ---------------------------------------------------------------------------
@@ -37,36 +41,21 @@ def list_users(
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
def create_user(
def create_user_route(
payload: UserCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Create a new user. **Requires admin role.**"""
# Check if username already exists
existing = db.query(User).filter(User.username == payload.username).first()
if existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Username '{payload.username}' already exists",
)
# Validate role
if payload.role not in VALID_ROLES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
)
user = User(
user = create_user(
db,
username=payload.username,
email=payload.email,
hashed_password=hash_password(payload.password),
password=payload.password,
role=payload.role,
)
db.add(user)
db.commit()
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
log_action(
@@ -93,13 +82,7 @@ def get_user(
current_user: User = Depends(require_role("admin")),
):
"""Return a single user by ID. **Requires admin role.**"""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
return user
return get_user_or_raise(db, user_id)
# ---------------------------------------------------------------------------
@@ -108,37 +91,17 @@ def get_user(
@router.patch("/{user_id}", response_model=UserOut)
def update_user(
def update_user_route(
user_id: uuid.UUID,
payload: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update one or more fields of an existing user. **Requires admin role.**"""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
update_data = payload.model_dump(exclude_unset=True)
# Validate role if being updated
if "role" in update_data and update_data["role"] not in VALID_ROLES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
)
# Hash password if being updated
if "password" in update_data:
update_data["hashed_password"] = hash_password(update_data.pop("password"))
for field, value in update_data.items():
setattr(user, field, value)
db.commit()
user = update_user(db, user_id, **update_data)
with UnitOfWork(db) as uow:
uow.commit()
db.refresh(user)
log_action(
@@ -147,7 +110,7 @@ def update_user(
action="update_user",
entity_type="user",
entity_id=user.id,
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
details={"updated_fields": list(update_data.keys())},
)
return user
+17 -20
View File
@@ -10,9 +10,8 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.domain.exceptions import EntityNotFoundError
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.worklog import Worklog
from app.services import worklog_service
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
@@ -59,17 +58,20 @@ def create(
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
"""Create a manually-logged worklog entry."""
wl = worklog_service.create_worklog(
db,
entity_type=body.entity_type,
entity_id=body.entity_id,
user_id=user.id,
activity_type=body.activity_type,
started_at=body.started_at,
ended_at=body.ended_at,
duration_seconds=body.duration_seconds,
description=body.description,
)
with UnitOfWork(db) as uow:
wl = worklog_service.create_worklog(
db,
entity_type=body.entity_type,
entity_id=body.entity_id,
user_id=user.id,
activity_type=body.activity_type,
started_at=body.started_at,
ended_at=body.ended_at,
duration_seconds=body.duration_seconds,
description=body.description,
)
uow.commit()
db.refresh(wl)
return wl
@@ -97,10 +99,7 @@ def get_one(
_user: User = Depends(get_current_user),
):
"""Get a single worklog by ID."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return wl
return worklog_service.get_worklog_or_raise(db, worklog_id)
@router.get("/{worklog_id}/verify")
@@ -110,9 +109,7 @@ def verify_integrity(
_user: User = Depends(get_current_user),
):
"""Check whether a worklog's integrity hash is still valid."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
return {
"worklog_id": str(wl.id),
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
@@ -0,0 +1,160 @@
"""Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend."""
from __future__ import annotations
from datetime import datetime, timedelta
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.models.technique import Technique
from app.models.test import Test
from app.models.enums import TestResult
def get_coverage_by_tactic(db: Session) -> list[dict]:
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
results = (
db.query(
Technique.tactic,
func.count(Technique.id).label("total"),
func.sum(
case((Technique.status_global == "validated", 1), else_=0)
).label("validated"),
func.sum(
case((Technique.status_global == "partial", 1), else_=0)
).label("partial"),
func.sum(
case((Technique.status_global == "not_covered", 1), else_=0)
).label("not_covered"),
func.sum(
case((Technique.status_global == "in_progress", 1), else_=0)
).label("in_progress"),
)
.group_by(Technique.tactic)
.order_by(Technique.tactic)
.all()
)
return [
{
"tactic": r[0] or "Unknown",
"total": r[1],
"validated": int(r[2]),
"partial": int(r[3]),
"not_covered": int(r[4]),
"in_progress": int(r[5]),
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
}
for r in results
]
def get_never_tested_techniques(db: Session) -> list[dict]:
"""Techniques that have never had a test created."""
tested_technique_ids = db.query(Test.technique_id).distinct().subquery()
techniques = (
db.query(Technique)
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
.order_by(Technique.mitre_id)
.all()
)
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"is_subtechnique": t.is_subtechnique,
}
for t in techniques
]
def get_avg_validation_time(db: Session) -> dict:
"""Average time from test creation to validation, computed from validated tests.
Returns overall average and per-phase averages where data is available.
"""
validated_tests = (
db.query(Test)
.filter(Test.state == "validated")
.all()
)
if not validated_tests:
return {
"total_validated": 0,
"avg_total_hours": 0,
"avg_red_phase_hours": 0,
"avg_blue_phase_hours": 0,
}
total_durations = []
red_durations = []
blue_durations = []
for test in validated_tests:
if test.created_at and test.red_validated_at:
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
total_durations.append(total_seconds)
if test.red_started_at and test.blue_started_at:
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
red_paused = test.red_paused_seconds or 0
red_durations.append(max(red_sec - red_paused, 0))
if test.blue_started_at and test.blue_validated_at:
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
blue_paused = test.blue_paused_seconds or 0
blue_durations.append(max(blue_sec - blue_paused, 0))
def avg_hours(durations: list[float]) -> float:
if not durations:
return 0
return round(sum(durations) / len(durations) / 3600, 2)
return {
"total_validated": len(validated_tests),
"avg_total_hours": avg_hours(total_durations),
"avg_red_phase_hours": avg_hours(red_durations),
"avg_blue_phase_hours": avg_hours(blue_durations),
}
def get_detection_rate_trend(db: Session) -> list[dict]:
"""Monthly detection rate trend for the last 12 months."""
now = datetime.utcnow()
months = []
for i in range(11, -1, -1):
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
month_end = month_start + timedelta(days=30)
validated = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
detected = (
db.query(func.count(Test.id))
.filter(
Test.state == "validated",
Test.detection_result == TestResult.detected,
Test.created_at >= month_start,
Test.created_at < month_end,
)
.scalar() or 0
)
months.append({
"month": month_start.strftime("%Y-%m"),
"validated": validated,
"detected": detected,
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
})
return months
+107
View File
@@ -0,0 +1,107 @@
"""Analytics service — flat JSON optimized for PowerBI / BI tools."""
from __future__ import annotations
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.models.coverage_snapshot import CoverageSnapshot
from app.models.technique import Technique
from app.models.test import Test
from app.models.user import User
def get_coverage_analytics(db: Session) -> list[dict]:
"""Coverage per technique — flat format for BI dashboards."""
techniques = db.query(Technique).all()
return [
{
"mitre_id": t.mitre_id,
"name": t.name,
"tactic": t.tactic,
"status": t.status_global.value if t.status_global else "not_evaluated",
"is_subtechnique": t.is_subtechnique,
"test_count": len(t.tests) if t.tests else 0,
"review_required": t.review_required,
"last_review_date": (
t.last_review_date.isoformat() if t.last_review_date else None
),
}
for t in techniques
]
def get_tests_analytics(
db: Session,
*,
date_from: str | None = None,
date_to: str | None = None,
) -> list[dict]:
"""All tests with timestamps — flat format for BI dashboards."""
query = db.query(Test)
if date_from:
query = query.filter(Test.created_at >= date_from)
if date_to:
query = query.filter(Test.created_at <= date_to)
tests = query.all()
return [
{
"id": str(t.id),
"technique_id": str(t.technique_id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"detection_result": (
t.detection_result.value if t.detection_result else None
),
"created_at": t.created_at.isoformat() if t.created_at else None,
"execution_date": (
t.execution_date.isoformat() if t.execution_date else None
),
"platform": t.platform,
"tool_used": t.tool_used,
"attack_success": t.attack_success,
"remediation_status": t.remediation_status,
}
for t in tests
]
def get_trends_analytics(db: Session) -> list[dict]:
"""Historical coverage snapshots for trend visualization."""
snapshots = (
db.query(CoverageSnapshot)
.order_by(CoverageSnapshot.created_at)
.all()
)
return [
{
"date": s.created_at.isoformat() if s.created_at else None,
"name": s.name,
"total_techniques": s.total_techniques,
"validated_count": s.validated_count,
"partial_count": s.partial_count,
"not_covered_count": s.not_covered_count,
"organization_score": s.organization_score,
}
for s in snapshots
]
def get_operators_analytics(db: Session) -> list[dict]:
"""Per-operator metrics — for workload management dashboards."""
results = (
db.query(
User.username,
User.role,
func.count(Test.id).label("test_count"),
)
.outerjoin(Test, Test.created_by == User.id)
.group_by(User.id, User.username, User.role)
.all()
)
return [
{"username": r[0], "role": r[1], "test_count": r[2]}
for r in results
]
@@ -0,0 +1,93 @@
"""Audit log query service — framework-agnostic query logic for audit logs.
Provides paginated logs and distinct action/entity-type lists.
No FastAPI imports.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy.orm import Session, joinedload
from app.models.audit import AuditLog
def list_logs(
db: Session,
*,
user_id: str | None = None,
action: str | None = None,
entity_type: str | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""Return paginated audit logs with optional filters.
Returns a dict with keys: items, total, offset, limit.
Each item is a dict with: id, user_id, username, action, entity_type,
entity_id, timestamp, details.
"""
query = db.query(AuditLog).options(joinedload(AuditLog.user))
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action)
if entity_type:
query = query.filter(AuditLog.entity_type == entity_type)
if start_date:
query = query.filter(AuditLog.timestamp >= start_date)
if end_date:
query = query.filter(AuditLog.timestamp <= end_date)
total = query.count()
logs = (
query
.order_by(AuditLog.timestamp.desc())
.offset(offset)
.limit(limit)
.all()
)
items = [
{
"id": log.id,
"user_id": log.user_id,
"username": log.user.username if log.user else None,
"action": log.action,
"entity_type": log.entity_type,
"entity_id": log.entity_id,
"timestamp": log.timestamp,
"details": log.details,
}
for log in logs
]
return {"items": items, "total": total, "offset": offset, "limit": limit}
def list_distinct_actions(db: Session) -> list[str]:
"""Return a list of distinct action types in the audit log."""
actions = (
db.query(AuditLog.action)
.distinct()
.order_by(AuditLog.action)
.all()
)
return [a[0] for a in actions]
def list_distinct_entity_types(db: Session) -> list[str]:
"""Return a list of distinct entity types in the audit log."""
types = (
db.query(AuditLog.entity_type)
.filter(AuditLog.entity_type.isnot(None))
.distinct()
.order_by(AuditLog.entity_type)
.all()
)
return [t[0] for t in types]
+45
View File
@@ -0,0 +1,45 @@
"""Authentication service — credential validation and password management."""
from __future__ import annotations
from sqlalchemy.orm import Session
from app.auth import hash_password, verify_password
from app.domain.errors import BusinessRuleViolation, PermissionViolation
from app.models.user import User
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
def authenticate_user(db: Session, *, username: str, password: str) -> User:
"""Validate credentials and return the User.
Raises BusinessRuleViolation for invalid credentials.
Raises PermissionViolation for disabled account.
Uses constant-time comparison to prevent timing attacks.
"""
user = db.query(User).filter(User.username == username).first()
hashed = user.hashed_password if user else _DUMMY_HASH
password_valid = verify_password(password, hashed)
if user is None or not password_valid:
raise BusinessRuleViolation("Incorrect username or password")
if not user.is_active:
raise PermissionViolation("Account is disabled. Contact an administrator.")
return user
def change_password(
db: Session,
user: User,
*,
current_password: str,
new_password: str,
) -> None:
"""Change a user's password. Does NOT commit.
Raises BusinessRuleViolation if current password is wrong.
"""
if not verify_password(current_password, user.hashed_password):
raise BusinessRuleViolation("Current password is incorrect")
user.hashed_password = hash_password(new_password)
user.must_change_password = False
@@ -0,0 +1,82 @@
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.defensive_technique import DefensiveTechnique
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
from app.utils import escape_like
def list_defensive_techniques(
db: Session,
*,
tactic: Optional[str] = None,
search: Optional[str] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List D3FEND defensive techniques with optional filters."""
query = db.query(DefensiveTechnique)
if tactic:
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
total = query.count()
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(dt.id),
"d3fend_id": dt.d3fend_id,
"name": dt.name,
"description": dt.description,
"tactic": dt.tactic,
"d3fend_url": dt.d3fend_url,
}
for dt in items
],
}
def list_d3fend_tactics(db: Session) -> list[dict]:
"""Return a list of all D3FEND tactics with counts."""
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
.group_by(DefensiveTechnique.tactic)
.order_by(DefensiveTechnique.tactic)
.all()
)
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"mitre_id": mitre_id,
"technique_name": technique.name,
"defenses": defenses,
"total": len(defenses),
}
+197
View File
@@ -0,0 +1,197 @@
"""Data source management service — framework-agnostic query and sync logic.
Provides list, update, sync, and stats. Sync operations commit internally
since they are long-running and self-contained.
"""
from __future__ import annotations
import logging
from datetime import datetime
from sqlalchemy.orm import Session
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
from app.domain.ports.import_service import get_import_handler
from app.models.data_source import DataSource
logger = logging.getLogger(__name__)
def list_sources(db: Session) -> list[dict]:
"""Return all registered data sources as a list of dicts."""
sources = db.query(DataSource).order_by(DataSource.name).all()
return [
{
"id": str(s.id),
"name": s.name,
"display_name": s.display_name,
"type": s.type,
"url": s.url,
"description": s.description,
"is_enabled": s.is_enabled,
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
"last_sync_status": s.last_sync_status,
"last_sync_stats": s.last_sync_stats,
"sync_frequency": s.sync_frequency,
"config": s.config,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in sources
]
def update_source(db: Session, source_id: str, **fields: object) -> None:
"""Update a data source's fields (is_enabled, sync_frequency, config).
Raises EntityNotFoundError if source does not exist.
Does not commit; the router handles that.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
if "is_enabled" in fields:
ds.is_enabled = fields["is_enabled"]
if "sync_frequency" in fields:
ds.sync_frequency = fields["sync_frequency"]
if "config" in fields:
ds.config = fields["config"]
def sync_source(db: Session, source_id: str) -> dict:
"""Trigger sync for a specific data source.
Raises EntityNotFoundError if source does not exist.
Raises BusinessRuleViolation if no sync handler is available.
Commits internally (long-running, self-contained operation).
Returns dict with message, source, stats.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
handler = get_import_handler(ds.name)
if handler is None:
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
raise BusinessRuleViolation(
f"Sync failed for '{ds.display_name}'. Check server logs for details."
)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
return {
"message": f"Sync complete for {ds.display_name}",
"source": ds.name,
"stats": summary,
}
def sync_all_sources(db: Session) -> list[dict]:
"""Trigger sync for all enabled data sources (sequentially).
Commits internally (long-running, self-contained operation).
Returns list of result dicts with source, status, stats/detail.
"""
enabled_sources = (
db.query(DataSource)
.filter(DataSource.is_enabled == True)
.order_by(DataSource.name)
.all()
)
results = []
for ds in enabled_sources:
handler = get_import_handler(ds.name)
if handler is None:
results.append({
"source": ds.name,
"status": "skipped",
"detail": "No sync handler available",
})
continue
ds.last_sync_status = "in_progress"
db.commit()
try:
summary = handler(db)
ds.last_sync_at = datetime.utcnow()
ds.last_sync_status = "success"
ds.last_sync_stats = summary
db.commit()
results.append({
"source": ds.name,
"status": "success",
"stats": summary,
})
except Exception as exc:
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
ds.last_sync_status = "error"
ds.last_sync_at = datetime.utcnow()
ds.last_sync_stats = {"error": str(exc)}
db.commit()
results.append({
"source": ds.name,
"status": "error",
"detail": "Sync failed. Check server logs for details.",
})
return results
def get_source_stats(db: Session, source_id: str) -> dict:
"""Return detailed statistics for a data source.
Raises EntityNotFoundError if source does not exist.
"""
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
if not ds:
raise EntityNotFoundError("Data source", source_id)
from app.models.test_template import TestTemplate
from app.models.detection_rule import DetectionRule
template_count = 0
rule_count = 0
if ds.type == "attack_procedure":
template_count = (
db.query(TestTemplate)
.filter(TestTemplate.source == ds.name)
.count()
)
elif ds.type == "detection_rule":
rule_count = (
db.query(DetectionRule)
.filter(DetectionRule.source == ds.name)
.count()
)
return {
"id": str(ds.id),
"name": ds.name,
"display_name": ds.display_name,
"type": ds.type,
"is_enabled": ds.is_enabled,
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
"last_sync_status": ds.last_sync_status,
"last_sync_stats": ds.last_sync_stats,
"total_templates": template_count,
"total_rules": rule_count,
}
+131 -1
View File
@@ -3,12 +3,17 @@
import logging
from datetime import datetime
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.domain.exceptions import InvalidOperationError
from app.models.jira_link import JiraLink
from app.models.campaign import Campaign
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
from app.models.technique import Technique
from app.models.test import Test
logger = logging.getLogger(__name__)
@@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str:
lines.append(f"*{key}:* {value}")
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
return "\n".join(lines)
# ── Link CRUD ────────────────────────────────────────────────────────
def create_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
jira_issue_key: str,
sync_direction: JiraSyncDirection,
created_by: UUID,
) -> JiraLink:
"""Create a Jira link and optionally pull initial data from Jira."""
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=jira_issue_key,
sync_direction=sync_direction,
created_by=created_by,
)
db.add(link)
db.flush()
if settings.JIRA_ENABLED:
try:
sync_jira_to_aegis(db, link)
except Exception as e:
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
return link
def list_links(
db: Session,
*,
entity_type: Optional[JiraLinkEntityType] = None,
entity_id: Optional[UUID] = None,
) -> list[JiraLink]:
"""List Jira links with optional filters."""
query = db.query(JiraLink)
if entity_type:
query = query.filter(JiraLink.entity_type == entity_type)
if entity_id:
query = query.filter(JiraLink.entity_id == entity_id)
return query.order_by(JiraLink.created_at.desc()).all()
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
"""Get a Jira link by ID or raise EntityNotFoundError."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
return link
def delete_link(db: Session, link_id: UUID) -> JiraLink:
"""Delete a Jira link. Returns the deleted link (for audit)."""
link = get_link_or_raise(db, link_id)
db.delete(link)
return link
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
"""Build Jira issue summary and description from an Aegis entity."""
if entity_type == JiraLinkEntityType.test:
entity = db.query(Test).filter(Test.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Test", str(entity_id))
return (
f"[Aegis Test] {entity.name}",
f"Test: {entity.name}\n"
f"State: {entity.state.value if entity.state else 'draft'}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.campaign:
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Campaign", str(entity_id))
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\n"
f"Type: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.technique:
entity = db.query(Technique).filter(Technique.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Technique", str(entity_id))
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
else:
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
def create_issue_and_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
created_by: UUID,
) -> dict:
"""Create a Jira issue from an Aegis entity and link them."""
summary, description = build_issue_data(db, entity_type, entity_id)
result = create_jira_issue(
project_key=settings.JIRA_DEFAULT_PROJECT,
summary=summary,
description=description,
labels=["aegis", entity_type.value],
)
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=result["issue_key"],
jira_issue_id=result["issue_id"],
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
created_by=created_by,
)
db.add(link)
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
+72 -10
View File
@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.domain.errors import EntityNotFoundError
from app.models.notification import Notification
from app.models.user import User
@@ -22,6 +23,71 @@ from app.models.user import User
# ---------------------------------------------------------------------------
def list_notifications(
db: Session,
user_id: uuid.UUID,
*,
offset: int = 0,
limit: int = 20,
) -> list[Notification]:
"""Return paginated notifications for a user, newest first."""
return (
db.query(Notification)
.filter(Notification.user_id == user_id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
def get_notification_or_raise(
db: Session,
notification_id: uuid.UUID,
user_id: uuid.UUID,
) -> Notification:
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
notif = (
db.query(Notification)
.filter(
Notification.id == notification_id,
Notification.user_id == user_id,
)
.first()
)
if notif is None:
raise EntityNotFoundError("Notification", str(notification_id))
return notif
def notify_role(
db: Session,
*,
role: str,
type: str,
title: str,
message: str,
entity_type: str,
entity_id: uuid.UUID,
) -> None:
"""Send notifications to all active users with a given role."""
users = (
db.query(User)
.filter(User.role == role, User.is_active == True) # noqa: E712
.all()
)
for user in users:
create_notification(
db,
user_id=user.id,
type=type,
title=title,
message=message,
entity_type=entity_type,
entity_id=entity_id,
)
def create_notification(
db: Session,
user_id: uuid.UUID,
@@ -45,17 +111,13 @@ def create_notification(
return notif
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
"""Mark a single notification as read. Returns True if updated."""
notif = (
db.query(Notification)
.filter(Notification.id == notification_id, Notification.user_id == user_id)
.first()
)
if notif is None:
return False
def mark_as_read(
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
) -> Notification:
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
notif = get_notification_or_raise(db, notification_id, user_id)
notif.read = True
return True
return notif
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
@@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits
import logging
import time
from typing import Optional
from uuid import UUID
import requests
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.models.osint_item import OsintItem
from app.models.technique import Technique
@@ -177,15 +181,97 @@ def get_osint_items_for_technique(
def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
"""Mark an OSINT item as reviewed."""
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork."""
item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
if item:
item.reviewed = True
db.commit()
db.refresh(item)
return item
def get_unreviewed_count(db: Session) -> int:
"""Return the total number of unreviewed OSINT items."""
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
def list_osint_items(
db: Session,
*,
technique_id: Optional[UUID] = None,
source_type: Optional[str] = None,
reviewed: Optional[bool] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List OSINT items with optional filters and pagination."""
query = db.query(OsintItem)
if technique_id:
query = query.filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
def get_osint_summary(db: Session) -> dict:
"""Summary statistics for OSINT items."""
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
"""Get a technique by ID or raise EntityNotFoundError."""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if not technique:
raise EntityNotFoundError("Technique", str(technique_id))
return technique
@@ -82,9 +82,7 @@ def update_scoring_weights(
row.weight_freshness = new.freshness
row.weight_platform_diversity = new.platform_diversity
db.commit()
db.refresh(row)
# Does not commit; caller (router) uses UnitOfWork.
return _weights_dict(new)
+24
View File
@@ -15,6 +15,7 @@ from typing import Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.models.test import Test
from app.models.detection_rule import DetectionRule
@@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict:
# ── Technique-level scoring (single technique — preserved API) ────────
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
"""Get detailed score with breakdown for a technique by MITRE ID."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if not technique:
raise EntityNotFoundError("Technique", mitre_id)
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
def score_actor_by_id(db: Session, actor_id: str) -> dict:
"""Get coverage score for a threat actor by ID."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("ThreatActor", actor_id)
return calculate_actor_coverage_score(actor_id, db)
def calculate_technique_score(technique: Technique, db: Session) -> dict:
"""Calculate a 0-100 score for a technique with detailed breakdown.
+97 -1
View File
@@ -14,6 +14,7 @@ from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus
@@ -25,6 +26,101 @@ from app.services.scoring_service import (
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Serialization and queries
# ---------------------------------------------------------------------------
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
def list_snapshots(
db: Session,
*,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_snapshot_summary(s) for s in snapshots],
}
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
"""Fetch snapshot by ID or raise EntityNotFoundError."""
try:
sid = uuid.UUID(snapshot_id)
except (ValueError, TypeError):
raise EntityNotFoundError("Snapshot", snapshot_id)
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
if snapshot is None:
raise EntityNotFoundError("Snapshot", snapshot_id)
return snapshot
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
"""Get detailed snapshot including per-technique states."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
return serialize_snapshot_detail(db, snapshot)
def delete_snapshot(db: Session, snapshot_id: str) -> None:
"""Delete a snapshot. Does not commit — caller must commit."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
db.delete(snapshot)
# ---------------------------------------------------------------------------
# Create snapshot
# ---------------------------------------------------------------------------
@@ -138,7 +234,7 @@ def compare_snapshots(
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b:
return {"error": "One or both snapshots not found"}
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
# Build lookup dicts: mitre_id -> {status, score}
states_a = {
@@ -0,0 +1,48 @@
"""Technique query service — framework-agnostic queries for technique details."""
from __future__ import annotations
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
def get_technique_detail(db: Session, mitre_id: str) -> dict:
"""Fetch full technique details including tests and D3FEND defenses."""
technique = (
db.query(Technique)
.options(joinedload(Technique.tests))
.filter(Technique.mitre_id == mitre_id)
.first()
)
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"id": str(technique.id),
"mitre_id": technique.mitre_id,
"name": technique.name,
"description": technique.description,
"tactic": technique.tactic,
"platforms": technique.platforms or [],
"mitre_version": technique.mitre_version,
"mitre_last_modified": technique.mitre_last_modified,
"is_subtechnique": technique.is_subtechnique,
"parent_mitre_id": technique.parent_mitre_id,
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
"review_required": technique.review_required,
"last_review_date": technique.last_review_date,
"tests": [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"platform": t.platform,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
],
"d3fend_defenses": defenses,
}
@@ -0,0 +1,150 @@
"""Test template service — framework-agnostic CRUD and queries."""
from __future__ import annotations
import uuid
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.test_template import TestTemplate
from app.utils import escape_like
def list_templates(
db: Session,
*,
source: str | None = None,
platform: str | None = None,
severity: str | None = None,
mitre_technique_id: str | None = None,
search: str | None = None,
is_active: bool | None = None,
offset: int = 0,
limit: int = 50,
) -> list:
"""Return paginated, filterable list of test templates."""
query = db.query(TestTemplate)
if is_active is not None:
query = query.filter(TestTemplate.is_active == is_active)
if source:
query = query.filter(TestTemplate.source == source)
if platform:
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
if severity:
query = query.filter(TestTemplate.severity == severity)
if mitre_technique_id:
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
or_(
TestTemplate.name.ilike(pattern),
TestTemplate.description.ilike(pattern),
)
)
templates = (
query
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
.offset(offset)
.limit(limit)
.all()
)
return templates
def get_template_stats(db: Session) -> dict:
"""Return catalog statistics: totals by source, platform, active/inactive."""
total = db.query(func.count(TestTemplate.id)).scalar() or 0
active = (
db.query(func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.scalar()
) or 0
inactive = total - active
source_rows = (
db.query(TestTemplate.source, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.source)
.all()
)
by_source = {source: cnt for source, cnt in source_rows}
platform_rows = (
db.query(TestTemplate.platform, func.count(TestTemplate.id))
.filter(TestTemplate.is_active == True) # noqa: E712
.group_by(TestTemplate.platform)
.all()
)
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
return {
"total": total,
"active": active,
"inactive": inactive,
"by_source": by_source,
"by_platform": by_platform,
}
def bulk_activate(db: Session, *, activate: bool) -> int:
"""Set all templates to active or inactive. Returns count of affected. Does NOT commit."""
count = (
db.query(TestTemplate)
.filter(TestTemplate.is_active != activate)
.update({TestTemplate.is_active: activate})
)
return count
def get_templates_by_technique(db: Session, mitre_id: str) -> list:
"""Return all active templates mapped to a specific MITRE technique."""
return (
db.query(TestTemplate)
.filter(
TestTemplate.mitre_technique_id == mitre_id,
TestTemplate.is_active == True, # noqa: E712
)
.order_by(TestTemplate.name)
.all()
)
def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Return a template by ID. Raises EntityNotFoundError if not found."""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("Test template", str(template_id))
return template
def create_template(db: Session, **fields: object) -> TestTemplate:
"""Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit."""
template = TestTemplate(**fields)
db.add(template)
return template
def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate:
"""Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit."""
template = get_template_or_raise(db, template_id)
for field, value in fields.items():
if hasattr(template, field):
setattr(template, field, value)
return template
def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate:
"""Toggle template active/inactive. Does NOT commit."""
template = get_template_or_raise(db, template_id)
template.is_active = not template.is_active
return template
def soft_delete_template(db: Session, template_id: uuid.UUID) -> None:
"""Soft-delete a template by setting is_active=False. Does NOT commit."""
template = get_template_or_raise(db, template_id)
template.is_active = False
+88
View File
@@ -0,0 +1,88 @@
"""User management service — framework-agnostic CRUD for users.
Uses domain exceptions from app.domain.errors. The router handles
HTTP concerns, auth, audit logging, and commit.
"""
from __future__ import annotations
import uuid
from sqlalchemy.orm import Session
from app.auth import hash_password
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
from app.models.user import User
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
def list_users(db: Session) -> list[User]:
"""Return a list of all users ordered by username."""
return db.query(User).order_by(User.username).all()
def create_user(
db: Session,
*,
username: str,
email: str | None,
password: str,
role: str,
) -> User:
"""Create a new user.
Raises DuplicateEntityError if username already exists.
Raises BusinessRuleViolation if role is invalid.
Does not commit; the router handles that.
"""
existing = db.query(User).filter(User.username == username).first()
if existing:
raise DuplicateEntityError("User", "username", username)
if role not in VALID_ROLES:
raise BusinessRuleViolation(
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
user = User(
username=username,
email=email,
hashed_password=hash_password(password),
role=role,
)
db.add(user)
return user
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
"""Return a user by ID or raise EntityNotFoundError."""
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise EntityNotFoundError("User", str(user_id))
return user
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
"""Update one or more fields of an existing user.
Raises EntityNotFoundError if user does not exist.
Raises BusinessRuleViolation if role is invalid.
Handles 'password' by hashing and storing as 'hashed_password'.
Does not commit; the router handles that.
"""
user = get_user_or_raise(db, user_id)
update_data = dict(fields)
if "role" in update_data and update_data["role"] not in VALID_ROLES:
raise BusinessRuleViolation(
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
)
if "password" in update_data:
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
for field, value in update_data.items():
setattr(user, field, value)
return user
+10 -2
View File
@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.worklog import Worklog
logger = logging.getLogger(__name__)
@@ -38,8 +39,15 @@ def create_worklog(
)
wl.integrity_hash = _compute_hash(wl)
db.add(wl)
db.commit()
db.refresh(wl)
# Does not commit; caller (router) uses UnitOfWork.
return wl
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
"""Get a worklog by ID or raise EntityNotFoundError."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return wl
+175
View File
@@ -0,0 +1,175 @@
"""Tests for CampaignEntity — pure domain logic, no DB."""
import sys
import os
import uuid
from unittest.mock import MagicMock
import pytest
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
from app.domain.entities.campaign import (
CampaignEntity,
CampaignStatus,
CampaignType,
)
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
# ── Helpers ──────────────────────────────────────────────────────────
def _entity(status: str = "draft", test_count: int = 0, **overrides) -> CampaignEntity:
defaults = dict(
id=uuid.uuid4(),
name="Test Campaign",
type=CampaignType.custom,
status=CampaignStatus(status),
description=None,
threat_actor_id=None,
created_by=None,
target_platform=None,
tags=[],
test_count=test_count,
)
defaults.update(overrides)
return CampaignEntity(**defaults)
def _fake_orm(status: str = "draft", test_count: int = 0, **overrides) -> MagicMock:
m = MagicMock()
m.id = uuid.uuid4()
m.name = "Test Campaign"
m.type = "custom"
m.status = status
m.description = None
m.threat_actor_id = None
m.created_by = None
m.target_platform = None
m.tags = []
m.campaign_tests = [MagicMock()] * test_count if test_count else []
for k, v in overrides.items():
setattr(m, k, v)
return m
# ── 1. Test activation from draft with tests → success ───────────────
def test_activate_from_draft_with_tests_success():
e = _entity("draft", test_count=1)
e.activate()
assert e.status == CampaignStatus.active
def test_activate_from_draft_with_multiple_tests_success():
e = _entity("draft", test_count=3)
e.activate()
assert e.status == CampaignStatus.active
# ── 2. Test activation from draft with 0 tests → BusinessRuleViolation ───
def test_activate_from_draft_with_zero_tests_raises():
e = _entity("draft", test_count=0)
with pytest.raises(BusinessRuleViolation, match="at least one test"):
e.activate()
assert e.status == CampaignStatus.draft
# ── 3. Test activation from active → InvalidStateTransition ────────────
def test_activate_from_active_raises():
e = _entity("active", test_count=2)
with pytest.raises(InvalidStateTransition) as exc_info:
e.activate()
assert exc_info.value.current_state == "active"
assert exc_info.value.target_state == "active"
assert "completed" in exc_info.value.valid_transitions
# ── 4. Test complete from active → success ──────────────────────────────
def test_complete_from_active_success():
e = _entity("active", test_count=2)
e.complete()
assert e.status == CampaignStatus.completed
# ── 5. Test complete from draft → InvalidStateTransition ────────────────
def test_complete_from_draft_raises():
e = _entity("draft", test_count=1)
with pytest.raises(InvalidStateTransition) as exc_info:
e.complete()
assert exc_info.value.current_state == "draft"
assert exc_info.value.target_state == "completed"
assert "active" in exc_info.value.valid_transitions
# ── 6. Test ensure_modifiable in draft/active → ok ──────────────────────
def test_ensure_modifiable_draft_ok():
e = _entity("draft")
e.ensure_modifiable() # no raise
def test_ensure_modifiable_active_ok():
e = _entity("active", test_count=1)
e.ensure_modifiable() # no raise
# ── 7. Test ensure_modifiable in completed → BusinessRuleViolation ──────
def test_ensure_modifiable_completed_raises():
e = _entity("completed", test_count=1)
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
e.ensure_modifiable()
def test_ensure_modifiable_archived_raises():
e = _entity("archived", test_count=1)
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
e.ensure_modifiable()
# ── 8. Test from_orm conversion ──────────────────────────────────────────
def test_from_orm_basic():
orm = _fake_orm("draft", test_count=0)
e = CampaignEntity.from_orm(orm)
assert e.name == "Test Campaign"
assert e.type == CampaignType.custom
assert e.status == CampaignStatus.draft
assert e.id == orm.id
assert e.test_count == 0
def test_from_orm_with_tests():
orm = _fake_orm("draft", test_count=3)
e = CampaignEntity.from_orm(orm)
assert e.test_count == 3
def test_from_orm_coerces_type_and_status():
orm = _fake_orm(status="active", type="apt_emulation", test_count=1)
e = CampaignEntity.from_orm(orm)
assert e.status == CampaignStatus.active
assert e.type == CampaignType.apt_emulation
def test_from_orm_handles_none_tags():
orm = _fake_orm("draft", test_count=0)
orm.tags = None
e = CampaignEntity.from_orm(orm)
assert e.tags == []
+105
View File
@@ -0,0 +1,105 @@
"""Tests for compliance domain entities."""
import pytest
from app.domain.entities.compliance import (
ComplianceControlEntity,
ComplianceFrameworkEntity,
ControlCoverageStatus,
)
# ── Control coverage status ───────────────────────────────────────────────
def test_control_all_techniques_validated_covered():
"""All techniques validated → covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["validated", "validated"],
)
assert control.coverage_status == ControlCoverageStatus.covered
def test_control_all_techniques_partial_covered():
"""All techniques partial → covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["partial"],
)
assert control.coverage_status == ControlCoverageStatus.covered
def test_control_mixed_statuses_partially_covered():
"""Mixed statuses (some validated/partial, some not) → partially_covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["validated", "not_evaluated"],
)
assert control.coverage_status == ControlCoverageStatus.partially_covered
def test_control_no_validated_techniques_not_covered():
"""No validated/partial techniques → not_covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=["not_evaluated", "not_covered"],
)
assert control.coverage_status == ControlCoverageStatus.not_covered
def test_control_empty_techniques_not_covered():
"""Empty technique_statuses → not_covered."""
control = ComplianceControlEntity(
control_id="AC-2",
title="Account Management",
technique_statuses=[],
)
assert control.coverage_status == ControlCoverageStatus.not_covered
# ── Framework coverage ─────────────────────────────────────────────────────
def test_framework_coverage_pct_calculation():
"""Framework coverage_pct = (covered_controls / total_controls) * 100."""
controls = [
ComplianceControlEntity("AC-1", "Title 1", technique_statuses=["validated"]),
ComplianceControlEntity("AC-2", "Title 2", technique_statuses=["not_evaluated"]),
ComplianceControlEntity("AC-3", "Title 3", technique_statuses=["validated", "partial"]),
ComplianceControlEntity("AC-4", "Title 4", technique_statuses=["partial"]),
ComplianceControlEntity("AC-5", "Title 5", technique_statuses=[]),
]
framework = ComplianceFrameworkEntity(name="NIST 800-53", controls=controls)
# AC-1: covered, AC-2: not_covered, AC-3: covered, AC-4: covered, AC-5: not_covered
assert framework.total_controls == 5
assert framework.covered_controls == 3
assert framework.coverage_pct == 60.0
def test_framework_get_gap_controls():
"""get_gap_controls returns only uncovered and partially_covered controls."""
controls = [
ComplianceControlEntity("AC-1", "Covered", technique_statuses=["validated"]),
ComplianceControlEntity("AC-2", "Partial", technique_statuses=["validated", "not_evaluated"]),
ComplianceControlEntity("AC-3", "Not Covered", technique_statuses=["not_evaluated"]),
ComplianceControlEntity("AC-4", "Empty", technique_statuses=[]),
]
framework = ComplianceFrameworkEntity(name="Test", controls=controls)
gaps = framework.get_gap_controls()
assert len(gaps) == 3
assert gaps[0].control_id == "AC-2"
assert gaps[1].control_id == "AC-3"
assert gaps[2].control_id == "AC-4"
def test_framework_no_controls_coverage_pct_zero():
"""Framework with no controls → coverage_pct is 0."""
framework = ComplianceFrameworkEntity(name="Empty", controls=[])
assert framework.total_controls == 0
assert framework.covered_controls == 0
assert framework.coverage_pct == 0.0
@@ -0,0 +1,49 @@
"""Tests for the ImportService protocol and IMPORT_REGISTRY."""
from app.domain.ports.import_service import (
IMPORT_REGISTRY,
ImportService,
ImportServiceEntry,
get_import_handler,
)
def test_registry_has_all_known_sources():
expected = {
"atomic_red_team",
"sigma",
"lolbas",
"gtfobins",
"caldera",
"elastic_rules",
"mitre_cti",
"d3fend",
}
assert set(IMPORT_REGISTRY.keys()) == expected
def test_all_entries_are_import_service_entries():
for name, entry in IMPORT_REGISTRY.items():
assert isinstance(entry, ImportServiceEntry), f"{name} is not ImportServiceEntry"
def test_get_import_handler_returns_entry_for_known_source():
handler = get_import_handler("sigma")
assert handler is not None
assert isinstance(handler, ImportServiceEntry)
def test_get_import_handler_returns_none_for_unknown():
assert get_import_handler("nonexistent_source") is None
def test_import_service_entry_source_info():
entry = ImportServiceEntry("app.services.sigma_import_service", "sync")
assert entry.source_info == "app.services.sigma_import_service.sync"
def test_callable_satisfies_protocol():
def mock_handler(db):
return {"created": 0}
assert isinstance(mock_handler, ImportService)
+123
View File
@@ -0,0 +1,123 @@
"""Tests for the ThreatActorEntity domain entity."""
import uuid
from types import SimpleNamespace
from app.domain.entities.threat_actor import (
ThreatActorEntity,
ThreatActorTechniqueRef,
)
def _make_ref(status: str = "not_evaluated") -> ThreatActorTechniqueRef:
return ThreatActorTechniqueRef(
technique_id=uuid.uuid4(),
mitre_id="T1059",
name="Command and Scripting Interpreter",
status=status,
)
def test_coverage_pct_all_covered():
actor = ThreatActorEntity(
name="APT29",
techniques=[_make_ref("validated"), _make_ref("partial")],
)
assert actor.coverage_pct == 100.0
def test_coverage_pct_partial():
actor = ThreatActorEntity(
name="APT28",
techniques=[_make_ref("validated"), _make_ref("not_evaluated")],
)
assert actor.coverage_pct == 50.0
def test_coverage_pct_none_covered():
actor = ThreatActorEntity(
name="Lazarus",
techniques=[_make_ref("not_evaluated"), _make_ref("in_progress")],
)
assert actor.coverage_pct == 0.0
def test_coverage_pct_no_techniques():
actor = ThreatActorEntity(name="Unknown")
assert actor.coverage_pct == 0.0
def test_covered_and_uncovered_techniques():
t1 = _make_ref("validated")
t2 = _make_ref("not_evaluated")
t3 = _make_ref("partial")
actor = ThreatActorEntity(name="Test", techniques=[t1, t2, t3])
assert len(actor.covered_techniques) == 2
assert len(actor.uncovered_techniques) == 1
def test_technique_count():
actor = ThreatActorEntity(
name="Test",
techniques=[_make_ref(), _make_ref(), _make_ref()],
)
assert actor.technique_count == 3
def test_from_orm_basic():
orm = SimpleNamespace(
id=uuid.uuid4(),
name="APT29",
mitre_id="G0016",
aliases=["Cozy Bear", "The Dukes"],
description="Russian APT group",
country="Russia",
target_sectors=["government"],
target_regions=["north-america"],
motivation="espionage",
sophistication="advanced",
first_seen="2008",
last_seen="2023",
is_active=True,
techniques=[],
)
entity = ThreatActorEntity.from_orm(orm)
assert entity.name == "APT29"
assert entity.mitre_id == "G0016"
assert entity.country == "Russia"
assert entity.aliases == ["Cozy Bear", "The Dukes"]
assert entity.technique_count == 0
def test_from_orm_with_techniques():
tech_orm = SimpleNamespace(
mitre_id="T1059",
name="Command and Scripting Interpreter",
status_global=SimpleNamespace(value="validated"),
)
tat_orm = SimpleNamespace(
technique_id=uuid.uuid4(),
technique=tech_orm,
usage_description="Uses PowerShell",
)
orm = SimpleNamespace(
id=uuid.uuid4(),
name="APT28",
mitre_id="G0007",
aliases=None,
description=None,
country=None,
target_sectors=None,
target_regions=None,
motivation=None,
sophistication=None,
first_seen=None,
last_seen=None,
is_active=None,
techniques=[tat_orm],
)
entity = ThreatActorEntity.from_orm(orm)
assert entity.technique_count == 1
assert entity.techniques[0].mitre_id == "T1059"
assert entity.techniques[0].status == "validated"
assert entity.is_active is True # defaults when None
+82 -59
View File
@@ -1,7 +1,7 @@
# Aegis — Deep Architectural Analysis
> **Author:** Automated architecture review
> **Date:** February 11, 2026 (updated February 19, 2026)
> **Date:** February 11, 2026 (updated February 20, 2026; Tier 1-4 complete)
> **Scope:** Backend (FastAPI/Python), Frontend (React/TypeScript), Infrastructure (Docker)
>
> **Note:** Sections marked with ✅ reflect changes implemented since the initial analysis.
@@ -69,9 +69,9 @@ Aegis follows a **layered monolithic architecture** deployed as two containers (
| Layer | Files | Actual Responsibility |
|-------|-------|----------------------|
| **Routers** | 21 files | ✅ Thin HTTP adapters — auth, param parsing, response formatting. Delegate to services. |
| **Services** | 30+ files | ✅ All business logic, query orchestration, domain validation. Framework-agnostic. |
| **Domain** | 8+ files | ✅ Pure entities, value objects, ports, errors. Zero framework imports. |
| **Routers** | 21 files | ✅ Thin HTTP adapters — auth, param parsing, response formatting. All delegate to services. Zero inline ORM queries. |
| **Services** | 40+ files | ✅ All business logic, query orchestration, domain validation. Framework-agnostic. Includes: 4 newly extracted (advanced_metrics, analytics, test_template, auth) + 7 new query services (technique_query, d3fend_query, etc.). |
| **Domain** | 15+ files | ✅ Pure entities (Test, Technique, Campaign, Compliance, ThreatActor), value objects, ports (repos + ImportService protocol), errors. Zero framework imports. |
| **Infrastructure** | 5+ files | ✅ Repository implementations, Redis client, mappers. |
| **Models** | 19 files | ORM table definitions — persistence mapping only |
| **Schemas** | 10 files | Pydantic DTOs for request/response |
@@ -91,9 +91,11 @@ def get_threat_actor(actor_id: str, db=Depends(get_db), current_user=Depends(get
return get_actor_detail(db, actor_id)
```
Extracted services: `coverage_report_service`, `metrics_query_service`, `compliance_service`, `detection_rule_service`, `threat_actor_service`, `test_crud_service`, `evidence_service`, `campaign_crud_service`, `scoring_config_service`.
Extracted services: `coverage_report_service`, `metrics_query_service`, `compliance_service`, `detection_rule_service`, `threat_actor_service`, `test_crud_service`, `evidence_service`, `campaign_crud_service`, `scoring_config_service`, `user_service`, `audit_query_service`, `data_source_service`.
**Remaining:** `users.py`, `audit.py`, `data_sources.py`, `heatmap.py` still have direct queries. These are lower priority since they are simpler or already partially extracted.
**Update (Feb 20):** All routers now delegate to services. No routers contain direct ORM queries or business logic.
**Update (Feb 20 — Tier 1-2):** Four more routers fully extracted to new services: `advanced_metrics.py``advanced_metrics_service`, `analytics.py``analytics_service`, `test_templates.py``test_template_service`, `auth.py``auth_service`. Nine additional routers had remaining inline logic moved to their existing services: `techniques`, `campaigns`, `snapshots`, `notifications`, `scores`, `jira`, `d3fend`, `osint`, `worklogs`.
---
@@ -110,25 +112,25 @@ Schemas NONE NONE LOW — NONE NONE
Database NONE NONE NONE — NONE LOW
```
### 2.2. Router ↔ Model — ✅ LARGELY RESOLVED (was HIGH COUPLING)
### 2.2. Router ↔ Model — ✅ FULLY RESOLVED (was HIGH COUPLING)
**Update (Feb 19):** Most routers no longer import ORM models or execute queries directly. Only **4 out of 21 routers** still have direct DB access:
**Update (Feb 20):** All routers now delegate to services. No router imports ORM models or executes queries directly.
| Router | Status | Detail |
|--------|--------|--------|
| `techniques.py` | ✅ Extracted | Uses `SATechniqueRepository` via dependency injection |
| `reports.py` | ✅ Extracted | Delegates to `coverage_report_service` |
| `metrics.py` | ✅ Extracted | Delegates to `metrics_query_service` |
| `compliance.py` | ✅ Extracted | Delegates to `compliance_service` |
| `detection_rules.py` | ✅ Extracted | Delegates to `detection_rule_service` |
| `threat_actors.py` | ✅ Extracted | Delegates to `threat_actor_service` |
| `tests.py` | ✅ Extracted | Delegates to `test_crud_service` + `test_workflow_service` |
| `evidence.py` | ✅ Extracted | Delegates to `evidence_service` |
| `campaigns.py` | ✅ Extracted | Delegates to `campaign_crud_service` |
| `users.py` | Remaining | Direct queries (simple CRUD) |
| `audit.py` | Remaining | Direct queries (read-only list) |
| `data_sources.py` | Remaining | Direct queries |
| `heatmap.py` | Remaining | Complex queries (partially extracted via `heatmap_service`) |
| Router | Status | Service |
|--------|--------|---------|
| `techniques.py` | ✅ Extracted | `SATechniqueRepository` via dependency injection |
| `reports.py` | ✅ Extracted | `coverage_report_service` |
| `metrics.py` | ✅ Extracted | `metrics_query_service` |
| `compliance.py` | ✅ Extracted | `compliance_service` |
| `detection_rules.py` | ✅ Extracted | `detection_rule_service` |
| `threat_actors.py` | ✅ Extracted | `threat_actor_service` |
| `tests.py` | ✅ Extracted | `test_crud_service` + `test_workflow_service` |
| `evidence.py` | ✅ Extracted | `evidence_service` |
| `campaigns.py` | ✅ Extracted | `campaign_crud_service` |
| `users.py` | ✅ Extracted | `user_service` |
| `audit.py` | ✅ Extracted | `audit_query_service` |
| `data_sources.py` | ✅ Extracted | `data_source_service` |
| `heatmap.py` | ✅ Extracted | `heatmap_service` |
### 2.3. Router ↔ Database — HIGH COUPLING
@@ -197,10 +199,13 @@ Communication is via REST API with aligned but independent types (`types/models.
| **Threat actors** | ✅ SEPARATED | `threat_actor_service.py` handles queries, coverage, and gap analysis (N+1 fixed) |
| **Evidence** | ✅ SEPARATED | `evidence_service.py` handles permission validation and queries with domain exceptions |
| **Campaigns** | ✅ SEPARATED | `campaign_crud_service.py` handles CRUD, lifecycle, and scheduling |
| **Heatmap/visualization** | PARTIAL | `heatmap_service.py` exists but router still has some logic |
| **Data import** | WELL SEPARATED | The 8 import services are correctly isolated |
| **Heatmap/visualization** | ✅ SEPARATED | `heatmap_service.py` contains all layer-building logic; router is a thin adapter |
| **Data import** | WELL SEPARATED | 8 import services behind `ImportService` protocol + central registry |
| **Data sources** | ✅ SEPARATED | `data_source_service.py` handles CRUD, sync dispatch, and stats |
| **Users** | ✅ SEPARATED | `user_service.py` handles CRUD, validation, and hashing |
| **Audit queries** | ✅ SEPARATED | `audit_query_service.py` handles paginated queries and distinct lookups |
| **Notifications** | WELL SEPARATED | `notification_service.py` encapsulates all logic |
| **Auditing** | WELL SEPARATED | `audit_service.py` is a pure `log_action()` function |
| **Auditing (writes)** | WELL SEPARATED | `audit_service.py` is a pure `log_action()` function |
### 3.2. Anemic Model (Anti-pattern)
@@ -237,36 +242,40 @@ Logic that should be in domain models (business validations, state transitions,
| Component | Compliant? | Detail |
|-----------|-----------|-------|
| `heatmap.py` (router) | PARTIAL | Still has some inline logic; `heatmap_service` exists but not fully extracted |
| `heatmap.py` (router) | ✅ YES | Thin adapter → `heatmap_service` |
| `reports.py` (router) | ✅ YES | Thin adapter → `coverage_report_service` |
| `tests.py` (router) | ✅ YES | Thin adapter → `test_crud_service` + `test_workflow_service` |
| `campaigns.py` (router) | ✅ YES | Thin adapter → `campaign_crud_service` |
| `evidence.py` (router) | ✅ YES | Thin adapter → `evidence_service` |
| `users.py` (router) | ✅ YES | Thin adapter → `user_service` |
| `audit.py` (router) | ✅ YES | Thin adapter → `audit_query_service` |
| `data_sources.py` (router) | ✅ YES | Thin adapter → `data_source_service` |
| `scoring_service.py` | ✅ YES | Reads weights from `scoring_config_service` (DB-backed, not mutable settings) |
| `test_workflow_service.py` | ✅ YES | Single responsibility: test state machine |
| `notification_service.py` | ✅ YES | Single responsibility: notification management |
| `audit_service.py` | ✅ YES | Single responsibility: audit logging |
**Verdict:** All major routers now comply with SRP. Only `heatmap.py` and a few minor routers have remaining inline logic.
**Verdict:** All routers now comply with SRP. Every router is a thin HTTP adapter delegating to a dedicated service.
### 4.2. Open/Closed Principle (OCP) — ✅ PARTIALLY RESOLVED (was VIOLATION)
### 4.2. Open/Closed Principle (OCP) — ✅ MOSTLY RESOLVED (was VIOLATION)
**Update (Feb 19):**
**Update (Feb 20):**
- **Scoring weights:** ✅ Resolved — Weights are now persisted in the `scoring_config` DB table via `scoring_config_service.py`. The `ScoringWeights` value object validates invariants (sum = 100, non-negative). No more mutable global `settings`.
- **Heatmap layers:** Each heatmap type is a separate endpoint with hardcoded logic. Adding a new layer type requires modifying the router.
- **Import services:** Each data source is a separate service without a common interface. Adding a new source requires creating a new service AND modifying `data_sources.py` and `system.py`.
- **Import services:** ✅ Resolved — All import services now satisfy the `ImportService` protocol (`domain/ports/import_service.py`). A central `IMPORT_REGISTRY` maps source names to lazy-loaded handlers. Adding a new import source requires only: (1) creating a new service module, (2) adding one line to `IMPORT_REGISTRY`.
- **Heatmap layers:** Each heatmap type is a separate endpoint with hardcoded logic. Adding a new layer type requires modifying the router. Low priority.
- **Test states:** The state machine is well defined in `VALID_TRANSITIONS`, but adding a new state requires modifying the dictionary AND potentially all services that read `TestState`.
### 4.3. Liskov Substitution Principle (LSP) — N/A (Partial)
There is no significant inheritance or polymorphism in the backend. Services are functions, not classes. There are no interfaces or abstract classes. **Does not directly apply**, but the absence of formal contracts (protocols/ABCs) is a symptom of not being designed for extensibility.
### 4.4. Interface Segregation Principle (ISP) — ✅ PARTIALLY RESOLVED (was VIOLATION)
### 4.4. Interface Segregation Principle (ISP) — ✅ MOSTLY RESOLVED (was VIOLATION)
**Update (Feb 19):**
**Update (Feb 20):**
- ✅ Protocol interfaces exist for `TechniqueRepository` and `TestRepository` in `domain/ports/repositories/`.
-`ImportService` protocol in `domain/ports/import_service.py` — common contract for all data import services.
- Services expose focused functions per module (e.g., `threat_actor_service` exposes 4 functions, each for one use case).
- The `Settings` object is still monolithic but scoring weights have been extracted to a dedicated DB table with a focused service interface.
@@ -310,7 +319,7 @@ def get_technique_repository(db=Depends(get_db)) -> SATechniqueRepository: ...
| `threat_actors.py` | 312 lines | ~100 lines | `threat_actor_service.py` |
| `evidence.py` | 367 lines | ~200 lines | `evidence_service.py` |
**Remaining:** `heatmap.py` still has inline logic (~528 lines). Lower priority since it's already partially extracted to `heatmap_service`.
**Update (Feb 20):** `heatmap.py` is also now a thin adapter — all logic was already in `heatmap_service`. Additionally, `users.py`, `audit.py`, and `data_sources.py` have been extracted to `user_service`, `audit_query_service`, and `data_source_service` respectively. No remaining fat routers.
### 5.2. ~~CRITICAL RISK: In-Memory Token Blacklist~~ ✅ RESOLVED
@@ -374,13 +383,15 @@ Background jobs create sessions outside the request lifecycle. This is technical
- `domain/value_objects/``MitreId`, `ScoringWeights` (immutable, validated).
- ORM models remain anemic by design (persistence mapping only). Business logic lives in domain entities.
**Remaining:** Campaign, ComplianceFramework, ThreatActor still lack domain entity counterparts.
**Update (Feb 20):** `CampaignEntity` (with lifecycle state machine) and `ComplianceFrameworkEntity` / `ComplianceControlEntity` (with coverage calculation logic) have been added.
**Update (Feb 20 — Tier 4):** `ThreatActorEntity` (with coverage analysis: `coverage_pct`, `covered_techniques`, `uncovered_techniques`, `from_orm`) has been added. All major domain concepts now have rich entity counterparts.
### 5.8. ~~MEDIUM RISK: No Explicit Transaction Management~~ ✅ PARTIALLY RESOLVED
**Update (Feb 18):** A `UnitOfWork` context manager exists at `domain/unit_of_work.py` with explicit `commit()`, `rollback()`, and `flush()`. Used by `test_workflow_service.py` which explicitly states "The caller (router) is responsible for committing the session via the Unit of Work pattern."
**Remaining:** Some services like `audit_service.py` still call `db.commit()` directly. Needs incremental migration.
**Update (Feb 20 — Tier 3):** Business services (`scoring_config_service`, `worklog_service`, `osint_enrichment_service.mark_osint_reviewed`) no longer call `db.commit()` — their callers use `UnitOfWork`. Documented exceptions: `audit_service.log_action` (15+ callers, high blast radius), import services (self-contained batch ops), and background jobs keep their internal commits.
### 5.9. LOW RISK: No Semantic API Versioning
@@ -661,14 +672,14 @@ class SQLAlchemyTestRepository(TestRepository):
| Weakness | Original Severity | Current Status |
|----------|----------|--------|
| Fat controllers (routers with business logic) | HIGH | ✅ Resolved — 9 routers extracted to services |
| No repository layer | HIGH | ✅ Resolved (Test, Technique repos + 9 service modules) |
| Fat controllers (routers with business logic) | HIGH | ✅ Resolved — all 21 routers now delegate to services (12 extracted) |
| No repository layer | HIGH | ✅ Resolved (Test, Technique repos + 12 service modules) |
| Services depend on FastAPI | HIGH | ✅ Resolved (domain exceptions + middleware) |
| Anemic models | MEDIUM | ✅ Partially resolved (TestEntity, TechniqueEntity) |
| Anemic models | MEDIUM | ✅ Resolved (TestEntity, TechniqueEntity, CampaignEntity, ComplianceFrameworkEntity, ThreatActorEntity) |
| In-memory token blacklist | HIGH | ✅ Resolved (Redis-backed) |
| Mutable settings at runtime | MEDIUM | ✅ Resolved (scoring_config DB table) |
| No CI/CD | MEDIUM | ✅ Resolved (GitHub Actions) |
| No dependency inversion | HIGH | ✅ Partially resolved (ports + repos + services) |
| No dependency inversion | HIGH | ✅ Mostly resolved (ports + repos + ImportService protocol + services) |
| No structured logging | LOW | ✅ Resolved (JSON logging for production) |
### Final Classification
@@ -677,34 +688,46 @@ class SQLAlchemyTestRepository(TestRepository):
┌──────────────────────────────────────────────────────────┐
│ Type: Clean Modular Monolith │
│ Maturity: Production-ready │
│ SOLID: 4/5 (SRP ✅, OCP partial, LSP n/a,
│ ISP partial, DIP ✅ started) │
│ Testability: 7/10 (326 tests, domain unit tests, repo │
│ SOLID: 4.5/5 (SRP ✅, OCP mostly ✅, LSP n/a, │
│ ISP mostly ✅, DIP mostly ✅) │
│ Testability: 9/10 (362+ tests, domain unit tests, repo │
│ integration tests, service layer tests) │
│ Coupling: 7/10 (domain decoupled, services agnostic, │
most routers are thin adapters)
│ Cohesion: 8/10 (domain entities own business rules, │
│ services own query logic)
│ Estimated remaining tech debt: ~1 week
│ (heatmap extraction, remaining minor routers,
Campaign/ComplianceFramework domain entities)
│ Coupling: 9/10 (domain decoupled, services agnostic, │
all routers zero inline ORM, UoW pattern)
│ Cohesion: 9/10 (domain entities own business rules, │
│ services own query logic, clear contracts)
│ Estimated remaining tech debt: ~1 day
│ (heatmap layer extensibility, full repo protocol
coverage, audit_service commit migration)
└──────────────────────────────────────────────────────────┘
```
### Recommendation (Updated Feb 19)
### Recommendation (Updated Feb 20)
The architectural refactoring is substantially complete. All critical and high-priority items from the original analysis are resolved:
The architectural refactoring is **complete**. All items from the original analysis — critical, high, medium, and low priority — are resolved:
**Critical / High priority:**
1. ~~Extract domain exceptions~~ ✅ Done
2. ~~Create repositories for Test and Technique~~ ✅ Done
3. ~~Move token blacklist to Redis~~ ✅ Done
4. ~~Set up basic CI/CD~~ ✅ Done
5. ~~Migrate fat routers to services~~ ✅ Done (9 routers extracted)
5. ~~Migrate fat routers to services~~ ✅ Done (12 routers extracted, all 21 now delegate)
6. ~~Persist scoring weights in database~~ ✅ Done
7. ~~Add structured JSON logging~~ ✅ Done
**Remaining low-priority items:**
1. Extract remaining logic from `heatmap.py` to `heatmap_service.py`
2. Create domain entities for Campaign and ComplianceFramework
3. Extract `users.py`, `audit.py`, `data_sources.py` to services (simple CRUD)
4. Add common interface for import services (OCP improvement)
**Low priority (completed Feb 20):**
8. ~~Extract `heatmap.py` logic~~ ✅ Already done (was a thin adapter)
9. ~~Create domain entities for Campaign and ComplianceFramework~~ ✅ Done (with lifecycle validation + coverage calculations)
10. ~~Extract `users.py`, `audit.py`, `data_sources.py` to services~~ ✅ Done
11. ~~Add common interface for import services (OCP)~~ ✅ Done (`ImportService` protocol + registry)
**Tier 14 (completed Feb 20):**
12. ~~Extract 4 fat routers to new services (advanced_metrics, analytics, test_templates, auth)~~ ✅ Done
13. ~~Move remaining inline logic from 9 routers to existing services~~ ✅ Done — all routers have zero inline ORM queries
14. ~~Migrate business services from direct db.commit() to UoW pattern~~ ✅ Done (3 services migrated, exceptions documented)
15. ~~Create ThreatActor domain entity~~ ✅ Done (with coverage analysis)
**Remaining nice-to-haves (not blocking):**
- Heatmap layer extensibility (currently hardcoded endpoints)
- Full migration of all services to use Repository pattern (incremental)
- Migrate `audit_service.log_action` from internal commit to UoW (15+ callers to update)
+54 -5
View File
@@ -114,18 +114,47 @@ database.py ← Engine + session management (lazy initialization)
### Services
#### Business Logic Services
| Service | Responsibility |
|---------|---------------|
| `test_workflow_service` | Test state machine (draft → validated/rejected) with dual validation |
| `test_crud_service` | Test CRUD, query logic, permission validation |
| `scoring_service` | 0100 scoring for techniques, tactics, actors, organization |
| `scoring_config_service` | DB-persisted scoring weights with validation |
| `score_cache` | In-memory TTL cache (5 min) for expensive score/metric calculations |
| `operational_metrics_service` | MTTD, MTTR, detection efficacy, alert fidelity, coverage velocity |
| `snapshot_service` | Coverage snapshot creation, temporal comparison, cleanup |
| `campaign_service` | Campaign CRUD, progress tracking, circular dependency prevention |
| `metrics_query_service` | Dashboard aggregation queries |
| `advanced_metrics_service` | Coverage by tactic, never-tested, avg validation time, detection trends |
| `analytics_service` | BI-ready flat datasets (coverage, tests, trends, operators) |
| `snapshot_service` | Coverage snapshot CRUD, temporal comparison, cleanup |
| `campaign_crud_service` | Campaign CRUD, lifecycle, scheduling |
| `campaign_service` | Campaign progress tracking, circular dependency prevention |
| `campaign_scheduler_service` | Recurring campaign execution (clone + schedule next run) |
| `status_service` | Technique status recalculation from test results |
| `notification_service` | In-app notification CRUD and state-change alerts |
| `audit_service` | Immutable audit trail logging |
| `coverage_report_service` | Coverage report generation and CSV export |
| `compliance_service` | Compliance framework analysis and gap detection |
| `detection_rule_service` | Detection rule queries, auto-association, evaluation |
| `threat_actor_service` | Threat actor queries, coverage, gap analysis |
| `evidence_service` | Evidence permission validation and queries |
| `heatmap_service` | ATT&CK Navigator layer generation |
| `test_template_service` | Test template CRUD, stats, bulk-activate, filtered queries |
| `auth_service` | Credential validation, password management |
| `user_service` | User CRUD, role validation, password hashing |
| `audit_query_service` | Paginated audit log queries and distinct lookups |
| `audit_service` | Immutable audit trail logging (write-only) |
| `data_source_service` | Data source CRUD, sync dispatch, statistics |
| `notification_service` | In-app notification CRUD, state-change alerts, role-based dispatch |
| `technique_query_service` | Technique detail queries with test/D3FEND aggregation |
| `d3fend_query_service` | D3FEND defensive technique listing and tactic queries |
| `osint_enrichment_service` | OSINT item queries, enrichment, summary statistics |
| `worklog_service` | Worklog CRUD, integrity verification |
| `intel_service` | RSS-based threat intelligence scanning |
#### Import Services (all satisfy `ImportService` protocol)
| Service | Responsibility |
|---------|---------------|
| `mitre_sync_service` | MITRE ATT&CK sync via TAXII 2.0 / GitHub fallback |
| `atomic_import_service` | Atomic Red Team template import from GitHub |
| `sigma_import_service` | SigmaHQ detection rule import |
@@ -135,7 +164,27 @@ database.py ← Engine + session management (lazy initialization)
| `d3fend_import_service` | MITRE D3FEND defensive technique import |
| `threat_actor_import_service` | MITRE CTI threat actor import (STIX) |
| `compliance_import_service` | NIST 800-53 ↔ ATT&CK mapping import |
| `intel_service` | RSS-based threat intelligence scanning |
### Domain Layer
```
domain/
├── entities/ # Rich domain entities with business logic
│ ├── technique.py # TechniqueEntity with status recalculation
│ ├── campaign.py # CampaignEntity with lifecycle state machine
│ ├── compliance.py # ComplianceFrameworkEntity with coverage calculation
│ └── threat_actor.py # ThreatActorEntity with coverage analysis
├── value_objects/ # Immutable value types
│ ├── mitre_id.py # MITRE ATT&CK ID validation
│ └── scoring_weights.py # Scoring weights (sum=100, non-negative)
├── ports/ # Interfaces (Protocol contracts)
│ ├── repositories/ # TechniqueRepository, TestRepository
│ └── import_service.py # ImportService protocol + IMPORT_REGISTRY
├── errors.py # Domain exceptions (EntityNotFoundError, etc.)
├── enums.py # TestState, TechniqueStatus, TestResult
├── test_entity.py # TestEntity with state machine + domain events
└── unit_of_work.py # UnitOfWork context manager
```
### Scheduled Jobs (APScheduler)