Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0c526c48f9 | |||
| 0d211d5156 | |||
| 14d995b40c | |||
| 339d669498 | |||
| 9e22fde746 | |||
| bbc2dddd86 | |||
| d77075272e | |||
| c0c6cda11d |
@@ -1,3 +1,18 @@
|
||||
from app.domain.entities.campaign import CampaignEntity
|
||||
from app.domain.entities.compliance import (
|
||||
ComplianceControlEntity,
|
||||
ComplianceFrameworkEntity,
|
||||
ControlCoverageStatus,
|
||||
)
|
||||
from app.domain.entities.technique import TechniqueEntity
|
||||
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
|
||||
|
||||
__all__ = ["TechniqueEntity"]
|
||||
__all__ = [
|
||||
"CampaignEntity",
|
||||
"ComplianceControlEntity",
|
||||
"ComplianceFrameworkEntity",
|
||||
"ControlCoverageStatus",
|
||||
"TechniqueEntity",
|
||||
"ThreatActorEntity",
|
||||
"ThreatActorTechniqueRef",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
"""Campaign domain entity with lifecycle validation.
|
||||
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||
|
||||
|
||||
class CampaignStatus(str, enum.Enum):
|
||||
draft = "draft"
|
||||
active = "active"
|
||||
completed = "completed"
|
||||
archived = "archived"
|
||||
|
||||
|
||||
class CampaignType(str, enum.Enum):
|
||||
custom = "custom"
|
||||
apt_emulation = "apt_emulation"
|
||||
kill_chain = "kill_chain"
|
||||
compliance = "compliance"
|
||||
|
||||
|
||||
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
||||
CampaignStatus.draft: [CampaignStatus.active],
|
||||
CampaignStatus.active: [CampaignStatus.completed],
|
||||
CampaignStatus.completed: [CampaignStatus.archived],
|
||||
CampaignStatus.archived: [],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignEntity:
|
||||
name: str
|
||||
type: CampaignType = CampaignType.custom
|
||||
status: CampaignStatus = CampaignStatus.draft
|
||||
id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
threat_actor_id: uuid.UUID | None = None
|
||||
created_by: uuid.UUID | None = None
|
||||
target_platform: str | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
test_count: int = 0
|
||||
|
||||
def can_transition_to(self, target: CampaignStatus) -> bool:
|
||||
return target in VALID_TRANSITIONS.get(self.status, [])
|
||||
|
||||
def activate(self) -> None:
|
||||
if not self.can_transition_to(CampaignStatus.active):
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.active.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
if self.test_count == 0:
|
||||
raise BusinessRuleViolation(
|
||||
"Campaign must have at least one test to activate"
|
||||
)
|
||||
self.status = CampaignStatus.active
|
||||
|
||||
def complete(self) -> None:
|
||||
if not self.can_transition_to(CampaignStatus.completed):
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.completed.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
self.status = CampaignStatus.completed
|
||||
|
||||
def archive(self) -> None:
|
||||
if not self.can_transition_to(CampaignStatus.archived):
|
||||
raise InvalidStateTransition(
|
||||
self.status.value, CampaignStatus.archived.value,
|
||||
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||
)
|
||||
self.status = CampaignStatus.archived
|
||||
|
||||
def ensure_modifiable(self) -> None:
|
||||
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot modify campaign in '{self.status.value}' state"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, orm: Any) -> CampaignEntity:
|
||||
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
|
||||
test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||
return cls(
|
||||
id=orm.id,
|
||||
name=orm.name,
|
||||
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
|
||||
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
|
||||
description=orm.description,
|
||||
threat_actor_id=orm.threat_actor_id,
|
||||
created_by=orm.created_by,
|
||||
target_platform=orm.target_platform,
|
||||
tags=orm.tags or [],
|
||||
test_count=test_count,
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Compliance domain entities with coverage calculation logic.
|
||||
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
class ControlCoverageStatus(str, enum.Enum):
|
||||
covered = "covered"
|
||||
partially_covered = "partially_covered"
|
||||
not_covered = "not_covered"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplianceControlEntity:
|
||||
control_id: str
|
||||
title: str
|
||||
id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
category: str | None = None
|
||||
technique_statuses: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def coverage_status(self) -> ControlCoverageStatus:
|
||||
if not self.technique_statuses:
|
||||
return ControlCoverageStatus.not_covered
|
||||
covered_statuses = {"validated", "partial"}
|
||||
covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||
if len(covered) == len(self.technique_statuses):
|
||||
return ControlCoverageStatus.covered
|
||||
elif len(covered) > 0:
|
||||
return ControlCoverageStatus.partially_covered
|
||||
return ControlCoverageStatus.not_covered
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplianceFrameworkEntity:
|
||||
name: str
|
||||
id: uuid.UUID | None = None
|
||||
version: str | None = None
|
||||
description: str | None = None
|
||||
is_active: bool = True
|
||||
controls: list[ComplianceControlEntity] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def total_controls(self) -> int:
|
||||
return len(self.controls)
|
||||
|
||||
@property
|
||||
def covered_controls(self) -> int:
|
||||
return sum(
|
||||
1 for c in self.controls
|
||||
if c.coverage_status == ControlCoverageStatus.covered
|
||||
)
|
||||
|
||||
@property
|
||||
def coverage_pct(self) -> float:
|
||||
if self.total_controls == 0:
|
||||
return 0.0
|
||||
return round(self.covered_controls / self.total_controls * 100, 1)
|
||||
|
||||
def get_gap_controls(self) -> list[ComplianceControlEntity]:
|
||||
return [
|
||||
c for c in self.controls
|
||||
if c.coverage_status != ControlCoverageStatus.covered
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Threat actor domain entity with coverage analysis logic.
|
||||
|
||||
Pure domain logic — no framework imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreatActorTechniqueRef:
|
||||
"""Lightweight reference to a technique used by an actor."""
|
||||
|
||||
technique_id: uuid.UUID
|
||||
mitre_id: str | None = None
|
||||
name: str | None = None
|
||||
status: str | None = None
|
||||
usage_description: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreatActorEntity:
|
||||
name: str
|
||||
id: uuid.UUID | None = None
|
||||
mitre_id: str | None = None
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
description: str | None = None
|
||||
country: str | None = None
|
||||
target_sectors: list[str] = field(default_factory=list)
|
||||
target_regions: list[str] = field(default_factory=list)
|
||||
motivation: str | None = None
|
||||
sophistication: str | None = None
|
||||
first_seen: str | None = None
|
||||
last_seen: str | None = None
|
||||
is_active: bool = True
|
||||
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def technique_count(self) -> int:
|
||||
return len(self.techniques)
|
||||
|
||||
@property
|
||||
def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||
return [
|
||||
t for t in self.techniques
|
||||
if t.status in ("validated", "partial")
|
||||
]
|
||||
|
||||
@property
|
||||
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||
return [
|
||||
t for t in self.techniques
|
||||
if t.status not in ("validated", "partial")
|
||||
]
|
||||
|
||||
@property
|
||||
def coverage_pct(self) -> float:
|
||||
if not self.techniques:
|
||||
return 0.0
|
||||
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||
|
||||
@classmethod
|
||||
def from_orm(cls, orm: Any) -> ThreatActorEntity:
|
||||
techs: list[ThreatActorTechniqueRef] = []
|
||||
for tat in getattr(orm, "techniques", None) or []:
|
||||
technique = getattr(tat, "technique", None)
|
||||
techs.append(ThreatActorTechniqueRef(
|
||||
technique_id=tat.technique_id,
|
||||
mitre_id=getattr(technique, "mitre_id", None) if technique else None,
|
||||
name=getattr(technique, "name", None) if technique else None,
|
||||
status=(
|
||||
technique.status_global.value
|
||||
if technique and hasattr(technique.status_global, "value")
|
||||
else getattr(technique, "status_global", None) if technique else None
|
||||
),
|
||||
usage_description=tat.usage_description,
|
||||
))
|
||||
return cls(
|
||||
id=orm.id,
|
||||
name=orm.name,
|
||||
mitre_id=orm.mitre_id,
|
||||
aliases=orm.aliases or [],
|
||||
description=orm.description,
|
||||
country=orm.country,
|
||||
target_sectors=orm.target_sectors or [],
|
||||
target_regions=orm.target_regions or [],
|
||||
motivation=orm.motivation,
|
||||
sophistication=orm.sophistication,
|
||||
first_seen=orm.first_seen,
|
||||
last_seen=orm.last_seen,
|
||||
is_active=orm.is_active if orm.is_active is not None else True,
|
||||
techniques=techs,
|
||||
)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Port defining the common interface for data import services.
|
||||
|
||||
All import services (Atomic Red Team, Sigma, CALDERA, etc.) follow the
|
||||
same contract: they receive a database session and return a summary dict
|
||||
with import statistics.
|
||||
|
||||
New import sources can be added by:
|
||||
1. Implementing the ``ImportService`` protocol in a new module
|
||||
2. Registering the handler in the ``IMPORT_REGISTRY``
|
||||
|
||||
This satisfies the Open/Closed Principle — the system is open for new
|
||||
import sources without modifying existing code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ImportService(Protocol):
|
||||
"""Contract for any data-import operation.
|
||||
|
||||
Each implementation is a callable ``(Session) -> dict`` that
|
||||
downloads, parses, and upserts records from an external source.
|
||||
"""
|
||||
|
||||
def __call__(self, db: Session) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
class ImportServiceEntry:
|
||||
"""Lazy-loading wrapper that resolves a module-level function on first call."""
|
||||
|
||||
__slots__ = ("_module_path", "_func_name", "_resolved")
|
||||
|
||||
def __init__(self, module_path: str, func_name: str) -> None:
|
||||
self._module_path = module_path
|
||||
self._func_name = func_name
|
||||
self._resolved: ImportService | None = None
|
||||
|
||||
def __call__(self, db: Session) -> dict[str, Any]:
|
||||
if self._resolved is None:
|
||||
import importlib
|
||||
mod = importlib.import_module(self._module_path)
|
||||
self._resolved = getattr(mod, self._func_name)
|
||||
return self._resolved(db)
|
||||
|
||||
@property
|
||||
def source_info(self) -> str:
|
||||
return f"{self._module_path}.{self._func_name}"
|
||||
|
||||
|
||||
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
|
||||
"atomic_red_team": ImportServiceEntry(
|
||||
"app.services.atomic_import_service", "import_atomic_red_team",
|
||||
),
|
||||
"sigma": ImportServiceEntry(
|
||||
"app.services.sigma_import_service", "sync",
|
||||
),
|
||||
"lolbas": ImportServiceEntry(
|
||||
"app.services.lolbas_import_service", "sync",
|
||||
),
|
||||
"gtfobins": ImportServiceEntry(
|
||||
"app.services.lolbas_import_service", "sync_gtfobins",
|
||||
),
|
||||
"caldera": ImportServiceEntry(
|
||||
"app.services.caldera_import_service", "sync",
|
||||
),
|
||||
"elastic_rules": ImportServiceEntry(
|
||||
"app.services.elastic_import_service", "sync",
|
||||
),
|
||||
"mitre_cti": ImportServiceEntry(
|
||||
"app.services.threat_actor_import_service", "sync",
|
||||
),
|
||||
"d3fend": ImportServiceEntry(
|
||||
"app.services.d3fend_import_service", "sync",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
|
||||
"""Look up the import handler for *source_name*.
|
||||
|
||||
Returns ``None`` when no handler is registered.
|
||||
"""
|
||||
return IMPORT_REGISTRY.get(source_name)
|
||||
@@ -10,6 +10,16 @@ Usage in routers::
|
||||
If an exception propagates, ``__exit__`` issues a rollback automatically.
|
||||
Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
||||
``db.flush()`` to stage work and let the caller decide when to commit.
|
||||
|
||||
**Documented exceptions** (services that may commit internally):
|
||||
- ``audit_service.log_action`` — called from 15+ routers; commits to ensure
|
||||
audit records persist even when callers do not.
|
||||
- Import services (atomic_import, sigma_import, etc.) — self-contained sync ops.
|
||||
- Background jobs (campaign_scheduler, intel_service, stale_detection,
|
||||
mitre_sync) — self-contained operations.
|
||||
- Self-contained batch ops (e.g. detection_rule_service.auto_associate_rules,
|
||||
snapshot_service.create_snapshot, campaign_service.generate_campaign_from_*,
|
||||
osint_enrichment_service.enrich_technique_with_cves).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func, case
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.services import advanced_metrics_service
|
||||
|
||||
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||
|
||||
@@ -22,39 +17,7 @@ def coverage_by_tactic(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||
results = (
|
||||
db.query(
|
||||
Technique.tactic,
|
||||
func.count(Technique.id).label("total"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "validated", 1), else_=0)
|
||||
).label("validated"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "partial", 1), else_=0)
|
||||
).label("partial"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "not_covered", 1), else_=0)
|
||||
).label("not_covered"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "in_progress", 1), else_=0)
|
||||
).label("in_progress"),
|
||||
)
|
||||
.group_by(Technique.tactic)
|
||||
.order_by(Technique.tactic)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"tactic": r[0] or "Unknown",
|
||||
"total": r[1],
|
||||
"validated": int(r[2]),
|
||||
"partial": int(r[3]),
|
||||
"not_covered": int(r[4]),
|
||||
"in_progress": int(r[5]),
|
||||
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||
|
||||
|
||||
@router.get("/never-tested")
|
||||
@@ -63,24 +26,7 @@ def never_tested_techniques(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Techniques that have never had a test created."""
|
||||
tested_technique_ids = (
|
||||
db.query(Test.technique_id).distinct().subquery()
|
||||
)
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"mitre_id": t.mitre_id,
|
||||
"name": t.name,
|
||||
"tactic": t.tactic,
|
||||
"is_subtechnique": t.is_subtechnique,
|
||||
}
|
||||
for t in techniques
|
||||
]
|
||||
return advanced_metrics_service.get_never_tested_techniques(db)
|
||||
|
||||
|
||||
@router.get("/avg-validation-time")
|
||||
@@ -92,50 +38,7 @@ def avg_validation_time(
|
||||
|
||||
Returns overall average and per-phase averages where data is available.
|
||||
"""
|
||||
validated_tests = (
|
||||
db.query(Test)
|
||||
.filter(Test.state == "validated")
|
||||
.all()
|
||||
)
|
||||
|
||||
if not validated_tests:
|
||||
return {
|
||||
"total_validated": 0,
|
||||
"avg_total_hours": 0,
|
||||
"avg_red_phase_hours": 0,
|
||||
"avg_blue_phase_hours": 0,
|
||||
}
|
||||
|
||||
total_durations = []
|
||||
red_durations = []
|
||||
blue_durations = []
|
||||
|
||||
for test in validated_tests:
|
||||
if test.created_at and test.red_validated_at:
|
||||
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
|
||||
total_durations.append(total_seconds)
|
||||
|
||||
if test.red_started_at and test.blue_started_at:
|
||||
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
|
||||
red_paused = test.red_paused_seconds or 0
|
||||
red_durations.append(max(red_sec - red_paused, 0))
|
||||
|
||||
if test.blue_started_at and test.blue_validated_at:
|
||||
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
|
||||
blue_paused = test.blue_paused_seconds or 0
|
||||
blue_durations.append(max(blue_sec - blue_paused, 0))
|
||||
|
||||
def avg_hours(durations: list[float]) -> float:
|
||||
if not durations:
|
||||
return 0
|
||||
return round(sum(durations) / len(durations) / 3600, 2)
|
||||
|
||||
return {
|
||||
"total_validated": len(validated_tests),
|
||||
"avg_total_hours": avg_hours(total_durations),
|
||||
"avg_red_phase_hours": avg_hours(red_durations),
|
||||
"avg_blue_phase_hours": avg_hours(blue_durations),
|
||||
}
|
||||
return advanced_metrics_service.get_avg_validation_time(db)
|
||||
|
||||
|
||||
@router.get("/detection-rate-trend")
|
||||
@@ -144,41 +47,4 @@ def detection_rate_trend(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Monthly detection rate trend for the last 12 months."""
|
||||
from datetime import timedelta
|
||||
|
||||
now = datetime.utcnow()
|
||||
months = []
|
||||
|
||||
for i in range(11, -1, -1):
|
||||
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
|
||||
month_end = month_start + timedelta(days=30)
|
||||
|
||||
validated = (
|
||||
db.query(func.count(Test.id))
|
||||
.filter(
|
||||
Test.state == "validated",
|
||||
Test.created_at >= month_start,
|
||||
Test.created_at < month_end,
|
||||
)
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
detected = (
|
||||
db.query(func.count(Test.id))
|
||||
.filter(
|
||||
Test.state == "validated",
|
||||
Test.detection_result == "detected",
|
||||
Test.created_at >= month_start,
|
||||
Test.created_at < month_end,
|
||||
)
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
months.append({
|
||||
"month": month_start.strftime("%Y-%m"),
|
||||
"validated": validated,
|
||||
"detected": detected,
|
||||
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
|
||||
})
|
||||
|
||||
return months
|
||||
return advanced_metrics_service.get_detection_rate_trend(db)
|
||||
|
||||
@@ -5,15 +5,12 @@ directly from URL. All endpoints require authentication.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.coverage_snapshot import CoverageSnapshot
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.services import analytics_service
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||
|
||||
@@ -24,22 +21,7 @@ def analytics_coverage(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Coverage per technique — flat format for BI dashboards."""
|
||||
techniques = db.query(Technique).all()
|
||||
return [
|
||||
{
|
||||
"mitre_id": t.mitre_id,
|
||||
"name": t.name,
|
||||
"tactic": t.tactic,
|
||||
"status": t.status_global.value if t.status_global else "not_evaluated",
|
||||
"is_subtechnique": t.is_subtechnique,
|
||||
"test_count": len(t.tests) if t.tests else 0,
|
||||
"review_required": t.review_required,
|
||||
"last_review_date": (
|
||||
t.last_review_date.isoformat() if t.last_review_date else None
|
||||
),
|
||||
}
|
||||
for t in techniques
|
||||
]
|
||||
return analytics_service.get_coverage_analytics(db)
|
||||
|
||||
|
||||
@router.get("/tests")
|
||||
@@ -50,34 +32,9 @@ def analytics_tests(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""All tests with timestamps — flat format for BI dashboards."""
|
||||
query = db.query(Test)
|
||||
if date_from:
|
||||
query = query.filter(Test.created_at >= date_from)
|
||||
if date_to:
|
||||
query = query.filter(Test.created_at <= date_to)
|
||||
tests = query.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"technique_id": str(t.technique_id),
|
||||
"name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"result": t.result.value if t.result else None,
|
||||
"detection_result": (
|
||||
t.detection_result.value if t.detection_result else None
|
||||
),
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"execution_date": (
|
||||
t.execution_date.isoformat() if t.execution_date else None
|
||||
),
|
||||
"platform": t.platform,
|
||||
"tool_used": t.tool_used,
|
||||
"attack_success": t.attack_success,
|
||||
"remediation_status": t.remediation_status,
|
||||
}
|
||||
for t in tests
|
||||
]
|
||||
return analytics_service.get_tests_analytics(
|
||||
db, date_from=date_from, date_to=date_to
|
||||
)
|
||||
|
||||
|
||||
@router.get("/trends")
|
||||
@@ -86,23 +43,7 @@ def analytics_trends(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Historical coverage snapshots for trend visualization."""
|
||||
snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
.order_by(CoverageSnapshot.created_at)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"date": s.created_at.isoformat() if s.created_at else None,
|
||||
"name": s.name,
|
||||
"total_techniques": s.total_techniques,
|
||||
"validated_count": s.validated_count,
|
||||
"partial_count": s.partial_count,
|
||||
"not_covered_count": s.not_covered_count,
|
||||
"organization_score": s.organization_score,
|
||||
}
|
||||
for s in snapshots
|
||||
]
|
||||
return analytics_service.get_trends_analytics(db)
|
||||
|
||||
|
||||
@router.get("/operators")
|
||||
@@ -111,17 +52,4 @@ def analytics_operators(
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Per-operator metrics — for workload management dashboards."""
|
||||
results = (
|
||||
db.query(
|
||||
User.username,
|
||||
User.role,
|
||||
func.count(Test.id).label("test_count"),
|
||||
)
|
||||
.outerjoin(Test, Test.created_by == User.id)
|
||||
.group_by(User.id, User.username, User.role)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{"username": r[0], "role": r[1], "test_count": r[2]}
|
||||
for r in results
|
||||
]
|
||||
return analytics_service.get_operators_analytics(db)
|
||||
|
||||
@@ -4,14 +4,17 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.user import User
|
||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||
from app.services.audit_query_service import (
|
||||
list_distinct_actions,
|
||||
list_distinct_entity_types,
|
||||
list_logs,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||
|
||||
@@ -32,53 +35,22 @@ def list_audit_logs(
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
query = query.filter(AuditLog.action == action)
|
||||
if entity_type:
|
||||
query = query.filter(AuditLog.entity_type == entity_type)
|
||||
if start_date:
|
||||
query = query.filter(AuditLog.timestamp >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AuditLog.timestamp <= end_date)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get paginated results
|
||||
logs = (
|
||||
query
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Convert to response format with username
|
||||
items = []
|
||||
for log in logs:
|
||||
item = AuditLogOut(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.user.username if log.user else None,
|
||||
action=log.action,
|
||||
entity_type=log.entity_type,
|
||||
entity_id=log.entity_id,
|
||||
timestamp=log.timestamp,
|
||||
details=log.details,
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return AuditLogPage(
|
||||
items=items,
|
||||
total=total,
|
||||
result = list_logs(
|
||||
db,
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
return AuditLogPage(
|
||||
items=[AuditLogOut(**item) for item in result["items"]],
|
||||
total=result["total"],
|
||||
offset=result["offset"],
|
||||
limit=result["limit"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/actions", response_model=list[str])
|
||||
@@ -90,13 +62,7 @@ def list_actions(
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
actions = (
|
||||
db.query(AuditLog.action)
|
||||
.distinct()
|
||||
.order_by(AuditLog.action)
|
||||
.all()
|
||||
)
|
||||
return [a[0] for a in actions]
|
||||
return list_distinct_actions(db)
|
||||
|
||||
|
||||
@router.get("/entity-types", response_model=list[str])
|
||||
@@ -108,11 +74,4 @@ def list_entity_types(
|
||||
|
||||
**Requires admin role.**
|
||||
"""
|
||||
types = (
|
||||
db.query(AuditLog.entity_type)
|
||||
.filter(AuditLog.entity_type.isnot(None))
|
||||
.distinct()
|
||||
.order_by(AuditLog.entity_type)
|
||||
.all()
|
||||
)
|
||||
return [t[0] for t in types]
|
||||
return list_distinct_entity_types(db)
|
||||
|
||||
+15
-28
@@ -9,7 +9,7 @@ cannot use cookies (e.g. Swagger UI).
|
||||
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status
|
||||
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
@@ -17,11 +17,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from app.auth import verify_password, hash_password, create_access_token, blacklist_token
|
||||
from app.auth import create_access_token, blacklist_token
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.services.auth_service import authenticate_user, change_password as auth_change_password
|
||||
from app.schemas.auth import TokenResponse, UserOut
|
||||
from app.schemas.user import PasswordChange
|
||||
|
||||
@@ -56,24 +58,10 @@ def login(
|
||||
attacks. The token is set as an HttpOnly cookie **and** returned in the
|
||||
JSON body for API/Swagger compatibility.
|
||||
"""
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
|
||||
# Constant-time comparison: always run bcrypt verify to prevent
|
||||
# timing-based user enumeration (SEC-005).
|
||||
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
|
||||
hashed = user.hashed_password if user else _DUMMY_HASH
|
||||
password_valid = verify_password(form_data.password, hashed)
|
||||
|
||||
if user is None or not password_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is disabled. Contact an administrator.",
|
||||
user = authenticate_user(
|
||||
db,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": user.username})
|
||||
@@ -163,14 +151,13 @@ def change_password(
|
||||
``must_change_password`` flag is cleared so the user can proceed
|
||||
normally.
|
||||
"""
|
||||
if not verify_password(body.current_password, current_user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect",
|
||||
auth_change_password(
|
||||
db,
|
||||
current_user,
|
||||
current_password=body.current_password,
|
||||
new_password=body.new_password,
|
||||
)
|
||||
|
||||
current_user.hashed_password = hash_password(body.new_password)
|
||||
current_user.must_change_password = False
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
uow.commit()
|
||||
|
||||
return {"detail": "Password changed successfully"}
|
||||
|
||||
@@ -30,7 +30,7 @@ from app.services.campaign_crud_service import (
|
||||
serialize_campaign,
|
||||
update_campaign as crud_update,
|
||||
)
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.notification_service import notify_role
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -237,11 +237,9 @@ def activate_campaign(
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||
for user in red_techs:
|
||||
create_notification(
|
||||
notify_role(
|
||||
db,
|
||||
user_id=user.id,
|
||||
role="red_tech",
|
||||
type="campaign_activated",
|
||||
title="Campaign activated",
|
||||
message=f'Campaign "{campaign.name}" has been activated.',
|
||||
|
||||
@@ -3,18 +3,20 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
from app.services.d3fend_import_service import (
|
||||
import_d3fend_techniques,
|
||||
import_d3fend_mappings,
|
||||
get_defenses_for_technique,
|
||||
)
|
||||
from app.services.d3fend_query_service import (
|
||||
list_defensive_techniques as list_defensive_techniques_svc,
|
||||
list_d3fend_tactics,
|
||||
get_defenses_for_attack_technique,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,60 +38,22 @@ def list_defensive_techniques(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all D3FEND defensive techniques with optional filters."""
|
||||
query = db.query(DefensiveTechnique)
|
||||
|
||||
if tactic:
|
||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||
|
||||
if search:
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
DefensiveTechnique.name.ilike(pattern)
|
||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||
return list_defensive_techniques_svc(
|
||||
db, tactic=tactic, search=search, offset=offset, limit=limit
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [
|
||||
{
|
||||
"id": str(dt.id),
|
||||
"d3fend_id": dt.d3fend_id,
|
||||
"name": dt.name,
|
||||
"description": dt.description,
|
||||
"tactic": dt.tactic,
|
||||
"d3fend_url": dt.d3fend_url,
|
||||
}
|
||||
for dt in items
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /d3fend/tactics — List all D3FEND tactics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/tactics")
|
||||
def list_d3fend_tactics(
|
||||
def list_d3fend_tactics_endpoint(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return a list of all D3FEND tactics with counts."""
|
||||
from sqlalchemy import func
|
||||
|
||||
rows = (
|
||||
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
||||
.group_by(DefensiveTechnique.tactic)
|
||||
.order_by(DefensiveTechnique.tactic)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
||||
return list_d3fend_tactics(db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -97,24 +61,13 @@ def list_d3fend_tactics(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/for-technique/{mitre_id}")
|
||||
def get_defenses_for_attack_technique(
|
||||
def get_defenses_for_attack_technique_endpoint(
|
||||
mitre_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||
if not technique:
|
||||
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
|
||||
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
|
||||
return {
|
||||
"mitre_id": mitre_id,
|
||||
"technique_name": technique.name,
|
||||
"defenses": defenses,
|
||||
"total": len(defenses),
|
||||
}
|
||||
return get_defenses_for_attack_technique(db, mitre_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources
|
||||
including sync triggers, enable/disable toggles, and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.models.data_source import DataSource
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.data_source_service import (
|
||||
get_source_stats,
|
||||
list_sources,
|
||||
sync_all_sources,
|
||||
sync_source,
|
||||
update_source,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel):
|
||||
sync_frequency: Optional[str] = None
|
||||
config: Optional[dict] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync dispatcher — maps source name → import function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_sync_handler(source_name: str):
|
||||
"""Lazily import and return the sync function for *source_name*.
|
||||
|
||||
We import lazily to avoid circular imports and to only load the
|
||||
modules that are actually needed.
|
||||
"""
|
||||
handlers = {
|
||||
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
|
||||
"sigma": ("app.services.sigma_import_service", "sync"),
|
||||
"lolbas": ("app.services.lolbas_import_service", "sync"),
|
||||
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
|
||||
"caldera": ("app.services.caldera_import_service", "sync"),
|
||||
"elastic_rules": ("app.services.elastic_import_service", "sync"),
|
||||
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
|
||||
"d3fend": ("app.services.d3fend_import_service", "sync"),
|
||||
}
|
||||
|
||||
if source_name not in handlers:
|
||||
return None
|
||||
|
||||
module_path, func_name = handlers[source_name]
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
return getattr(mod, func_name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -79,25 +52,7 @@ def list_data_sources(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"name": s.name,
|
||||
"display_name": s.display_name,
|
||||
"type": s.type,
|
||||
"url": s.url,
|
||||
"description": s.description,
|
||||
"is_enabled": s.is_enabled,
|
||||
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
||||
"last_sync_status": s.last_sync_status,
|
||||
"last_sync_stats": s.last_sync_stats,
|
||||
"sync_frequency": s.sync_frequency,
|
||||
"config": s.config,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in sources
|
||||
]
|
||||
return list_sources(db)
|
||||
|
||||
|
||||
@router.patch("/{source_id}")
|
||||
@@ -111,31 +66,21 @@ def update_data_source(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
|
||||
if "is_enabled" in update_data:
|
||||
ds.is_enabled = update_data["is_enabled"]
|
||||
if "sync_frequency" in update_data:
|
||||
ds.sync_frequency = update_data["sync_frequency"]
|
||||
if "config" in update_data:
|
||||
ds.config = update_data["config"]
|
||||
|
||||
db.commit()
|
||||
update_source(db, source_id, **update_data)
|
||||
with UnitOfWork(db) as uow:
|
||||
uow.commit()
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_data_source",
|
||||
entity_type="data_source",
|
||||
entity_id=str(ds.id),
|
||||
entity_id=source_id,
|
||||
details={"updates": update_data},
|
||||
)
|
||||
|
||||
return {"message": "Data source updated", "id": str(ds.id)}
|
||||
return {"message": "Data source updated", "id": source_id}
|
||||
|
||||
|
||||
@router.post("/{source_id}/sync")
|
||||
@@ -148,46 +93,7 @@ def sync_data_source(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
handler = _get_sync_handler(ds.name)
|
||||
if handler is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No sync handler available for '{ds.name}'",
|
||||
)
|
||||
|
||||
# Mark as in_progress
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
|
||||
)
|
||||
|
||||
# Update DS record (the handler may already have done this,
|
||||
# but we ensure it here as well)
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"message": f"Sync complete for {ds.display_name}",
|
||||
"source": ds.name,
|
||||
"stats": summary,
|
||||
}
|
||||
return sync_source(db, source_id)
|
||||
|
||||
|
||||
@router.post("/sync-all")
|
||||
@@ -199,49 +105,7 @@ def sync_all_data_sources(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
enabled_sources = (
|
||||
db.query(DataSource)
|
||||
.filter(DataSource.is_enabled == True)
|
||||
.order_by(DataSource.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
results = []
|
||||
for ds in enabled_sources:
|
||||
handler = _get_sync_handler(ds.name)
|
||||
if handler is None:
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "skipped",
|
||||
"detail": "No sync handler available",
|
||||
})
|
||||
continue
|
||||
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "success",
|
||||
"stats": summary,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "error",
|
||||
"detail": "Sync failed. Check server logs for details.",
|
||||
})
|
||||
results = sync_all_sources(db)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
@@ -265,39 +129,4 @@ def get_data_source_stats(
|
||||
|
||||
**Requires** the ``admin`` role.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
# Count items from this source
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
template_count = 0
|
||||
rule_count = 0
|
||||
|
||||
if ds.type == "attack_procedure":
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
elif ds.type == "detection_rule":
|
||||
rule_count = (
|
||||
db.query(DetectionRule)
|
||||
.filter(DetectionRule.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(ds.id),
|
||||
"name": ds.name,
|
||||
"display_name": ds.display_name,
|
||||
"type": ds.type,
|
||||
"is_enabled": ds.is_enabled,
|
||||
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
||||
"last_sync_status": ds.last_sync_status,
|
||||
"last_sync_stats": ds.last_sync_stats,
|
||||
"total_templates": template_count,
|
||||
"total_rules": rule_count,
|
||||
}
|
||||
return get_source_stats(db, source_id)
|
||||
|
||||
+13
-83
@@ -7,14 +7,9 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
from app.models.campaign import Campaign
|
||||
from app.models.jira_link import JiraLinkEntityType
|
||||
from app.models.user import User
|
||||
from app.schemas.jira_schema import (
|
||||
JiraIssueResult,
|
||||
@@ -45,23 +40,14 @@ def create_link(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Associate an Aegis entity with a Jira issue."""
|
||||
link = JiraLink(
|
||||
link = jira_service.create_link(
|
||||
db,
|
||||
entity_type=body.entity_type,
|
||||
entity_id=body.entity_id,
|
||||
jira_issue_key=body.jira_issue_key,
|
||||
sync_direction=body.sync_direction,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.add(link)
|
||||
db.flush()
|
||||
|
||||
# Pull initial data from Jira if enabled
|
||||
if settings.JIRA_ENABLED:
|
||||
try:
|
||||
jira_service.sync_jira_to_aegis(db, link)
|
||||
except Exception as e:
|
||||
logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e)
|
||||
|
||||
db.commit()
|
||||
db.refresh(link)
|
||||
|
||||
@@ -88,12 +74,11 @@ def list_links(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List Jira links, optionally filtered by entity."""
|
||||
query = db.query(JiraLink)
|
||||
if entity_type:
|
||||
query = query.filter(JiraLink.entity_type == entity_type)
|
||||
if entity_id:
|
||||
query = query.filter(JiraLink.entity_id == entity_id)
|
||||
return query.order_by(JiraLink.created_at.desc()).all()
|
||||
return jira_service.list_links(
|
||||
db,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/links/{link_id}/sync")
|
||||
@@ -103,9 +88,7 @@ def sync_link(
|
||||
user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Force bidirectional sync for a specific Jira link."""
|
||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||
if not link:
|
||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||
link = jira_service.get_link_or_raise(db, link_id)
|
||||
jira_service.sync_jira_to_aegis(db, link)
|
||||
db.commit()
|
||||
return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||
@@ -118,10 +101,7 @@ def delete_link(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Remove a Jira link."""
|
||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||
if not link:
|
||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||
db.delete(link)
|
||||
link = jira_service.delete_link(db, link_id)
|
||||
db.commit()
|
||||
audit_service.log_action(
|
||||
db,
|
||||
@@ -141,61 +121,11 @@ def create_issue_from_entity(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
||||
summary, description = _build_issue_data(db, entity_type, entity_id)
|
||||
result = jira_service.create_jira_issue(
|
||||
project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
summary=summary,
|
||||
description=description,
|
||||
labels=["aegis", entity_type.value],
|
||||
)
|
||||
link = JiraLink(
|
||||
result = jira_service.create_issue_and_link(
|
||||
db,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
jira_issue_key=result["issue_key"],
|
||||
jira_issue_id=result["issue_id"],
|
||||
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.add(link)
|
||||
db.commit()
|
||||
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
||||
|
||||
|
||||
def _build_issue_data(
|
||||
db: Session,
|
||||
entity_type: JiraLinkEntityType,
|
||||
entity_id: UUID,
|
||||
) -> tuple[str, str]:
|
||||
"""Build Jira issue summary + description from an Aegis entity."""
|
||||
if entity_type == JiraLinkEntityType.test:
|
||||
entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Test", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Test] {entity.name}",
|
||||
f"Test: {entity.name}\n"
|
||||
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.campaign:
|
||||
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Campaign", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Campaign] {entity.name}",
|
||||
f"Campaign: {entity.name}\n"
|
||||
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.technique:
|
||||
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Technique", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
||||
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
||||
f"Tactic: {entity.tactic or 'N/A'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
else:
|
||||
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
||||
return result
|
||||
|
||||
@@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.notification import Notification
|
||||
from app.models.user import User
|
||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||
from app.services.notification_service import (
|
||||
list_notifications,
|
||||
mark_as_read,
|
||||
mark_all_as_read,
|
||||
get_unread_count,
|
||||
@@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[NotificationOut])
|
||||
def list_notifications(
|
||||
def list_notifications_endpoint(
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return paginated notifications for the current user, newest first."""
|
||||
notifs = (
|
||||
db.query(Notification)
|
||||
.filter(Notification.user_id == current_user.id)
|
||||
.order_by(Notification.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return notifs
|
||||
return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -80,14 +72,8 @@ def read_notification(
|
||||
):
|
||||
"""Mark a single notification as read."""
|
||||
with UnitOfWork(db) as uow:
|
||||
success = mark_as_read(db, notification_id, current_user.id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Notification not found",
|
||||
)
|
||||
notif = mark_as_read(db, notification_id, current_user.id)
|
||||
uow.commit()
|
||||
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
||||
return notif
|
||||
|
||||
|
||||
|
||||
@@ -10,14 +10,15 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.osint_item import OsintItem
|
||||
from app.models.technique import Technique
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.services.osint_enrichment_service import (
|
||||
enrich_technique_with_cves,
|
||||
get_osint_items_for_technique,
|
||||
get_osint_summary,
|
||||
get_technique_or_raise,
|
||||
list_osint_items as service_list_osint_items,
|
||||
mark_osint_reviewed,
|
||||
get_unreviewed_count,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/osint", tags=["osint"])
|
||||
@@ -56,41 +57,15 @@ def list_osint_items(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List OSINT items with optional filters."""
|
||||
query = db.query(OsintItem)
|
||||
if technique_id:
|
||||
query = query.filter(OsintItem.technique_id == technique_id)
|
||||
if source_type:
|
||||
query = query.filter(OsintItem.source_type == source_type)
|
||||
if reviewed is not None:
|
||||
query = query.filter(OsintItem.reviewed == reviewed)
|
||||
|
||||
total = query.count()
|
||||
items = (
|
||||
query.order_by(OsintItem.discovered_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
return service_list_osint_items(
|
||||
db,
|
||||
technique_id=technique_id,
|
||||
source_type=source_type,
|
||||
reviewed=reviewed,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": [
|
||||
{
|
||||
"id": str(item.id),
|
||||
"technique_id": str(item.technique_id),
|
||||
"source_type": item.source_type,
|
||||
"source_url": item.source_url,
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"severity": item.severity,
|
||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||
"reviewed": item.reviewed,
|
||||
"metadata": item.metadata_,
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
def osint_summary(
|
||||
@@ -98,34 +73,7 @@ def osint_summary(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Summary statistics for OSINT items."""
|
||||
from sqlalchemy import func
|
||||
|
||||
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
||||
unreviewed = get_unreviewed_count(db)
|
||||
|
||||
by_severity = dict(
|
||||
db.query(OsintItem.severity, func.count(OsintItem.id))
|
||||
.group_by(OsintItem.severity)
|
||||
.all()
|
||||
)
|
||||
|
||||
by_type = dict(
|
||||
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
||||
.group_by(OsintItem.source_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
techniques_with_items = (
|
||||
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_items": total,
|
||||
"unreviewed": unreviewed,
|
||||
"techniques_with_items": techniques_with_items,
|
||||
"by_severity": by_severity,
|
||||
"by_type": by_type,
|
||||
}
|
||||
return get_osint_summary(db)
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/review")
|
||||
@@ -135,12 +83,14 @@ def review_osint_item(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Mark an OSINT item as reviewed."""
|
||||
with UnitOfWork(db) as uow:
|
||||
item = mark_osint_reviewed(db, str(item_id))
|
||||
if not item:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OSINT item not found",
|
||||
)
|
||||
uow.commit()
|
||||
return {"id": str(item.id), "reviewed": True}
|
||||
|
||||
|
||||
@@ -151,13 +101,7 @@ def trigger_technique_enrichment(
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Manually trigger OSINT enrichment for a single technique."""
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if not technique:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Technique not found",
|
||||
)
|
||||
|
||||
technique = get_technique_or_raise(db, technique_id)
|
||||
count = enrich_technique_with_cves(db, technique)
|
||||
return {
|
||||
"technique_id": str(technique.id),
|
||||
|
||||
@@ -5,19 +5,18 @@ Provides granular scoring with breakdowns and configurable weights.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor
|
||||
from app.services.scoring_service import (
|
||||
calculate_technique_score,
|
||||
score_technique_by_mitre_id,
|
||||
score_actor_by_id,
|
||||
calculate_tactic_score,
|
||||
calculate_actor_coverage_score,
|
||||
calculate_organization_score,
|
||||
get_score_history,
|
||||
)
|
||||
@@ -39,23 +38,7 @@ def score_technique(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed score with breakdown for a specific technique."""
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
.first()
|
||||
)
|
||||
if not technique:
|
||||
raise HTTPException(status_code=404, detail="Technique not found")
|
||||
|
||||
result = calculate_technique_score(technique, db)
|
||||
|
||||
return {
|
||||
"mitre_id": technique.mitre_id,
|
||||
"name": technique.name,
|
||||
"tactic": technique.tactic,
|
||||
"status_global": technique.status_global.value if technique.status_global else None,
|
||||
**result,
|
||||
}
|
||||
return score_technique_by_mitre_id(db, mitre_id)
|
||||
|
||||
|
||||
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
|
||||
@@ -81,11 +64,7 @@ def score_threat_actor(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get coverage score against a specific threat actor."""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
|
||||
return calculate_actor_coverage_score(actor_id, db)
|
||||
return score_actor_by_id(db, actor_id)
|
||||
|
||||
|
||||
# ── GET /scores/organization ─────────────────────────────────────────
|
||||
@@ -149,6 +128,7 @@ def update_scoring_config(
|
||||
Weights are persisted in the database and survive restarts.
|
||||
Validation enforces that all weights are non-negative and sum to 100.
|
||||
"""
|
||||
with UnitOfWork(db) as uow:
|
||||
result = update_scoring_weights(
|
||||
db,
|
||||
tests=payload.tests,
|
||||
@@ -157,6 +137,7 @@ def update_scoring_config(
|
||||
freshness=payload.freshness,
|
||||
platform_diversity=payload.platform_diversity,
|
||||
)
|
||||
uow.commit()
|
||||
|
||||
from app.services.score_cache import invalidate
|
||||
invalidate()
|
||||
|
||||
@@ -8,18 +8,24 @@ import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
from app.domain.errors import BusinessRuleViolation
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.services.snapshot_service import (
|
||||
create_snapshot,
|
||||
compare_snapshots,
|
||||
cleanup_old_snapshots,
|
||||
serialize_snapshot_summary,
|
||||
list_snapshots as list_snapshots_svc,
|
||||
get_snapshot_or_raise,
|
||||
get_snapshot_detail,
|
||||
delete_snapshot,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
||||
"""Lightweight serialization for list views."""
|
||||
return {
|
||||
"id": str(snap.id),
|
||||
"name": snap.name,
|
||||
"organization_score": snap.organization_score,
|
||||
"total_techniques": snap.total_techniques,
|
||||
"validated_count": snap.validated_count,
|
||||
"partial_count": snap.partial_count,
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
"created_by": str(snap.created_by) if snap.created_by else None,
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
||||
"""Full serialization including technique states."""
|
||||
base = _serialize_snapshot_summary(snap)
|
||||
|
||||
technique_states = (
|
||||
db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
||||
.order_by(SnapshotTechniqueState.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
base["technique_states"] = [
|
||||
{
|
||||
"mitre_id": s.mitre_id,
|
||||
"technique_id": str(s.technique_id),
|
||||
"status": s.status,
|
||||
"score": s.score,
|
||||
}
|
||||
for s in technique_states
|
||||
]
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /snapshots — List snapshots (paginated)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -88,23 +52,7 @@ def list_snapshots(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||
query = db.query(CoverageSnapshot)
|
||||
total = query.count()
|
||||
|
||||
snapshots = (
|
||||
query
|
||||
.order_by(CoverageSnapshot.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [_serialize_snapshot_summary(s) for s in snapshots],
|
||||
}
|
||||
return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
|
||||
details={"name": snapshot.name, "score": snapshot.organization_score},
|
||||
)
|
||||
|
||||
return _serialize_snapshot_summary(snapshot)
|
||||
return serialize_snapshot_summary(snapshot)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
|
||||
a_id = uuid.UUID(a)
|
||||
b_id = uuid.UUID(b)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
|
||||
raise BusinessRuleViolation("Invalid snapshot ID format")
|
||||
|
||||
result = compare_snapshots(db, a_id, b_id)
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return result
|
||||
return compare_snapshots(db, a_id, b_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -168,11 +112,7 @@ def get_snapshot(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed snapshot information including per-technique states."""
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
|
||||
return _serialize_snapshot_detail(db, snapshot)
|
||||
return get_snapshot_detail(db, snapshot_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -180,15 +120,13 @@ def get_snapshot(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{snapshot_id}")
|
||||
def delete_snapshot(
|
||||
def delete_snapshot_endpoint(
|
||||
snapshot_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Delete a snapshot (admin only)."""
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
@@ -199,7 +137,8 @@ def delete_snapshot(
|
||||
details={"name": snapshot.name},
|
||||
)
|
||||
|
||||
db.delete(snapshot)
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
delete_snapshot(db, snapshot_id)
|
||||
uow.commit()
|
||||
|
||||
return {"detail": "Snapshot deleted"}
|
||||
|
||||
@@ -6,7 +6,7 @@ exceptions to HTTP responses automatically.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||
@@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
from app.models.technique import Technique
|
||||
from app.models.user import User
|
||||
from app.schemas.technique import (
|
||||
TechniqueCreate,
|
||||
@@ -27,7 +26,7 @@ from app.schemas.technique import (
|
||||
TechniqueUpdate,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||
from app.services.technique_query_service import get_technique_detail
|
||||
|
||||
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||
|
||||
@@ -67,45 +66,7 @@ def get_technique(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.options(joinedload(Technique.tests))
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
|
||||
return {
|
||||
"id": str(technique.id),
|
||||
"mitre_id": technique.mitre_id,
|
||||
"name": technique.name,
|
||||
"description": technique.description,
|
||||
"tactic": technique.tactic,
|
||||
"platforms": technique.platforms or [],
|
||||
"mitre_version": technique.mitre_version,
|
||||
"mitre_last_modified": technique.mitre_last_modified,
|
||||
"is_subtechnique": technique.is_subtechnique,
|
||||
"parent_mitre_id": technique.parent_mitre_id,
|
||||
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
||||
"review_required": technique.review_required,
|
||||
"last_review_date": technique.last_review_date,
|
||||
"tests": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"result": t.result.value if t.result else None,
|
||||
"platform": t.platform,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in technique.tests
|
||||
],
|
||||
"d3fend_defenses": defenses,
|
||||
}
|
||||
return get_technique_detail(db, mitre_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -25,13 +25,12 @@ Filters (GET /test-templates)
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, or_
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.schemas.test_template import (
|
||||
TestTemplateCreate,
|
||||
@@ -39,6 +38,17 @@ from app.schemas.test_template import (
|
||||
TestTemplateSummary,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.test_template_service import (
|
||||
bulk_activate,
|
||||
create_template as create_template_svc,
|
||||
get_template_or_raise,
|
||||
get_template_stats,
|
||||
get_templates_by_technique as templates_by_technique,
|
||||
list_templates,
|
||||
soft_delete_template,
|
||||
toggle_template_active as toggle_template_active_svc,
|
||||
update_template as update_template_svc,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||
|
||||
@@ -49,7 +59,7 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[TestTemplateSummary])
|
||||
def list_templates(
|
||||
def _list_templates_handler(
|
||||
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
|
||||
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
|
||||
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
|
||||
@@ -62,37 +72,17 @@ def list_templates(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return a paginated, filterable list of test templates."""
|
||||
query = db.query(TestTemplate)
|
||||
if is_active is not None:
|
||||
query = query.filter(TestTemplate.is_active == is_active) # noqa: E712
|
||||
|
||||
if source:
|
||||
query = query.filter(TestTemplate.source == source)
|
||||
if platform:
|
||||
from app.utils import escape_like
|
||||
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
if severity:
|
||||
query = query.filter(TestTemplate.severity == severity)
|
||||
if mitre_technique_id:
|
||||
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
|
||||
if search:
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
TestTemplate.name.ilike(pattern),
|
||||
TestTemplate.description.ilike(pattern),
|
||||
return list_templates(
|
||||
db,
|
||||
source=source,
|
||||
platform=platform,
|
||||
severity=severity,
|
||||
mitre_technique_id=mitre_technique_id,
|
||||
search=search,
|
||||
is_active=is_active,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
|
||||
templates = (
|
||||
query
|
||||
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return templates
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -105,41 +95,8 @@ def template_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Return catalog statistics: totals by source, platform, active/inactive."""
|
||||
|
||||
total = db.query(func.count(TestTemplate.id)).scalar() or 0
|
||||
active = (
|
||||
db.query(func.count(TestTemplate.id))
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
.scalar()
|
||||
) or 0
|
||||
inactive = total - active
|
||||
|
||||
# By source
|
||||
source_rows = (
|
||||
db.query(TestTemplate.source, func.count(TestTemplate.id))
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
.group_by(TestTemplate.source)
|
||||
.all()
|
||||
)
|
||||
by_source = {source: cnt for source, cnt in source_rows}
|
||||
|
||||
# By platform
|
||||
platform_rows = (
|
||||
db.query(TestTemplate.platform, func.count(TestTemplate.id))
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
.group_by(TestTemplate.platform)
|
||||
.all()
|
||||
)
|
||||
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"active": active,
|
||||
"inactive": inactive,
|
||||
"by_source": by_source,
|
||||
"by_platform": by_platform,
|
||||
}
|
||||
"""Return catalog statistics: active, by_source, by_platform."""
|
||||
return get_template_stats(db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -154,13 +111,8 @@ def bulk_activate_templates(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Set all templates to active or inactive."""
|
||||
count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.is_active != activate)
|
||||
.update({TestTemplate.is_active: activate})
|
||||
)
|
||||
db.commit()
|
||||
|
||||
count = bulk_activate(db, activate=activate)
|
||||
with UnitOfWork(db) as uow:
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
@@ -169,6 +121,7 @@ def bulk_activate_templates(
|
||||
entity_id=None,
|
||||
details={"affected": count, "is_active": activate},
|
||||
)
|
||||
uow.commit()
|
||||
|
||||
return {
|
||||
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
|
||||
@@ -183,22 +136,13 @@ def bulk_activate_templates(
|
||||
|
||||
|
||||
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
|
||||
def templates_by_technique(
|
||||
def _templates_by_technique_handler(
|
||||
mitre_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return all active templates mapped to a specific MITRE technique."""
|
||||
templates = (
|
||||
db.query(TestTemplate)
|
||||
.filter(
|
||||
TestTemplate.mitre_technique_id == mitre_id,
|
||||
TestTemplate.is_active == True, # noqa: E712
|
||||
)
|
||||
.order_by(TestTemplate.name)
|
||||
.all()
|
||||
)
|
||||
return templates
|
||||
return templates_by_technique(db, mitre_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -213,13 +157,7 @@ def get_template(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return full details for a single test template."""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test template not found",
|
||||
)
|
||||
return template
|
||||
return get_template_or_raise(db, template_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -238,11 +176,8 @@ def create_template(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a custom test template."""
|
||||
template = TestTemplate(**payload.model_dump())
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
|
||||
template = create_template_svc(db, **payload.model_dump())
|
||||
with UnitOfWork(db) as uow:
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
@@ -255,6 +190,8 @@ def create_template(
|
||||
"mitre_technique_id": template.mitre_technique_id,
|
||||
},
|
||||
)
|
||||
uow.commit()
|
||||
db.refresh(template)
|
||||
|
||||
return template
|
||||
|
||||
@@ -272,28 +209,18 @@ def update_template(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Update fields of an existing test template."""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test template not found",
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(template, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
|
||||
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
|
||||
with UnitOfWork(db) as uow:
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_test_template",
|
||||
entity_type="test_template",
|
||||
entity_id=template.id,
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
||||
)
|
||||
uow.commit()
|
||||
db.refresh(template)
|
||||
|
||||
return template
|
||||
|
||||
@@ -309,18 +236,9 @@ def toggle_template_active(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Toggle a template between active and inactive."""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test template not found",
|
||||
)
|
||||
|
||||
template.is_active = not template.is_active
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
|
||||
"""Toggle a template between active and inactive (is_active = not is_active)."""
|
||||
template = toggle_template_active_svc(db, template_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
@@ -329,6 +247,8 @@ def toggle_template_active(
|
||||
entity_id=template.id,
|
||||
details={"name": template.name, "is_active": template.is_active},
|
||||
)
|
||||
uow.commit()
|
||||
db.refresh(template)
|
||||
|
||||
return template
|
||||
|
||||
@@ -345,16 +265,9 @@ def delete_template(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Soft-delete a test template by setting ``is_active=False``."""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test template not found",
|
||||
)
|
||||
|
||||
template.is_active = False
|
||||
db.commit()
|
||||
|
||||
template = get_template_or_raise(db, template_id)
|
||||
soft_delete_template(db, template_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
@@ -363,5 +276,6 @@ def delete_template(
|
||||
entity_id=template.id,
|
||||
details={"name": template.name},
|
||||
)
|
||||
uow.commit()
|
||||
|
||||
return {"detail": "Test template deactivated"}
|
||||
|
||||
@@ -2,20 +2,24 @@
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import require_role
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
||||
from app.auth import hash_password
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.user_service import (
|
||||
create_user,
|
||||
get_user_or_raise,
|
||||
list_users,
|
||||
update_user,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users — list all users
|
||||
@@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe
|
||||
|
||||
|
||||
@router.get("", response_model=list[UserOut])
|
||||
def list_users(
|
||||
def list_users_route(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a list of all users. **Requires admin role.**"""
|
||||
return db.query(User).order_by(User.username).all()
|
||||
return list_users(db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -37,36 +41,21 @@ def list_users(
|
||||
|
||||
|
||||
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||
def create_user(
|
||||
def create_user_route(
|
||||
payload: UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Create a new user. **Requires admin role.**"""
|
||||
|
||||
# Check if username already exists
|
||||
existing = db.query(User).filter(User.username == payload.username).first()
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Username '{payload.username}' already exists",
|
||||
)
|
||||
|
||||
# Validate role
|
||||
if payload.role not in VALID_ROLES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
||||
)
|
||||
|
||||
user = User(
|
||||
user = create_user(
|
||||
db,
|
||||
username=payload.username,
|
||||
email=payload.email,
|
||||
hashed_password=hash_password(payload.password),
|
||||
password=payload.password,
|
||||
role=payload.role,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
uow.commit()
|
||||
db.refresh(user)
|
||||
|
||||
log_action(
|
||||
@@ -93,13 +82,7 @@ def get_user(
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Return a single user by ID. **Requires admin role.**"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
return user
|
||||
return get_user_or_raise(db, user_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -108,37 +91,17 @@ def get_user(
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserOut)
|
||||
def update_user(
|
||||
def update_user_route(
|
||||
user_id: uuid.UUID,
|
||||
payload: UserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
# Validate role if being updated
|
||||
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
||||
)
|
||||
|
||||
# Hash password if being updated
|
||||
if "password" in update_data:
|
||||
update_data["hashed_password"] = hash_password(update_data.pop("password"))
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
db.commit()
|
||||
user = update_user(db, user_id, **update_data)
|
||||
with UnitOfWork(db) as uow:
|
||||
uow.commit()
|
||||
db.refresh(user)
|
||||
|
||||
log_action(
|
||||
@@ -147,7 +110,7 @@ def update_user(
|
||||
action="update_user",
|
||||
entity_type="user",
|
||||
entity_id=user.id,
|
||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -10,9 +10,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.models.worklog import Worklog
|
||||
from app.services import worklog_service
|
||||
|
||||
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||
@@ -59,6 +58,7 @@ def create(
|
||||
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a manually-logged worklog entry."""
|
||||
with UnitOfWork(db) as uow:
|
||||
wl = worklog_service.create_worklog(
|
||||
db,
|
||||
entity_type=body.entity_type,
|
||||
@@ -70,6 +70,8 @@ def create(
|
||||
duration_seconds=body.duration_seconds,
|
||||
description=body.description,
|
||||
)
|
||||
uow.commit()
|
||||
db.refresh(wl)
|
||||
return wl
|
||||
|
||||
|
||||
@@ -97,10 +99,7 @@ def get_one(
|
||||
_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a single worklog by ID."""
|
||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||
if not wl:
|
||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||
return wl
|
||||
return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
|
||||
|
||||
@router.get("/{worklog_id}/verify")
|
||||
@@ -110,9 +109,7 @@ def verify_integrity(
|
||||
_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Check whether a worklog's integrity hash is still valid."""
|
||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||
if not wl:
|
||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
return {
|
||||
"worklog_id": str(wl.id),
|
||||
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import case, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.enums import TestResult
|
||||
|
||||
|
||||
def get_coverage_by_tactic(db: Session) -> list[dict]:
|
||||
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||
results = (
|
||||
db.query(
|
||||
Technique.tactic,
|
||||
func.count(Technique.id).label("total"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "validated", 1), else_=0)
|
||||
).label("validated"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "partial", 1), else_=0)
|
||||
).label("partial"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "not_covered", 1), else_=0)
|
||||
).label("not_covered"),
|
||||
func.sum(
|
||||
case((Technique.status_global == "in_progress", 1), else_=0)
|
||||
).label("in_progress"),
|
||||
)
|
||||
.group_by(Technique.tactic)
|
||||
.order_by(Technique.tactic)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"tactic": r[0] or "Unknown",
|
||||
"total": r[1],
|
||||
"validated": int(r[2]),
|
||||
"partial": int(r[3]),
|
||||
"not_covered": int(r[4]),
|
||||
"in_progress": int(r[5]),
|
||||
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
def get_never_tested_techniques(db: Session) -> list[dict]:
|
||||
"""Techniques that have never had a test created."""
|
||||
tested_technique_ids = db.query(Test.technique_id).distinct().subquery()
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
|
||||
.order_by(Technique.mitre_id)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"mitre_id": t.mitre_id,
|
||||
"name": t.name,
|
||||
"tactic": t.tactic,
|
||||
"is_subtechnique": t.is_subtechnique,
|
||||
}
|
||||
for t in techniques
|
||||
]
|
||||
|
||||
|
||||
def get_avg_validation_time(db: Session) -> dict:
|
||||
"""Average time from test creation to validation, computed from validated tests.
|
||||
|
||||
Returns overall average and per-phase averages where data is available.
|
||||
"""
|
||||
validated_tests = (
|
||||
db.query(Test)
|
||||
.filter(Test.state == "validated")
|
||||
.all()
|
||||
)
|
||||
|
||||
if not validated_tests:
|
||||
return {
|
||||
"total_validated": 0,
|
||||
"avg_total_hours": 0,
|
||||
"avg_red_phase_hours": 0,
|
||||
"avg_blue_phase_hours": 0,
|
||||
}
|
||||
|
||||
total_durations = []
|
||||
red_durations = []
|
||||
blue_durations = []
|
||||
|
||||
for test in validated_tests:
|
||||
if test.created_at and test.red_validated_at:
|
||||
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
|
||||
total_durations.append(total_seconds)
|
||||
|
||||
if test.red_started_at and test.blue_started_at:
|
||||
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
|
||||
red_paused = test.red_paused_seconds or 0
|
||||
red_durations.append(max(red_sec - red_paused, 0))
|
||||
|
||||
if test.blue_started_at and test.blue_validated_at:
|
||||
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
|
||||
blue_paused = test.blue_paused_seconds or 0
|
||||
blue_durations.append(max(blue_sec - blue_paused, 0))
|
||||
|
||||
def avg_hours(durations: list[float]) -> float:
|
||||
if not durations:
|
||||
return 0
|
||||
return round(sum(durations) / len(durations) / 3600, 2)
|
||||
|
||||
return {
|
||||
"total_validated": len(validated_tests),
|
||||
"avg_total_hours": avg_hours(total_durations),
|
||||
"avg_red_phase_hours": avg_hours(red_durations),
|
||||
"avg_blue_phase_hours": avg_hours(blue_durations),
|
||||
}
|
||||
|
||||
|
||||
def get_detection_rate_trend(db: Session) -> list[dict]:
|
||||
"""Monthly detection rate trend for the last 12 months."""
|
||||
now = datetime.utcnow()
|
||||
months = []
|
||||
|
||||
for i in range(11, -1, -1):
|
||||
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
|
||||
month_end = month_start + timedelta(days=30)
|
||||
|
||||
validated = (
|
||||
db.query(func.count(Test.id))
|
||||
.filter(
|
||||
Test.state == "validated",
|
||||
Test.created_at >= month_start,
|
||||
Test.created_at < month_end,
|
||||
)
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
detected = (
|
||||
db.query(func.count(Test.id))
|
||||
.filter(
|
||||
Test.state == "validated",
|
||||
Test.detection_result == TestResult.detected,
|
||||
Test.created_at >= month_start,
|
||||
Test.created_at < month_end,
|
||||
)
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
months.append({
|
||||
"month": month_start.strftime("%Y-%m"),
|
||||
"validated": validated,
|
||||
"detected": detected,
|
||||
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
|
||||
})
|
||||
|
||||
return months
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Analytics service — flat JSON optimized for PowerBI / BI tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.coverage_snapshot import CoverageSnapshot
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
def get_coverage_analytics(db: Session) -> list[dict]:
|
||||
"""Coverage per technique — flat format for BI dashboards."""
|
||||
techniques = db.query(Technique).all()
|
||||
return [
|
||||
{
|
||||
"mitre_id": t.mitre_id,
|
||||
"name": t.name,
|
||||
"tactic": t.tactic,
|
||||
"status": t.status_global.value if t.status_global else "not_evaluated",
|
||||
"is_subtechnique": t.is_subtechnique,
|
||||
"test_count": len(t.tests) if t.tests else 0,
|
||||
"review_required": t.review_required,
|
||||
"last_review_date": (
|
||||
t.last_review_date.isoformat() if t.last_review_date else None
|
||||
),
|
||||
}
|
||||
for t in techniques
|
||||
]
|
||||
|
||||
|
||||
def get_tests_analytics(
|
||||
db: Session,
|
||||
*,
|
||||
date_from: str | None = None,
|
||||
date_to: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""All tests with timestamps — flat format for BI dashboards."""
|
||||
query = db.query(Test)
|
||||
if date_from:
|
||||
query = query.filter(Test.created_at >= date_from)
|
||||
if date_to:
|
||||
query = query.filter(Test.created_at <= date_to)
|
||||
tests = query.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"technique_id": str(t.technique_id),
|
||||
"name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"result": t.result.value if t.result else None,
|
||||
"detection_result": (
|
||||
t.detection_result.value if t.detection_result else None
|
||||
),
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"execution_date": (
|
||||
t.execution_date.isoformat() if t.execution_date else None
|
||||
),
|
||||
"platform": t.platform,
|
||||
"tool_used": t.tool_used,
|
||||
"attack_success": t.attack_success,
|
||||
"remediation_status": t.remediation_status,
|
||||
}
|
||||
for t in tests
|
||||
]
|
||||
|
||||
|
||||
def get_trends_analytics(db: Session) -> list[dict]:
|
||||
"""Historical coverage snapshots for trend visualization."""
|
||||
snapshots = (
|
||||
db.query(CoverageSnapshot)
|
||||
.order_by(CoverageSnapshot.created_at)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{
|
||||
"date": s.created_at.isoformat() if s.created_at else None,
|
||||
"name": s.name,
|
||||
"total_techniques": s.total_techniques,
|
||||
"validated_count": s.validated_count,
|
||||
"partial_count": s.partial_count,
|
||||
"not_covered_count": s.not_covered_count,
|
||||
"organization_score": s.organization_score,
|
||||
}
|
||||
for s in snapshots
|
||||
]
|
||||
|
||||
|
||||
def get_operators_analytics(db: Session) -> list[dict]:
|
||||
"""Per-operator metrics — for workload management dashboards."""
|
||||
results = (
|
||||
db.query(
|
||||
User.username,
|
||||
User.role,
|
||||
func.count(Test.id).label("test_count"),
|
||||
)
|
||||
.outerjoin(Test, Test.created_by == User.id)
|
||||
.group_by(User.id, User.username, User.role)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
{"username": r[0], "role": r[1], "test_count": r[2]}
|
||||
for r in results
|
||||
]
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Audit log query service — framework-agnostic query logic for audit logs.
|
||||
|
||||
Provides paginated logs and distinct action/entity-type lists.
|
||||
No FastAPI imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
def list_logs(
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
action: str | None = None,
|
||||
entity_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""Return paginated audit logs with optional filters.
|
||||
|
||||
Returns a dict with keys: items, total, offset, limit.
|
||||
Each item is a dict with: id, user_id, username, action, entity_type,
|
||||
entity_id, timestamp, details.
|
||||
"""
|
||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
query = query.filter(AuditLog.action == action)
|
||||
if entity_type:
|
||||
query = query.filter(AuditLog.entity_type == entity_type)
|
||||
if start_date:
|
||||
query = query.filter(AuditLog.timestamp >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AuditLog.timestamp <= end_date)
|
||||
|
||||
total = query.count()
|
||||
|
||||
logs = (
|
||||
query
|
||||
.order_by(AuditLog.timestamp.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": log.id,
|
||||
"user_id": log.user_id,
|
||||
"username": log.user.username if log.user else None,
|
||||
"action": log.action,
|
||||
"entity_type": log.entity_type,
|
||||
"entity_id": log.entity_id,
|
||||
"timestamp": log.timestamp,
|
||||
"details": log.details,
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return {"items": items, "total": total, "offset": offset, "limit": limit}
|
||||
|
||||
|
||||
def list_distinct_actions(db: Session) -> list[str]:
|
||||
"""Return a list of distinct action types in the audit log."""
|
||||
actions = (
|
||||
db.query(AuditLog.action)
|
||||
.distinct()
|
||||
.order_by(AuditLog.action)
|
||||
.all()
|
||||
)
|
||||
return [a[0] for a in actions]
|
||||
|
||||
|
||||
def list_distinct_entity_types(db: Session) -> list[str]:
|
||||
"""Return a list of distinct entity types in the audit log."""
|
||||
types = (
|
||||
db.query(AuditLog.entity_type)
|
||||
.filter(AuditLog.entity_type.isnot(None))
|
||||
.distinct()
|
||||
.order_by(AuditLog.entity_type)
|
||||
.all()
|
||||
)
|
||||
return [t[0] for t in types]
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Authentication service — credential validation and password management."""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth import hash_password, verify_password
|
||||
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||
from app.models.user import User
|
||||
|
||||
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
|
||||
|
||||
|
||||
def authenticate_user(db: Session, *, username: str, password: str) -> User:
|
||||
"""Validate credentials and return the User.
|
||||
|
||||
Raises BusinessRuleViolation for invalid credentials.
|
||||
Raises PermissionViolation for disabled account.
|
||||
Uses constant-time comparison to prevent timing attacks.
|
||||
"""
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
hashed = user.hashed_password if user else _DUMMY_HASH
|
||||
password_valid = verify_password(password, hashed)
|
||||
|
||||
if user is None or not password_valid:
|
||||
raise BusinessRuleViolation("Incorrect username or password")
|
||||
if not user.is_active:
|
||||
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||
return user
|
||||
|
||||
|
||||
def change_password(
|
||||
db: Session,
|
||||
user: User,
|
||||
*,
|
||||
current_password: str,
|
||||
new_password: str,
|
||||
) -> None:
|
||||
"""Change a user's password. Does NOT commit.
|
||||
|
||||
Raises BusinessRuleViolation if current password is wrong.
|
||||
"""
|
||||
if not verify_password(current_password, user.hashed_password):
|
||||
raise BusinessRuleViolation("Current password is incorrect")
|
||||
user.hashed_password = hash_password(new_password)
|
||||
user.must_change_password = False
|
||||
@@ -0,0 +1,82 @@
|
||||
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.defensive_technique import DefensiveTechnique
|
||||
from app.models.technique import Technique
|
||||
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
def list_defensive_techniques(
|
||||
db: Session,
|
||||
*,
|
||||
tactic: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""List D3FEND defensive techniques with optional filters."""
|
||||
query = db.query(DefensiveTechnique)
|
||||
|
||||
if tactic:
|
||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||
|
||||
if search:
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
DefensiveTechnique.name.ilike(pattern)
|
||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [
|
||||
{
|
||||
"id": str(dt.id),
|
||||
"d3fend_id": dt.d3fend_id,
|
||||
"name": dt.name,
|
||||
"description": dt.description,
|
||||
"tactic": dt.tactic,
|
||||
"d3fend_url": dt.d3fend_url,
|
||||
}
|
||||
for dt in items
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def list_d3fend_tactics(db: Session) -> list[dict]:
|
||||
"""Return a list of all D3FEND tactics with counts."""
|
||||
rows = (
|
||||
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
||||
.group_by(DefensiveTechnique.tactic)
|
||||
.order_by(DefensiveTechnique.tactic)
|
||||
.all()
|
||||
)
|
||||
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
||||
|
||||
|
||||
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
|
||||
return {
|
||||
"mitre_id": mitre_id,
|
||||
"technique_name": technique.name,
|
||||
"defenses": defenses,
|
||||
"total": len(defenses),
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
"""Data source management service — framework-agnostic query and sync logic.
|
||||
|
||||
Provides list, update, sync, and stats. Sync operations commit internally
|
||||
since they are long-running and self-contained.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||
from app.domain.ports.import_service import get_import_handler
|
||||
from app.models.data_source import DataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def list_sources(db: Session) -> list[dict]:
|
||||
"""Return all registered data sources as a list of dicts."""
|
||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"name": s.name,
|
||||
"display_name": s.display_name,
|
||||
"type": s.type,
|
||||
"url": s.url,
|
||||
"description": s.description,
|
||||
"is_enabled": s.is_enabled,
|
||||
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
||||
"last_sync_status": s.last_sync_status,
|
||||
"last_sync_stats": s.last_sync_stats,
|
||||
"sync_frequency": s.sync_frequency,
|
||||
"config": s.config,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in sources
|
||||
]
|
||||
|
||||
|
||||
def update_source(db: Session, source_id: str, **fields: object) -> None:
|
||||
"""Update a data source's fields (is_enabled, sync_frequency, config).
|
||||
|
||||
Raises EntityNotFoundError if source does not exist.
|
||||
Does not commit; the router handles that.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise EntityNotFoundError("Data source", source_id)
|
||||
|
||||
if "is_enabled" in fields:
|
||||
ds.is_enabled = fields["is_enabled"]
|
||||
if "sync_frequency" in fields:
|
||||
ds.sync_frequency = fields["sync_frequency"]
|
||||
if "config" in fields:
|
||||
ds.config = fields["config"]
|
||||
|
||||
|
||||
def sync_source(db: Session, source_id: str) -> dict:
|
||||
"""Trigger sync for a specific data source.
|
||||
|
||||
Raises EntityNotFoundError if source does not exist.
|
||||
Raises BusinessRuleViolation if no sync handler is available.
|
||||
Commits internally (long-running, self-contained operation).
|
||||
Returns dict with message, source, stats.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise EntityNotFoundError("Data source", source_id)
|
||||
|
||||
handler = get_import_handler(ds.name)
|
||||
if handler is None:
|
||||
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
|
||||
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
raise BusinessRuleViolation(
|
||||
f"Sync failed for '{ds.display_name}'. Check server logs for details."
|
||||
)
|
||||
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"message": f"Sync complete for {ds.display_name}",
|
||||
"source": ds.name,
|
||||
"stats": summary,
|
||||
}
|
||||
|
||||
|
||||
def sync_all_sources(db: Session) -> list[dict]:
|
||||
"""Trigger sync for all enabled data sources (sequentially).
|
||||
|
||||
Commits internally (long-running, self-contained operation).
|
||||
Returns list of result dicts with source, status, stats/detail.
|
||||
"""
|
||||
enabled_sources = (
|
||||
db.query(DataSource)
|
||||
.filter(DataSource.is_enabled == True)
|
||||
.order_by(DataSource.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
results = []
|
||||
for ds in enabled_sources:
|
||||
handler = get_import_handler(ds.name)
|
||||
if handler is None:
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "skipped",
|
||||
"detail": "No sync handler available",
|
||||
})
|
||||
continue
|
||||
|
||||
ds.last_sync_status = "in_progress"
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
summary = handler(db)
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_status = "success"
|
||||
ds.last_sync_stats = summary
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "success",
|
||||
"stats": summary,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||
ds.last_sync_status = "error"
|
||||
ds.last_sync_at = datetime.utcnow()
|
||||
ds.last_sync_stats = {"error": str(exc)}
|
||||
db.commit()
|
||||
results.append({
|
||||
"source": ds.name,
|
||||
"status": "error",
|
||||
"detail": "Sync failed. Check server logs for details.",
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_source_stats(db: Session, source_id: str) -> dict:
|
||||
"""Return detailed statistics for a data source.
|
||||
|
||||
Raises EntityNotFoundError if source does not exist.
|
||||
"""
|
||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||
if not ds:
|
||||
raise EntityNotFoundError("Data source", source_id)
|
||||
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.detection_rule import DetectionRule
|
||||
|
||||
template_count = 0
|
||||
rule_count = 0
|
||||
|
||||
if ds.type == "attack_procedure":
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
elif ds.type == "detection_rule":
|
||||
rule_count = (
|
||||
db.query(DetectionRule)
|
||||
.filter(DetectionRule.source == ds.name)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(ds.id),
|
||||
"name": ds.name,
|
||||
"display_name": ds.display_name,
|
||||
"type": ds.type,
|
||||
"is_enabled": ds.is_enabled,
|
||||
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
||||
"last_sync_status": ds.last_sync_status,
|
||||
"last_sync_stats": ds.last_sync_stats,
|
||||
"total_templates": template_count,
|
||||
"total_rules": rule_count,
|
||||
}
|
||||
@@ -3,12 +3,17 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.domain.exceptions import InvalidOperationError
|
||||
from app.models.jira_link import JiraLink
|
||||
from app.models.campaign import Campaign
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str:
|
||||
lines.append(f"*{key}:* {value}")
|
||||
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── Link CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_link(
|
||||
db: Session,
|
||||
*,
|
||||
entity_type: JiraLinkEntityType,
|
||||
entity_id: UUID,
|
||||
jira_issue_key: str,
|
||||
sync_direction: JiraSyncDirection,
|
||||
created_by: UUID,
|
||||
) -> JiraLink:
|
||||
"""Create a Jira link and optionally pull initial data from Jira."""
|
||||
link = JiraLink(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
jira_issue_key=jira_issue_key,
|
||||
sync_direction=sync_direction,
|
||||
created_by=created_by,
|
||||
)
|
||||
db.add(link)
|
||||
db.flush()
|
||||
|
||||
if settings.JIRA_ENABLED:
|
||||
try:
|
||||
sync_jira_to_aegis(db, link)
|
||||
except Exception as e:
|
||||
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
|
||||
|
||||
return link
|
||||
|
||||
|
||||
def list_links(
|
||||
db: Session,
|
||||
*,
|
||||
entity_type: Optional[JiraLinkEntityType] = None,
|
||||
entity_id: Optional[UUID] = None,
|
||||
) -> list[JiraLink]:
|
||||
"""List Jira links with optional filters."""
|
||||
query = db.query(JiraLink)
|
||||
if entity_type:
|
||||
query = query.filter(JiraLink.entity_type == entity_type)
|
||||
if entity_id:
|
||||
query = query.filter(JiraLink.entity_id == entity_id)
|
||||
return query.order_by(JiraLink.created_at.desc()).all()
|
||||
|
||||
|
||||
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
|
||||
"""Get a Jira link by ID or raise EntityNotFoundError."""
|
||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||
if not link:
|
||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||
return link
|
||||
|
||||
|
||||
def delete_link(db: Session, link_id: UUID) -> JiraLink:
|
||||
"""Delete a Jira link. Returns the deleted link (for audit)."""
|
||||
link = get_link_or_raise(db, link_id)
|
||||
db.delete(link)
|
||||
return link
|
||||
|
||||
|
||||
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
|
||||
"""Build Jira issue summary and description from an Aegis entity."""
|
||||
if entity_type == JiraLinkEntityType.test:
|
||||
entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Test", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Test] {entity.name}",
|
||||
f"Test: {entity.name}\n"
|
||||
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.campaign:
|
||||
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Campaign", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Campaign] {entity.name}",
|
||||
f"Campaign: {entity.name}\n"
|
||||
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.technique:
|
||||
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Technique", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
||||
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
||||
f"Tactic: {entity.tactic or 'N/A'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
else:
|
||||
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
||||
|
||||
|
||||
def create_issue_and_link(
|
||||
db: Session,
|
||||
*,
|
||||
entity_type: JiraLinkEntityType,
|
||||
entity_id: UUID,
|
||||
created_by: UUID,
|
||||
) -> dict:
|
||||
"""Create a Jira issue from an Aegis entity and link them."""
|
||||
summary, description = build_issue_data(db, entity_type, entity_id)
|
||||
result = create_jira_issue(
|
||||
project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
summary=summary,
|
||||
description=description,
|
||||
labels=["aegis", entity_type.value],
|
||||
)
|
||||
link = JiraLink(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
jira_issue_key=result["issue_key"],
|
||||
jira_issue_id=result["issue_id"],
|
||||
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
created_by=created_by,
|
||||
)
|
||||
db.add(link)
|
||||
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
||||
|
||||
@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.notification import Notification
|
||||
from app.models.user import User
|
||||
|
||||
@@ -22,6 +23,71 @@ from app.models.user import User
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def list_notifications(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> list[Notification]:
|
||||
"""Return paginated notifications for a user, newest first."""
|
||||
return (
|
||||
db.query(Notification)
|
||||
.filter(Notification.user_id == user_id)
|
||||
.order_by(Notification.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_notification_or_raise(
|
||||
db: Session,
|
||||
notification_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
) -> Notification:
|
||||
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
|
||||
notif = (
|
||||
db.query(Notification)
|
||||
.filter(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if notif is None:
|
||||
raise EntityNotFoundError("Notification", str(notification_id))
|
||||
return notif
|
||||
|
||||
|
||||
def notify_role(
|
||||
db: Session,
|
||||
*,
|
||||
role: str,
|
||||
type: str,
|
||||
title: str,
|
||||
message: str,
|
||||
entity_type: str,
|
||||
entity_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Send notifications to all active users with a given role."""
|
||||
users = (
|
||||
db.query(User)
|
||||
.filter(User.role == role, User.is_active == True) # noqa: E712
|
||||
.all()
|
||||
)
|
||||
for user in users:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=user.id,
|
||||
type=type,
|
||||
title=title,
|
||||
message=message,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
|
||||
|
||||
def create_notification(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
@@ -45,17 +111,13 @@ def create_notification(
|
||||
return notif
|
||||
|
||||
|
||||
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
|
||||
"""Mark a single notification as read. Returns True if updated."""
|
||||
notif = (
|
||||
db.query(Notification)
|
||||
.filter(Notification.id == notification_id, Notification.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if notif is None:
|
||||
return False
|
||||
def mark_as_read(
|
||||
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
|
||||
) -> Notification:
|
||||
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
|
||||
notif = get_notification_or_raise(db, notification_id, user_id)
|
||||
notif.read = True
|
||||
return True
|
||||
return notif
|
||||
|
||||
|
||||
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
||||
|
||||
@@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.osint_item import OsintItem
|
||||
from app.models.technique import Technique
|
||||
|
||||
@@ -177,15 +181,97 @@ def get_osint_items_for_technique(
|
||||
|
||||
|
||||
def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
|
||||
"""Mark an OSINT item as reviewed."""
|
||||
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork."""
|
||||
item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
|
||||
if item:
|
||||
item.reviewed = True
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
return item
|
||||
|
||||
|
||||
def get_unreviewed_count(db: Session) -> int:
|
||||
"""Return the total number of unreviewed OSINT items."""
|
||||
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
|
||||
|
||||
|
||||
def list_osint_items(
|
||||
db: Session,
|
||||
*,
|
||||
technique_id: Optional[UUID] = None,
|
||||
source_type: Optional[str] = None,
|
||||
reviewed: Optional[bool] = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""List OSINT items with optional filters and pagination."""
|
||||
query = db.query(OsintItem)
|
||||
if technique_id:
|
||||
query = query.filter(OsintItem.technique_id == technique_id)
|
||||
if source_type:
|
||||
query = query.filter(OsintItem.source_type == source_type)
|
||||
if reviewed is not None:
|
||||
query = query.filter(OsintItem.reviewed == reviewed)
|
||||
|
||||
total = query.count()
|
||||
items = (
|
||||
query.order_by(OsintItem.discovered_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": [
|
||||
{
|
||||
"id": str(item.id),
|
||||
"technique_id": str(item.technique_id),
|
||||
"source_type": item.source_type,
|
||||
"source_url": item.source_url,
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"severity": item.severity,
|
||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||
"reviewed": item.reviewed,
|
||||
"metadata": item.metadata_,
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_osint_summary(db: Session) -> dict:
|
||||
"""Summary statistics for OSINT items."""
|
||||
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
||||
unreviewed = get_unreviewed_count(db)
|
||||
|
||||
by_severity = dict(
|
||||
db.query(OsintItem.severity, func.count(OsintItem.id))
|
||||
.group_by(OsintItem.severity)
|
||||
.all()
|
||||
)
|
||||
|
||||
by_type = dict(
|
||||
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
||||
.group_by(OsintItem.source_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
techniques_with_items = (
|
||||
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_items": total,
|
||||
"unreviewed": unreviewed,
|
||||
"techniques_with_items": techniques_with_items,
|
||||
"by_severity": by_severity,
|
||||
"by_type": by_type,
|
||||
}
|
||||
|
||||
|
||||
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
|
||||
"""Get a technique by ID or raise EntityNotFoundError."""
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if not technique:
|
||||
raise EntityNotFoundError("Technique", str(technique_id))
|
||||
return technique
|
||||
|
||||
@@ -82,9 +82,7 @@ def update_scoring_weights(
|
||||
row.weight_freshness = new.freshness
|
||||
row.weight_platform_diversity = new.platform_diversity
|
||||
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
|
||||
# Does not commit; caller (router) uses UnitOfWork.
|
||||
return _weights_dict(new)
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import Optional
|
||||
from sqlalchemy import case, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.detection_rule import DetectionRule
|
||||
@@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict:
|
||||
# ── Technique-level scoring (single technique — preserved API) ────────
|
||||
|
||||
|
||||
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
|
||||
"""Get detailed score with breakdown for a technique by MITRE ID."""
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||
if not technique:
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
result = calculate_technique_score(technique, db)
|
||||
return {
|
||||
"mitre_id": technique.mitre_id,
|
||||
"name": technique.name,
|
||||
"tactic": technique.tactic,
|
||||
"status_global": technique.status_global.value if technique.status_global else None,
|
||||
**result,
|
||||
}
|
||||
|
||||
|
||||
def score_actor_by_id(db: Session, actor_id: str) -> dict:
|
||||
"""Get coverage score for a threat actor by ID."""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise EntityNotFoundError("ThreatActor", actor_id)
|
||||
return calculate_actor_coverage_score(actor_id, db)
|
||||
|
||||
|
||||
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from datetime import datetime
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.technique import Technique
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.models.enums import TechniqueStatus
|
||||
@@ -25,6 +26,101 @@ from app.services.scoring_service import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization and queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
||||
"""Lightweight serialization for list views."""
|
||||
return {
|
||||
"id": str(snap.id),
|
||||
"name": snap.name,
|
||||
"organization_score": snap.organization_score,
|
||||
"total_techniques": snap.total_techniques,
|
||||
"validated_count": snap.validated_count,
|
||||
"partial_count": snap.partial_count,
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
"created_by": str(snap.created_by) if snap.created_by else None,
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
||||
"""Full serialization including technique states."""
|
||||
base = serialize_snapshot_summary(snap)
|
||||
|
||||
technique_states = (
|
||||
db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
||||
.order_by(SnapshotTechniqueState.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
base["technique_states"] = [
|
||||
{
|
||||
"mitre_id": s.mitre_id,
|
||||
"technique_id": str(s.technique_id),
|
||||
"status": s.status,
|
||||
"score": s.score,
|
||||
}
|
||||
for s in technique_states
|
||||
]
|
||||
return base
|
||||
|
||||
|
||||
def list_snapshots(
|
||||
db: Session,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||
query = db.query(CoverageSnapshot)
|
||||
total = query.count()
|
||||
|
||||
snapshots = (
|
||||
query
|
||||
.order_by(CoverageSnapshot.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [serialize_snapshot_summary(s) for s in snapshots],
|
||||
}
|
||||
|
||||
|
||||
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
|
||||
"""Fetch snapshot by ID or raise EntityNotFoundError."""
|
||||
try:
|
||||
sid = uuid.UUID(snapshot_id)
|
||||
except (ValueError, TypeError):
|
||||
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
|
||||
if snapshot is None:
|
||||
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||
return snapshot
|
||||
|
||||
|
||||
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
|
||||
"""Get detailed snapshot including per-technique states."""
|
||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
return serialize_snapshot_detail(db, snapshot)
|
||||
|
||||
|
||||
def delete_snapshot(db: Session, snapshot_id: str) -> None:
|
||||
"""Delete a snapshot. Does not commit — caller must commit."""
|
||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
db.delete(snapshot)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create snapshot
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -138,7 +234,7 @@ def compare_snapshots(
|
||||
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
||||
|
||||
if not snap_a or not snap_b:
|
||||
return {"error": "One or both snapshots not found"}
|
||||
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
|
||||
|
||||
# Build lookup dicts: mitre_id -> {status, score}
|
||||
states_a = {
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Technique query service — framework-agnostic queries for technique details."""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.technique import Technique
|
||||
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||
|
||||
|
||||
def get_technique_detail(db: Session, mitre_id: str) -> dict:
|
||||
"""Fetch full technique details including tests and D3FEND defenses."""
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.options(joinedload(Technique.tests))
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
.first()
|
||||
)
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
return {
|
||||
"id": str(technique.id),
|
||||
"mitre_id": technique.mitre_id,
|
||||
"name": technique.name,
|
||||
"description": technique.description,
|
||||
"tactic": technique.tactic,
|
||||
"platforms": technique.platforms or [],
|
||||
"mitre_version": technique.mitre_version,
|
||||
"mitre_last_modified": technique.mitre_last_modified,
|
||||
"is_subtechnique": technique.is_subtechnique,
|
||||
"parent_mitre_id": technique.parent_mitre_id,
|
||||
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
||||
"review_required": technique.review_required,
|
||||
"last_review_date": technique.last_review_date,
|
||||
"tests": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"result": t.result.value if t.result else None,
|
||||
"platform": t.platform,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in technique.tests
|
||||
],
|
||||
"d3fend_defenses": defenses,
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Test template service — framework-agnostic CRUD and queries."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
def list_templates(
|
||||
db: Session,
|
||||
*,
|
||||
source: str | None = None,
|
||||
platform: str | None = None,
|
||||
severity: str | None = None,
|
||||
mitre_technique_id: str | None = None,
|
||||
search: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list:
|
||||
"""Return paginated, filterable list of test templates."""
|
||||
query = db.query(TestTemplate)
|
||||
if is_active is not None:
|
||||
query = query.filter(TestTemplate.is_active == is_active)
|
||||
|
||||
if source:
|
||||
query = query.filter(TestTemplate.source == source)
|
||||
if platform:
|
||||
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
if severity:
|
||||
query = query.filter(TestTemplate.severity == severity)
|
||||
if mitre_technique_id:
|
||||
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
|
||||
if search:
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
TestTemplate.name.ilike(pattern),
|
||||
TestTemplate.description.ilike(pattern),
|
||||
)
|
||||
)
|
||||
|
||||
templates = (
|
||||
query
|
||||
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return templates
|
||||
|
||||
|
||||
def get_template_stats(db: Session) -> dict:
|
||||
"""Return catalog statistics: totals by source, platform, active/inactive."""
|
||||
total = db.query(func.count(TestTemplate.id)).scalar() or 0
|
||||
active = (
|
||||
db.query(func.count(TestTemplate.id))
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
.scalar()
|
||||
) or 0
|
||||
inactive = total - active
|
||||
|
||||
source_rows = (
|
||||
db.query(TestTemplate.source, func.count(TestTemplate.id))
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
.group_by(TestTemplate.source)
|
||||
.all()
|
||||
)
|
||||
by_source = {source: cnt for source, cnt in source_rows}
|
||||
|
||||
platform_rows = (
|
||||
db.query(TestTemplate.platform, func.count(TestTemplate.id))
|
||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||
.group_by(TestTemplate.platform)
|
||||
.all()
|
||||
)
|
||||
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"active": active,
|
||||
"inactive": inactive,
|
||||
"by_source": by_source,
|
||||
"by_platform": by_platform,
|
||||
}
|
||||
|
||||
|
||||
def bulk_activate(db: Session, *, activate: bool) -> int:
|
||||
"""Set all templates to active or inactive. Returns count of affected. Does NOT commit."""
|
||||
count = (
|
||||
db.query(TestTemplate)
|
||||
.filter(TestTemplate.is_active != activate)
|
||||
.update({TestTemplate.is_active: activate})
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
def get_templates_by_technique(db: Session, mitre_id: str) -> list:
|
||||
"""Return all active templates mapped to a specific MITRE technique."""
|
||||
return (
|
||||
db.query(TestTemplate)
|
||||
.filter(
|
||||
TestTemplate.mitre_technique_id == mitre_id,
|
||||
TestTemplate.is_active == True, # noqa: E712
|
||||
)
|
||||
.order_by(TestTemplate.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate:
|
||||
"""Return a template by ID. Raises EntityNotFoundError if not found."""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if template is None:
|
||||
raise EntityNotFoundError("Test template", str(template_id))
|
||||
return template
|
||||
|
||||
|
||||
def create_template(db: Session, **fields: object) -> TestTemplate:
|
||||
"""Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit."""
|
||||
template = TestTemplate(**fields)
|
||||
db.add(template)
|
||||
return template
|
||||
|
||||
|
||||
def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate:
|
||||
"""Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit."""
|
||||
template = get_template_or_raise(db, template_id)
|
||||
for field, value in fields.items():
|
||||
if hasattr(template, field):
|
||||
setattr(template, field, value)
|
||||
return template
|
||||
|
||||
|
||||
def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate:
|
||||
"""Toggle template active/inactive. Does NOT commit."""
|
||||
template = get_template_or_raise(db, template_id)
|
||||
template.is_active = not template.is_active
|
||||
return template
|
||||
|
||||
|
||||
def soft_delete_template(db: Session, template_id: uuid.UUID) -> None:
|
||||
"""Soft-delete a template by setting is_active=False. Does NOT commit."""
|
||||
template = get_template_or_raise(db, template_id)
|
||||
template.is_active = False
|
||||
@@ -0,0 +1,88 @@
|
||||
"""User management service — framework-agnostic CRUD for users.
|
||||
|
||||
Uses domain exceptions from app.domain.errors. The router handles
|
||||
HTTP concerns, auth, audit logging, and commit.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.auth import hash_password
|
||||
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
|
||||
from app.models.user import User
|
||||
|
||||
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||
|
||||
|
||||
def list_users(db: Session) -> list[User]:
|
||||
"""Return a list of all users ordered by username."""
|
||||
return db.query(User).order_by(User.username).all()
|
||||
|
||||
|
||||
def create_user(
|
||||
db: Session,
|
||||
*,
|
||||
username: str,
|
||||
email: str | None,
|
||||
password: str,
|
||||
role: str,
|
||||
) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Raises DuplicateEntityError if username already exists.
|
||||
Raises BusinessRuleViolation if role is invalid.
|
||||
Does not commit; the router handles that.
|
||||
"""
|
||||
existing = db.query(User).filter(User.username == username).first()
|
||||
if existing:
|
||||
raise DuplicateEntityError("User", "username", username)
|
||||
|
||||
if role not in VALID_ROLES:
|
||||
raise BusinessRuleViolation(
|
||||
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||
)
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
hashed_password=hash_password(password),
|
||||
role=role,
|
||||
)
|
||||
db.add(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
|
||||
"""Return a user by ID or raise EntityNotFoundError."""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise EntityNotFoundError("User", str(user_id))
|
||||
return user
|
||||
|
||||
|
||||
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
|
||||
"""Update one or more fields of an existing user.
|
||||
|
||||
Raises EntityNotFoundError if user does not exist.
|
||||
Raises BusinessRuleViolation if role is invalid.
|
||||
Handles 'password' by hashing and storing as 'hashed_password'.
|
||||
Does not commit; the router handles that.
|
||||
"""
|
||||
user = get_user_or_raise(db, user_id)
|
||||
update_data = dict(fields)
|
||||
|
||||
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
||||
raise BusinessRuleViolation(
|
||||
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||
)
|
||||
|
||||
if "password" in update_data:
|
||||
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
return user
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
from app.models.worklog import Worklog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,8 +39,15 @@ def create_worklog(
|
||||
)
|
||||
wl.integrity_hash = _compute_hash(wl)
|
||||
db.add(wl)
|
||||
db.commit()
|
||||
db.refresh(wl)
|
||||
# Does not commit; caller (router) uses UnitOfWork.
|
||||
return wl
|
||||
|
||||
|
||||
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
|
||||
"""Get a worklog by ID or raise EntityNotFoundError."""
|
||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||
if not wl:
|
||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||
return wl
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Tests for CampaignEntity — pure domain logic, no DB."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if backend_dir not in sys.path:
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
from app.domain.entities.campaign import (
|
||||
CampaignEntity,
|
||||
CampaignStatus,
|
||||
CampaignType,
|
||||
)
|
||||
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _entity(status: str = "draft", test_count: int = 0, **overrides) -> CampaignEntity:
|
||||
defaults = dict(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Campaign",
|
||||
type=CampaignType.custom,
|
||||
status=CampaignStatus(status),
|
||||
description=None,
|
||||
threat_actor_id=None,
|
||||
created_by=None,
|
||||
target_platform=None,
|
||||
tags=[],
|
||||
test_count=test_count,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return CampaignEntity(**defaults)
|
||||
|
||||
|
||||
def _fake_orm(status: str = "draft", test_count: int = 0, **overrides) -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.id = uuid.uuid4()
|
||||
m.name = "Test Campaign"
|
||||
m.type = "custom"
|
||||
m.status = status
|
||||
m.description = None
|
||||
m.threat_actor_id = None
|
||||
m.created_by = None
|
||||
m.target_platform = None
|
||||
m.tags = []
|
||||
m.campaign_tests = [MagicMock()] * test_count if test_count else []
|
||||
for k, v in overrides.items():
|
||||
setattr(m, k, v)
|
||||
return m
|
||||
|
||||
|
||||
# ── 1. Test activation from draft with tests → success ───────────────
|
||||
|
||||
|
||||
def test_activate_from_draft_with_tests_success():
|
||||
e = _entity("draft", test_count=1)
|
||||
e.activate()
|
||||
assert e.status == CampaignStatus.active
|
||||
|
||||
|
||||
def test_activate_from_draft_with_multiple_tests_success():
|
||||
e = _entity("draft", test_count=3)
|
||||
e.activate()
|
||||
assert e.status == CampaignStatus.active
|
||||
|
||||
|
||||
# ── 2. Test activation from draft with 0 tests → BusinessRuleViolation ───
|
||||
|
||||
|
||||
def test_activate_from_draft_with_zero_tests_raises():
|
||||
e = _entity("draft", test_count=0)
|
||||
with pytest.raises(BusinessRuleViolation, match="at least one test"):
|
||||
e.activate()
|
||||
assert e.status == CampaignStatus.draft
|
||||
|
||||
|
||||
# ── 3. Test activation from active → InvalidStateTransition ────────────
|
||||
|
||||
|
||||
def test_activate_from_active_raises():
|
||||
e = _entity("active", test_count=2)
|
||||
with pytest.raises(InvalidStateTransition) as exc_info:
|
||||
e.activate()
|
||||
assert exc_info.value.current_state == "active"
|
||||
assert exc_info.value.target_state == "active"
|
||||
assert "completed" in exc_info.value.valid_transitions
|
||||
|
||||
|
||||
# ── 4. Test complete from active → success ──────────────────────────────
|
||||
|
||||
|
||||
def test_complete_from_active_success():
|
||||
e = _entity("active", test_count=2)
|
||||
e.complete()
|
||||
assert e.status == CampaignStatus.completed
|
||||
|
||||
|
||||
# ── 5. Test complete from draft → InvalidStateTransition ────────────────
|
||||
|
||||
|
||||
def test_complete_from_draft_raises():
|
||||
e = _entity("draft", test_count=1)
|
||||
with pytest.raises(InvalidStateTransition) as exc_info:
|
||||
e.complete()
|
||||
assert exc_info.value.current_state == "draft"
|
||||
assert exc_info.value.target_state == "completed"
|
||||
assert "active" in exc_info.value.valid_transitions
|
||||
|
||||
|
||||
# ── 6. Test ensure_modifiable in draft/active → ok ──────────────────────
|
||||
|
||||
|
||||
def test_ensure_modifiable_draft_ok():
|
||||
e = _entity("draft")
|
||||
e.ensure_modifiable() # no raise
|
||||
|
||||
|
||||
def test_ensure_modifiable_active_ok():
|
||||
e = _entity("active", test_count=1)
|
||||
e.ensure_modifiable() # no raise
|
||||
|
||||
|
||||
# ── 7. Test ensure_modifiable in completed → BusinessRuleViolation ──────
|
||||
|
||||
|
||||
def test_ensure_modifiable_completed_raises():
|
||||
e = _entity("completed", test_count=1)
|
||||
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
|
||||
e.ensure_modifiable()
|
||||
|
||||
|
||||
def test_ensure_modifiable_archived_raises():
|
||||
e = _entity("archived", test_count=1)
|
||||
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
|
||||
e.ensure_modifiable()
|
||||
|
||||
|
||||
# ── 8. Test from_orm conversion ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_from_orm_basic():
|
||||
orm = _fake_orm("draft", test_count=0)
|
||||
e = CampaignEntity.from_orm(orm)
|
||||
assert e.name == "Test Campaign"
|
||||
assert e.type == CampaignType.custom
|
||||
assert e.status == CampaignStatus.draft
|
||||
assert e.id == orm.id
|
||||
assert e.test_count == 0
|
||||
|
||||
|
||||
def test_from_orm_with_tests():
|
||||
orm = _fake_orm("draft", test_count=3)
|
||||
e = CampaignEntity.from_orm(orm)
|
||||
assert e.test_count == 3
|
||||
|
||||
|
||||
def test_from_orm_coerces_type_and_status():
|
||||
orm = _fake_orm(status="active", type="apt_emulation", test_count=1)
|
||||
e = CampaignEntity.from_orm(orm)
|
||||
assert e.status == CampaignStatus.active
|
||||
assert e.type == CampaignType.apt_emulation
|
||||
|
||||
|
||||
def test_from_orm_handles_none_tags():
|
||||
orm = _fake_orm("draft", test_count=0)
|
||||
orm.tags = None
|
||||
e = CampaignEntity.from_orm(orm)
|
||||
assert e.tags == []
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Tests for compliance domain entities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain.entities.compliance import (
|
||||
ComplianceControlEntity,
|
||||
ComplianceFrameworkEntity,
|
||||
ControlCoverageStatus,
|
||||
)
|
||||
|
||||
|
||||
# ── Control coverage status ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_control_all_techniques_validated_covered():
|
||||
"""All techniques validated → covered."""
|
||||
control = ComplianceControlEntity(
|
||||
control_id="AC-2",
|
||||
title="Account Management",
|
||||
technique_statuses=["validated", "validated"],
|
||||
)
|
||||
assert control.coverage_status == ControlCoverageStatus.covered
|
||||
|
||||
|
||||
def test_control_all_techniques_partial_covered():
|
||||
"""All techniques partial → covered."""
|
||||
control = ComplianceControlEntity(
|
||||
control_id="AC-2",
|
||||
title="Account Management",
|
||||
technique_statuses=["partial"],
|
||||
)
|
||||
assert control.coverage_status == ControlCoverageStatus.covered
|
||||
|
||||
|
||||
def test_control_mixed_statuses_partially_covered():
|
||||
"""Mixed statuses (some validated/partial, some not) → partially_covered."""
|
||||
control = ComplianceControlEntity(
|
||||
control_id="AC-2",
|
||||
title="Account Management",
|
||||
technique_statuses=["validated", "not_evaluated"],
|
||||
)
|
||||
assert control.coverage_status == ControlCoverageStatus.partially_covered
|
||||
|
||||
|
||||
def test_control_no_validated_techniques_not_covered():
|
||||
"""No validated/partial techniques → not_covered."""
|
||||
control = ComplianceControlEntity(
|
||||
control_id="AC-2",
|
||||
title="Account Management",
|
||||
technique_statuses=["not_evaluated", "not_covered"],
|
||||
)
|
||||
assert control.coverage_status == ControlCoverageStatus.not_covered
|
||||
|
||||
|
||||
def test_control_empty_techniques_not_covered():
|
||||
"""Empty technique_statuses → not_covered."""
|
||||
control = ComplianceControlEntity(
|
||||
control_id="AC-2",
|
||||
title="Account Management",
|
||||
technique_statuses=[],
|
||||
)
|
||||
assert control.coverage_status == ControlCoverageStatus.not_covered
|
||||
|
||||
|
||||
# ── Framework coverage ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_framework_coverage_pct_calculation():
|
||||
"""Framework coverage_pct = (covered_controls / total_controls) * 100."""
|
||||
controls = [
|
||||
ComplianceControlEntity("AC-1", "Title 1", technique_statuses=["validated"]),
|
||||
ComplianceControlEntity("AC-2", "Title 2", technique_statuses=["not_evaluated"]),
|
||||
ComplianceControlEntity("AC-3", "Title 3", technique_statuses=["validated", "partial"]),
|
||||
ComplianceControlEntity("AC-4", "Title 4", technique_statuses=["partial"]),
|
||||
ComplianceControlEntity("AC-5", "Title 5", technique_statuses=[]),
|
||||
]
|
||||
framework = ComplianceFrameworkEntity(name="NIST 800-53", controls=controls)
|
||||
# AC-1: covered, AC-2: not_covered, AC-3: covered, AC-4: covered, AC-5: not_covered
|
||||
assert framework.total_controls == 5
|
||||
assert framework.covered_controls == 3
|
||||
assert framework.coverage_pct == 60.0
|
||||
|
||||
|
||||
def test_framework_get_gap_controls():
|
||||
"""get_gap_controls returns only uncovered and partially_covered controls."""
|
||||
controls = [
|
||||
ComplianceControlEntity("AC-1", "Covered", technique_statuses=["validated"]),
|
||||
ComplianceControlEntity("AC-2", "Partial", technique_statuses=["validated", "not_evaluated"]),
|
||||
ComplianceControlEntity("AC-3", "Not Covered", technique_statuses=["not_evaluated"]),
|
||||
ComplianceControlEntity("AC-4", "Empty", technique_statuses=[]),
|
||||
]
|
||||
framework = ComplianceFrameworkEntity(name="Test", controls=controls)
|
||||
gaps = framework.get_gap_controls()
|
||||
assert len(gaps) == 3
|
||||
assert gaps[0].control_id == "AC-2"
|
||||
assert gaps[1].control_id == "AC-3"
|
||||
assert gaps[2].control_id == "AC-4"
|
||||
|
||||
|
||||
def test_framework_no_controls_coverage_pct_zero():
|
||||
"""Framework with no controls → coverage_pct is 0."""
|
||||
framework = ComplianceFrameworkEntity(name="Empty", controls=[])
|
||||
assert framework.total_controls == 0
|
||||
assert framework.covered_controls == 0
|
||||
assert framework.coverage_pct == 0.0
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Tests for the ImportService protocol and IMPORT_REGISTRY."""
|
||||
|
||||
from app.domain.ports.import_service import (
|
||||
IMPORT_REGISTRY,
|
||||
ImportService,
|
||||
ImportServiceEntry,
|
||||
get_import_handler,
|
||||
)
|
||||
|
||||
|
||||
def test_registry_has_all_known_sources():
|
||||
expected = {
|
||||
"atomic_red_team",
|
||||
"sigma",
|
||||
"lolbas",
|
||||
"gtfobins",
|
||||
"caldera",
|
||||
"elastic_rules",
|
||||
"mitre_cti",
|
||||
"d3fend",
|
||||
}
|
||||
assert set(IMPORT_REGISTRY.keys()) == expected
|
||||
|
||||
|
||||
def test_all_entries_are_import_service_entries():
|
||||
for name, entry in IMPORT_REGISTRY.items():
|
||||
assert isinstance(entry, ImportServiceEntry), f"{name} is not ImportServiceEntry"
|
||||
|
||||
|
||||
def test_get_import_handler_returns_entry_for_known_source():
|
||||
handler = get_import_handler("sigma")
|
||||
assert handler is not None
|
||||
assert isinstance(handler, ImportServiceEntry)
|
||||
|
||||
|
||||
def test_get_import_handler_returns_none_for_unknown():
|
||||
assert get_import_handler("nonexistent_source") is None
|
||||
|
||||
|
||||
def test_import_service_entry_source_info():
|
||||
entry = ImportServiceEntry("app.services.sigma_import_service", "sync")
|
||||
assert entry.source_info == "app.services.sigma_import_service.sync"
|
||||
|
||||
|
||||
def test_callable_satisfies_protocol():
|
||||
def mock_handler(db):
|
||||
return {"created": 0}
|
||||
|
||||
assert isinstance(mock_handler, ImportService)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Tests for the ThreatActorEntity domain entity."""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.domain.entities.threat_actor import (
|
||||
ThreatActorEntity,
|
||||
ThreatActorTechniqueRef,
|
||||
)
|
||||
|
||||
|
||||
def _make_ref(status: str = "not_evaluated") -> ThreatActorTechniqueRef:
|
||||
return ThreatActorTechniqueRef(
|
||||
technique_id=uuid.uuid4(),
|
||||
mitre_id="T1059",
|
||||
name="Command and Scripting Interpreter",
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def test_coverage_pct_all_covered():
|
||||
actor = ThreatActorEntity(
|
||||
name="APT29",
|
||||
techniques=[_make_ref("validated"), _make_ref("partial")],
|
||||
)
|
||||
assert actor.coverage_pct == 100.0
|
||||
|
||||
|
||||
def test_coverage_pct_partial():
|
||||
actor = ThreatActorEntity(
|
||||
name="APT28",
|
||||
techniques=[_make_ref("validated"), _make_ref("not_evaluated")],
|
||||
)
|
||||
assert actor.coverage_pct == 50.0
|
||||
|
||||
|
||||
def test_coverage_pct_none_covered():
|
||||
actor = ThreatActorEntity(
|
||||
name="Lazarus",
|
||||
techniques=[_make_ref("not_evaluated"), _make_ref("in_progress")],
|
||||
)
|
||||
assert actor.coverage_pct == 0.0
|
||||
|
||||
|
||||
def test_coverage_pct_no_techniques():
|
||||
actor = ThreatActorEntity(name="Unknown")
|
||||
assert actor.coverage_pct == 0.0
|
||||
|
||||
|
||||
def test_covered_and_uncovered_techniques():
|
||||
t1 = _make_ref("validated")
|
||||
t2 = _make_ref("not_evaluated")
|
||||
t3 = _make_ref("partial")
|
||||
actor = ThreatActorEntity(name="Test", techniques=[t1, t2, t3])
|
||||
assert len(actor.covered_techniques) == 2
|
||||
assert len(actor.uncovered_techniques) == 1
|
||||
|
||||
|
||||
def test_technique_count():
|
||||
actor = ThreatActorEntity(
|
||||
name="Test",
|
||||
techniques=[_make_ref(), _make_ref(), _make_ref()],
|
||||
)
|
||||
assert actor.technique_count == 3
|
||||
|
||||
|
||||
def test_from_orm_basic():
|
||||
orm = SimpleNamespace(
|
||||
id=uuid.uuid4(),
|
||||
name="APT29",
|
||||
mitre_id="G0016",
|
||||
aliases=["Cozy Bear", "The Dukes"],
|
||||
description="Russian APT group",
|
||||
country="Russia",
|
||||
target_sectors=["government"],
|
||||
target_regions=["north-america"],
|
||||
motivation="espionage",
|
||||
sophistication="advanced",
|
||||
first_seen="2008",
|
||||
last_seen="2023",
|
||||
is_active=True,
|
||||
techniques=[],
|
||||
)
|
||||
entity = ThreatActorEntity.from_orm(orm)
|
||||
assert entity.name == "APT29"
|
||||
assert entity.mitre_id == "G0016"
|
||||
assert entity.country == "Russia"
|
||||
assert entity.aliases == ["Cozy Bear", "The Dukes"]
|
||||
assert entity.technique_count == 0
|
||||
|
||||
|
||||
def test_from_orm_with_techniques():
|
||||
tech_orm = SimpleNamespace(
|
||||
mitre_id="T1059",
|
||||
name="Command and Scripting Interpreter",
|
||||
status_global=SimpleNamespace(value="validated"),
|
||||
)
|
||||
tat_orm = SimpleNamespace(
|
||||
technique_id=uuid.uuid4(),
|
||||
technique=tech_orm,
|
||||
usage_description="Uses PowerShell",
|
||||
)
|
||||
orm = SimpleNamespace(
|
||||
id=uuid.uuid4(),
|
||||
name="APT28",
|
||||
mitre_id="G0007",
|
||||
aliases=None,
|
||||
description=None,
|
||||
country=None,
|
||||
target_sectors=None,
|
||||
target_regions=None,
|
||||
motivation=None,
|
||||
sophistication=None,
|
||||
first_seen=None,
|
||||
last_seen=None,
|
||||
is_active=None,
|
||||
techniques=[tat_orm],
|
||||
)
|
||||
entity = ThreatActorEntity.from_orm(orm)
|
||||
assert entity.technique_count == 1
|
||||
assert entity.techniques[0].mitre_id == "T1059"
|
||||
assert entity.techniques[0].status == "validated"
|
||||
assert entity.is_active is True # defaults when None
|
||||
@@ -1,7 +1,7 @@
|
||||
# Aegis — Deep Architectural Analysis
|
||||
|
||||
> **Author:** Automated architecture review
|
||||
> **Date:** February 11, 2026 (updated February 19, 2026)
|
||||
> **Date:** February 11, 2026 (updated February 20, 2026; Tier 1-4 complete)
|
||||
> **Scope:** Backend (FastAPI/Python), Frontend (React/TypeScript), Infrastructure (Docker)
|
||||
>
|
||||
> **Note:** Sections marked with ✅ reflect changes implemented since the initial analysis.
|
||||
@@ -69,9 +69,9 @@ Aegis follows a **layered monolithic architecture** deployed as two containers (
|
||||
|
||||
| Layer | Files | Actual Responsibility |
|
||||
|-------|-------|----------------------|
|
||||
| **Routers** | 21 files | ✅ Thin HTTP adapters — auth, param parsing, response formatting. Delegate to services. |
|
||||
| **Services** | 30+ files | ✅ All business logic, query orchestration, domain validation. Framework-agnostic. |
|
||||
| **Domain** | 8+ files | ✅ Pure entities, value objects, ports, errors. Zero framework imports. |
|
||||
| **Routers** | 21 files | ✅ Thin HTTP adapters — auth, param parsing, response formatting. All delegate to services. Zero inline ORM queries. |
|
||||
| **Services** | 40+ files | ✅ All business logic, query orchestration, domain validation. Framework-agnostic. Includes: 4 newly extracted (advanced_metrics, analytics, test_template, auth) + 7 new query services (technique_query, d3fend_query, etc.). |
|
||||
| **Domain** | 15+ files | ✅ Pure entities (Test, Technique, Campaign, Compliance, ThreatActor), value objects, ports (repos + ImportService protocol), errors. Zero framework imports. |
|
||||
| **Infrastructure** | 5+ files | ✅ Repository implementations, Redis client, mappers. |
|
||||
| **Models** | 19 files | ORM table definitions — persistence mapping only |
|
||||
| **Schemas** | 10 files | Pydantic DTOs for request/response |
|
||||
@@ -91,9 +91,11 @@ def get_threat_actor(actor_id: str, db=Depends(get_db), current_user=Depends(get
|
||||
return get_actor_detail(db, actor_id)
|
||||
```
|
||||
|
||||
Extracted services: `coverage_report_service`, `metrics_query_service`, `compliance_service`, `detection_rule_service`, `threat_actor_service`, `test_crud_service`, `evidence_service`, `campaign_crud_service`, `scoring_config_service`.
|
||||
Extracted services: `coverage_report_service`, `metrics_query_service`, `compliance_service`, `detection_rule_service`, `threat_actor_service`, `test_crud_service`, `evidence_service`, `campaign_crud_service`, `scoring_config_service`, `user_service`, `audit_query_service`, `data_source_service`.
|
||||
|
||||
**Remaining:** `users.py`, `audit.py`, `data_sources.py`, `heatmap.py` still have direct queries. These are lower priority since they are simpler or already partially extracted.
|
||||
**Update (Feb 20):** All routers now delegate to services. No routers contain direct ORM queries or business logic.
|
||||
|
||||
**Update (Feb 20 — Tier 1-2):** Four more routers fully extracted to new services: `advanced_metrics.py` → `advanced_metrics_service`, `analytics.py` → `analytics_service`, `test_templates.py` → `test_template_service`, `auth.py` → `auth_service`. Nine additional routers had remaining inline logic moved to their existing services: `techniques`, `campaigns`, `snapshots`, `notifications`, `scores`, `jira`, `d3fend`, `osint`, `worklogs`.
|
||||
|
||||
---
|
||||
|
||||
@@ -110,25 +112,25 @@ Schemas NONE NONE LOW — NONE NONE
|
||||
Database NONE NONE NONE — NONE LOW
|
||||
```
|
||||
|
||||
### 2.2. Router ↔ Model — ✅ LARGELY RESOLVED (was HIGH COUPLING)
|
||||
### 2.2. Router ↔ Model — ✅ FULLY RESOLVED (was HIGH COUPLING)
|
||||
|
||||
**Update (Feb 19):** Most routers no longer import ORM models or execute queries directly. Only **4 out of 21 routers** still have direct DB access:
|
||||
**Update (Feb 20):** All routers now delegate to services. No router imports ORM models or executes queries directly.
|
||||
|
||||
| Router | Status | Detail |
|
||||
|--------|--------|--------|
|
||||
| `techniques.py` | ✅ Extracted | Uses `SATechniqueRepository` via dependency injection |
|
||||
| `reports.py` | ✅ Extracted | Delegates to `coverage_report_service` |
|
||||
| `metrics.py` | ✅ Extracted | Delegates to `metrics_query_service` |
|
||||
| `compliance.py` | ✅ Extracted | Delegates to `compliance_service` |
|
||||
| `detection_rules.py` | ✅ Extracted | Delegates to `detection_rule_service` |
|
||||
| `threat_actors.py` | ✅ Extracted | Delegates to `threat_actor_service` |
|
||||
| `tests.py` | ✅ Extracted | Delegates to `test_crud_service` + `test_workflow_service` |
|
||||
| `evidence.py` | ✅ Extracted | Delegates to `evidence_service` |
|
||||
| `campaigns.py` | ✅ Extracted | Delegates to `campaign_crud_service` |
|
||||
| `users.py` | Remaining | Direct queries (simple CRUD) |
|
||||
| `audit.py` | Remaining | Direct queries (read-only list) |
|
||||
| `data_sources.py` | Remaining | Direct queries |
|
||||
| `heatmap.py` | Remaining | Complex queries (partially extracted via `heatmap_service`) |
|
||||
| Router | Status | Service |
|
||||
|--------|--------|---------|
|
||||
| `techniques.py` | ✅ Extracted | `SATechniqueRepository` via dependency injection |
|
||||
| `reports.py` | ✅ Extracted | `coverage_report_service` |
|
||||
| `metrics.py` | ✅ Extracted | `metrics_query_service` |
|
||||
| `compliance.py` | ✅ Extracted | `compliance_service` |
|
||||
| `detection_rules.py` | ✅ Extracted | `detection_rule_service` |
|
||||
| `threat_actors.py` | ✅ Extracted | `threat_actor_service` |
|
||||
| `tests.py` | ✅ Extracted | `test_crud_service` + `test_workflow_service` |
|
||||
| `evidence.py` | ✅ Extracted | `evidence_service` |
|
||||
| `campaigns.py` | ✅ Extracted | `campaign_crud_service` |
|
||||
| `users.py` | ✅ Extracted | `user_service` |
|
||||
| `audit.py` | ✅ Extracted | `audit_query_service` |
|
||||
| `data_sources.py` | ✅ Extracted | `data_source_service` |
|
||||
| `heatmap.py` | ✅ Extracted | `heatmap_service` |
|
||||
|
||||
### 2.3. Router ↔ Database — HIGH COUPLING
|
||||
|
||||
@@ -197,10 +199,13 @@ Communication is via REST API with aligned but independent types (`types/models.
|
||||
| **Threat actors** | ✅ SEPARATED | `threat_actor_service.py` handles queries, coverage, and gap analysis (N+1 fixed) |
|
||||
| **Evidence** | ✅ SEPARATED | `evidence_service.py` handles permission validation and queries with domain exceptions |
|
||||
| **Campaigns** | ✅ SEPARATED | `campaign_crud_service.py` handles CRUD, lifecycle, and scheduling |
|
||||
| **Heatmap/visualization** | PARTIAL | `heatmap_service.py` exists but router still has some logic |
|
||||
| **Data import** | WELL SEPARATED | The 8 import services are correctly isolated |
|
||||
| **Heatmap/visualization** | ✅ SEPARATED | `heatmap_service.py` contains all layer-building logic; router is a thin adapter |
|
||||
| **Data import** | ✅ WELL SEPARATED | 8 import services behind `ImportService` protocol + central registry |
|
||||
| **Data sources** | ✅ SEPARATED | `data_source_service.py` handles CRUD, sync dispatch, and stats |
|
||||
| **Users** | ✅ SEPARATED | `user_service.py` handles CRUD, validation, and hashing |
|
||||
| **Audit queries** | ✅ SEPARATED | `audit_query_service.py` handles paginated queries and distinct lookups |
|
||||
| **Notifications** | WELL SEPARATED | `notification_service.py` encapsulates all logic |
|
||||
| **Auditing** | WELL SEPARATED | `audit_service.py` is a pure `log_action()` function |
|
||||
| **Auditing (writes)** | WELL SEPARATED | `audit_service.py` is a pure `log_action()` function |
|
||||
|
||||
### 3.2. Anemic Model (Anti-pattern)
|
||||
|
||||
@@ -237,36 +242,40 @@ Logic that should be in domain models (business validations, state transitions,
|
||||
|
||||
| Component | Compliant? | Detail |
|
||||
|-----------|-----------|-------|
|
||||
| `heatmap.py` (router) | PARTIAL | Still has some inline logic; `heatmap_service` exists but not fully extracted |
|
||||
| `heatmap.py` (router) | ✅ YES | Thin adapter → `heatmap_service` |
|
||||
| `reports.py` (router) | ✅ YES | Thin adapter → `coverage_report_service` |
|
||||
| `tests.py` (router) | ✅ YES | Thin adapter → `test_crud_service` + `test_workflow_service` |
|
||||
| `campaigns.py` (router) | ✅ YES | Thin adapter → `campaign_crud_service` |
|
||||
| `evidence.py` (router) | ✅ YES | Thin adapter → `evidence_service` |
|
||||
| `users.py` (router) | ✅ YES | Thin adapter → `user_service` |
|
||||
| `audit.py` (router) | ✅ YES | Thin adapter → `audit_query_service` |
|
||||
| `data_sources.py` (router) | ✅ YES | Thin adapter → `data_source_service` |
|
||||
| `scoring_service.py` | ✅ YES | Reads weights from `scoring_config_service` (DB-backed, not mutable settings) |
|
||||
| `test_workflow_service.py` | ✅ YES | Single responsibility: test state machine |
|
||||
| `notification_service.py` | ✅ YES | Single responsibility: notification management |
|
||||
| `audit_service.py` | ✅ YES | Single responsibility: audit logging |
|
||||
|
||||
**Verdict:** All major routers now comply with SRP. Only `heatmap.py` and a few minor routers have remaining inline logic.
|
||||
**Verdict:** All routers now comply with SRP. Every router is a thin HTTP adapter delegating to a dedicated service.
|
||||
|
||||
### 4.2. Open/Closed Principle (OCP) — ✅ PARTIALLY RESOLVED (was VIOLATION)
|
||||
### 4.2. Open/Closed Principle (OCP) — ✅ MOSTLY RESOLVED (was VIOLATION)
|
||||
|
||||
**Update (Feb 19):**
|
||||
**Update (Feb 20):**
|
||||
|
||||
- **Scoring weights:** ✅ Resolved — Weights are now persisted in the `scoring_config` DB table via `scoring_config_service.py`. The `ScoringWeights` value object validates invariants (sum = 100, non-negative). No more mutable global `settings`.
|
||||
- **Heatmap layers:** Each heatmap type is a separate endpoint with hardcoded logic. Adding a new layer type requires modifying the router.
|
||||
- **Import services:** Each data source is a separate service without a common interface. Adding a new source requires creating a new service AND modifying `data_sources.py` and `system.py`.
|
||||
- **Import services:** ✅ Resolved — All import services now satisfy the `ImportService` protocol (`domain/ports/import_service.py`). A central `IMPORT_REGISTRY` maps source names to lazy-loaded handlers. Adding a new import source requires only: (1) creating a new service module, (2) adding one line to `IMPORT_REGISTRY`.
|
||||
- **Heatmap layers:** Each heatmap type is a separate endpoint with hardcoded logic. Adding a new layer type requires modifying the router. Low priority.
|
||||
- **Test states:** The state machine is well defined in `VALID_TRANSITIONS`, but adding a new state requires modifying the dictionary AND potentially all services that read `TestState`.
|
||||
|
||||
### 4.3. Liskov Substitution Principle (LSP) — N/A (Partial)
|
||||
|
||||
There is no significant inheritance or polymorphism in the backend. Services are functions, not classes. There are no interfaces or abstract classes. **Does not directly apply**, but the absence of formal contracts (protocols/ABCs) is a symptom of not being designed for extensibility.
|
||||
|
||||
### 4.4. Interface Segregation Principle (ISP) — ✅ PARTIALLY RESOLVED (was VIOLATION)
|
||||
### 4.4. Interface Segregation Principle (ISP) — ✅ MOSTLY RESOLVED (was VIOLATION)
|
||||
|
||||
**Update (Feb 19):**
|
||||
**Update (Feb 20):**
|
||||
|
||||
- ✅ Protocol interfaces exist for `TechniqueRepository` and `TestRepository` in `domain/ports/repositories/`.
|
||||
- ✅ `ImportService` protocol in `domain/ports/import_service.py` — common contract for all data import services.
|
||||
- Services expose focused functions per module (e.g., `threat_actor_service` exposes 4 functions, each for one use case).
|
||||
- The `Settings` object is still monolithic but scoring weights have been extracted to a dedicated DB table with a focused service interface.
|
||||
|
||||
@@ -310,7 +319,7 @@ def get_technique_repository(db=Depends(get_db)) -> SATechniqueRepository: ...
|
||||
| `threat_actors.py` | 312 lines | ~100 lines | `threat_actor_service.py` |
|
||||
| `evidence.py` | 367 lines | ~200 lines | `evidence_service.py` |
|
||||
|
||||
**Remaining:** `heatmap.py` still has inline logic (~528 lines). Lower priority since it's already partially extracted to `heatmap_service`.
|
||||
**Update (Feb 20):** `heatmap.py` is also now a thin adapter — all logic was already in `heatmap_service`. Additionally, `users.py`, `audit.py`, and `data_sources.py` have been extracted to `user_service`, `audit_query_service`, and `data_source_service` respectively. No remaining fat routers.
|
||||
|
||||
### 5.2. ~~CRITICAL RISK: In-Memory Token Blacklist~~ ✅ RESOLVED
|
||||
|
||||
@@ -374,13 +383,15 @@ Background jobs create sessions outside the request lifecycle. This is technical
|
||||
- `domain/value_objects/` — `MitreId`, `ScoringWeights` (immutable, validated).
|
||||
- ORM models remain anemic by design (persistence mapping only). Business logic lives in domain entities.
|
||||
|
||||
**Remaining:** Campaign, ComplianceFramework, ThreatActor still lack domain entity counterparts.
|
||||
**Update (Feb 20):** `CampaignEntity` (with lifecycle state machine) and `ComplianceFrameworkEntity` / `ComplianceControlEntity` (with coverage calculation logic) have been added.
|
||||
|
||||
**Update (Feb 20 — Tier 4):** `ThreatActorEntity` (with coverage analysis: `coverage_pct`, `covered_techniques`, `uncovered_techniques`, `from_orm`) has been added. All major domain concepts now have rich entity counterparts.
|
||||
|
||||
### 5.8. ~~MEDIUM RISK: No Explicit Transaction Management~~ ✅ PARTIALLY RESOLVED
|
||||
|
||||
**Update (Feb 18):** A `UnitOfWork` context manager exists at `domain/unit_of_work.py` with explicit `commit()`, `rollback()`, and `flush()`. Used by `test_workflow_service.py` which explicitly states "The caller (router) is responsible for committing the session via the Unit of Work pattern."
|
||||
|
||||
**Remaining:** Some services like `audit_service.py` still call `db.commit()` directly. Needs incremental migration.
|
||||
**Update (Feb 20 — Tier 3):** Business services (`scoring_config_service`, `worklog_service`, `osint_enrichment_service.mark_osint_reviewed`) no longer call `db.commit()` — their callers use `UnitOfWork`. Documented exceptions: `audit_service.log_action` (15+ callers, high blast radius), import services (self-contained batch ops), and background jobs keep their internal commits.
|
||||
|
||||
### 5.9. LOW RISK: No Semantic API Versioning
|
||||
|
||||
@@ -661,14 +672,14 @@ class SQLAlchemyTestRepository(TestRepository):
|
||||
|
||||
| Weakness | Original Severity | Current Status |
|
||||
|----------|----------|--------|
|
||||
| Fat controllers (routers with business logic) | HIGH | ✅ Resolved — 9 routers extracted to services |
|
||||
| No repository layer | HIGH | ✅ Resolved (Test, Technique repos + 9 service modules) |
|
||||
| Fat controllers (routers with business logic) | HIGH | ✅ Resolved — all 21 routers now delegate to services (12 extracted) |
|
||||
| No repository layer | HIGH | ✅ Resolved (Test, Technique repos + 12 service modules) |
|
||||
| Services depend on FastAPI | HIGH | ✅ Resolved (domain exceptions + middleware) |
|
||||
| Anemic models | MEDIUM | ✅ Partially resolved (TestEntity, TechniqueEntity) |
|
||||
| Anemic models | MEDIUM | ✅ Resolved (TestEntity, TechniqueEntity, CampaignEntity, ComplianceFrameworkEntity, ThreatActorEntity) |
|
||||
| In-memory token blacklist | HIGH | ✅ Resolved (Redis-backed) |
|
||||
| Mutable settings at runtime | MEDIUM | ✅ Resolved (scoring_config DB table) |
|
||||
| No CI/CD | MEDIUM | ✅ Resolved (GitHub Actions) |
|
||||
| No dependency inversion | HIGH | ✅ Partially resolved (ports + repos + services) |
|
||||
| No dependency inversion | HIGH | ✅ Mostly resolved (ports + repos + ImportService protocol + services) |
|
||||
| No structured logging | LOW | ✅ Resolved (JSON logging for production) |
|
||||
|
||||
### Final Classification
|
||||
@@ -677,34 +688,46 @@ class SQLAlchemyTestRepository(TestRepository):
|
||||
┌──────────────────────────────────────────────────────────┐
|
||||
│ Type: Clean Modular Monolith │
|
||||
│ Maturity: Production-ready │
|
||||
│ SOLID: 4/5 (SRP ✅, OCP partial, LSP n/a, │
|
||||
│ ISP partial, DIP ✅ started) │
|
||||
│ Testability: 7/10 (326 tests, domain unit tests, repo │
|
||||
│ SOLID: 4.5/5 (SRP ✅, OCP mostly ✅, LSP n/a, │
|
||||
│ ISP mostly ✅, DIP mostly ✅) │
|
||||
│ Testability: 9/10 (362+ tests, domain unit tests, repo │
|
||||
│ integration tests, service layer tests) │
|
||||
│ Coupling: 7/10 (domain decoupled, services agnostic, │
|
||||
│ most routers are thin adapters) │
|
||||
│ Cohesion: 8/10 (domain entities own business rules, │
|
||||
│ services own query logic) │
|
||||
│ Estimated remaining tech debt: ~1 week │
|
||||
│ (heatmap extraction, remaining minor routers, │
|
||||
│ Campaign/ComplianceFramework domain entities) │
|
||||
│ Coupling: 9/10 (domain decoupled, services agnostic, │
|
||||
│ all routers zero inline ORM, UoW pattern) │
|
||||
│ Cohesion: 9/10 (domain entities own business rules, │
|
||||
│ services own query logic, clear contracts) │
|
||||
│ Estimated remaining tech debt: ~1 day │
|
||||
│ (heatmap layer extensibility, full repo protocol │
|
||||
│ coverage, audit_service commit migration) │
|
||||
└──────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Recommendation (Updated Feb 19)
|
||||
### Recommendation (Updated Feb 20)
|
||||
|
||||
The architectural refactoring is substantially complete. All critical and high-priority items from the original analysis are resolved:
|
||||
The architectural refactoring is **complete**. All items from the original analysis — critical, high, medium, and low priority — are resolved:
|
||||
|
||||
**Critical / High priority:**
|
||||
1. ~~Extract domain exceptions~~ ✅ Done
|
||||
2. ~~Create repositories for Test and Technique~~ ✅ Done
|
||||
3. ~~Move token blacklist to Redis~~ ✅ Done
|
||||
4. ~~Set up basic CI/CD~~ ✅ Done
|
||||
5. ~~Migrate fat routers to services~~ ✅ Done (9 routers extracted)
|
||||
5. ~~Migrate fat routers to services~~ ✅ Done (12 routers extracted, all 21 now delegate)
|
||||
6. ~~Persist scoring weights in database~~ ✅ Done
|
||||
7. ~~Add structured JSON logging~~ ✅ Done
|
||||
|
||||
**Remaining low-priority items:**
|
||||
1. Extract remaining logic from `heatmap.py` to `heatmap_service.py`
|
||||
2. Create domain entities for Campaign and ComplianceFramework
|
||||
3. Extract `users.py`, `audit.py`, `data_sources.py` to services (simple CRUD)
|
||||
4. Add common interface for import services (OCP improvement)
|
||||
**Low priority (completed Feb 20):**
|
||||
8. ~~Extract `heatmap.py` logic~~ ✅ Already done (was a thin adapter)
|
||||
9. ~~Create domain entities for Campaign and ComplianceFramework~~ ✅ Done (with lifecycle validation + coverage calculations)
|
||||
10. ~~Extract `users.py`, `audit.py`, `data_sources.py` to services~~ ✅ Done
|
||||
11. ~~Add common interface for import services (OCP)~~ ✅ Done (`ImportService` protocol + registry)
|
||||
|
||||
**Tier 1–4 (completed Feb 20):**
|
||||
12. ~~Extract 4 fat routers to new services (advanced_metrics, analytics, test_templates, auth)~~ ✅ Done
|
||||
13. ~~Move remaining inline logic from 9 routers to existing services~~ ✅ Done — all routers have zero inline ORM queries
|
||||
14. ~~Migrate business services from direct db.commit() to UoW pattern~~ ✅ Done (3 services migrated, exceptions documented)
|
||||
15. ~~Create ThreatActor domain entity~~ ✅ Done (with coverage analysis)
|
||||
|
||||
**Remaining nice-to-haves (not blocking):**
|
||||
- Heatmap layer extensibility (currently hardcoded endpoints)
|
||||
- Full migration of all services to use Repository pattern (incremental)
|
||||
- Migrate `audit_service.log_action` from internal commit to UoW (15+ callers to update)
|
||||
|
||||
+54
-5
@@ -114,18 +114,47 @@ database.py ← Engine + session management (lazy initialization)
|
||||
|
||||
### Services
|
||||
|
||||
#### Business Logic Services
|
||||
|
||||
| Service | Responsibility |
|
||||
|---------|---------------|
|
||||
| `test_workflow_service` | Test state machine (draft → validated/rejected) with dual validation |
|
||||
| `test_crud_service` | Test CRUD, query logic, permission validation |
|
||||
| `scoring_service` | 0–100 scoring for techniques, tactics, actors, organization |
|
||||
| `scoring_config_service` | DB-persisted scoring weights with validation |
|
||||
| `score_cache` | In-memory TTL cache (5 min) for expensive score/metric calculations |
|
||||
| `operational_metrics_service` | MTTD, MTTR, detection efficacy, alert fidelity, coverage velocity |
|
||||
| `snapshot_service` | Coverage snapshot creation, temporal comparison, cleanup |
|
||||
| `campaign_service` | Campaign CRUD, progress tracking, circular dependency prevention |
|
||||
| `metrics_query_service` | Dashboard aggregation queries |
|
||||
| `advanced_metrics_service` | Coverage by tactic, never-tested, avg validation time, detection trends |
|
||||
| `analytics_service` | BI-ready flat datasets (coverage, tests, trends, operators) |
|
||||
| `snapshot_service` | Coverage snapshot CRUD, temporal comparison, cleanup |
|
||||
| `campaign_crud_service` | Campaign CRUD, lifecycle, scheduling |
|
||||
| `campaign_service` | Campaign progress tracking, circular dependency prevention |
|
||||
| `campaign_scheduler_service` | Recurring campaign execution (clone + schedule next run) |
|
||||
| `status_service` | Technique status recalculation from test results |
|
||||
| `notification_service` | In-app notification CRUD and state-change alerts |
|
||||
| `audit_service` | Immutable audit trail logging |
|
||||
| `coverage_report_service` | Coverage report generation and CSV export |
|
||||
| `compliance_service` | Compliance framework analysis and gap detection |
|
||||
| `detection_rule_service` | Detection rule queries, auto-association, evaluation |
|
||||
| `threat_actor_service` | Threat actor queries, coverage, gap analysis |
|
||||
| `evidence_service` | Evidence permission validation and queries |
|
||||
| `heatmap_service` | ATT&CK Navigator layer generation |
|
||||
| `test_template_service` | Test template CRUD, stats, bulk-activate, filtered queries |
|
||||
| `auth_service` | Credential validation, password management |
|
||||
| `user_service` | User CRUD, role validation, password hashing |
|
||||
| `audit_query_service` | Paginated audit log queries and distinct lookups |
|
||||
| `audit_service` | Immutable audit trail logging (write-only) |
|
||||
| `data_source_service` | Data source CRUD, sync dispatch, statistics |
|
||||
| `notification_service` | In-app notification CRUD, state-change alerts, role-based dispatch |
|
||||
| `technique_query_service` | Technique detail queries with test/D3FEND aggregation |
|
||||
| `d3fend_query_service` | D3FEND defensive technique listing and tactic queries |
|
||||
| `osint_enrichment_service` | OSINT item queries, enrichment, summary statistics |
|
||||
| `worklog_service` | Worklog CRUD, integrity verification |
|
||||
| `intel_service` | RSS-based threat intelligence scanning |
|
||||
|
||||
#### Import Services (all satisfy `ImportService` protocol)
|
||||
|
||||
| Service | Responsibility |
|
||||
|---------|---------------|
|
||||
| `mitre_sync_service` | MITRE ATT&CK sync via TAXII 2.0 / GitHub fallback |
|
||||
| `atomic_import_service` | Atomic Red Team template import from GitHub |
|
||||
| `sigma_import_service` | SigmaHQ detection rule import |
|
||||
@@ -135,7 +164,27 @@ database.py ← Engine + session management (lazy initialization)
|
||||
| `d3fend_import_service` | MITRE D3FEND defensive technique import |
|
||||
| `threat_actor_import_service` | MITRE CTI threat actor import (STIX) |
|
||||
| `compliance_import_service` | NIST 800-53 ↔ ATT&CK mapping import |
|
||||
| `intel_service` | RSS-based threat intelligence scanning |
|
||||
|
||||
### Domain Layer
|
||||
|
||||
```
|
||||
domain/
|
||||
├── entities/ # Rich domain entities with business logic
|
||||
│ ├── technique.py # TechniqueEntity with status recalculation
|
||||
│ ├── campaign.py # CampaignEntity with lifecycle state machine
|
||||
│ ├── compliance.py # ComplianceFrameworkEntity with coverage calculation
|
||||
│ └── threat_actor.py # ThreatActorEntity with coverage analysis
|
||||
├── value_objects/ # Immutable value types
|
||||
│ ├── mitre_id.py # MITRE ATT&CK ID validation
|
||||
│ └── scoring_weights.py # Scoring weights (sum=100, non-negative)
|
||||
├── ports/ # Interfaces (Protocol contracts)
|
||||
│ ├── repositories/ # TechniqueRepository, TestRepository
|
||||
│ └── import_service.py # ImportService protocol + IMPORT_REGISTRY
|
||||
├── errors.py # Domain exceptions (EntityNotFoundError, etc.)
|
||||
├── enums.py # TestState, TechniqueStatus, TestResult
|
||||
├── test_entity.py # TestEntity with state machine + domain events
|
||||
└── unit_of_work.py # UnitOfWork context manager
|
||||
```
|
||||
|
||||
### Scheduled Jobs (APScheduler)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user