Compare commits
8 Commits
44621364be
...
0c526c48f9
| Author | SHA1 | Date | |
|---|---|---|---|
| 0c526c48f9 | |||
| 0d211d5156 | |||
| 14d995b40c | |||
| 339d669498 | |||
| 9e22fde746 | |||
| bbc2dddd86 | |||
| d77075272e | |||
| c0c6cda11d |
@@ -1,3 +1,18 @@
|
|||||||
|
from app.domain.entities.campaign import CampaignEntity
|
||||||
|
from app.domain.entities.compliance import (
|
||||||
|
ComplianceControlEntity,
|
||||||
|
ComplianceFrameworkEntity,
|
||||||
|
ControlCoverageStatus,
|
||||||
|
)
|
||||||
from app.domain.entities.technique import TechniqueEntity
|
from app.domain.entities.technique import TechniqueEntity
|
||||||
|
from app.domain.entities.threat_actor import ThreatActorEntity, ThreatActorTechniqueRef
|
||||||
|
|
||||||
__all__ = ["TechniqueEntity"]
|
__all__ = [
|
||||||
|
"CampaignEntity",
|
||||||
|
"ComplianceControlEntity",
|
||||||
|
"ComplianceFrameworkEntity",
|
||||||
|
"ControlCoverageStatus",
|
||||||
|
"TechniqueEntity",
|
||||||
|
"ThreatActorEntity",
|
||||||
|
"ThreatActorTechniqueRef",
|
||||||
|
]
|
||||||
|
|||||||
@@ -0,0 +1,103 @@
|
|||||||
|
"""Campaign domain entity with lifecycle validation.
|
||||||
|
|
||||||
|
Pure domain logic — no framework imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignStatus(str, enum.Enum):
|
||||||
|
draft = "draft"
|
||||||
|
active = "active"
|
||||||
|
completed = "completed"
|
||||||
|
archived = "archived"
|
||||||
|
|
||||||
|
|
||||||
|
class CampaignType(str, enum.Enum):
|
||||||
|
custom = "custom"
|
||||||
|
apt_emulation = "apt_emulation"
|
||||||
|
kill_chain = "kill_chain"
|
||||||
|
compliance = "compliance"
|
||||||
|
|
||||||
|
|
||||||
|
VALID_TRANSITIONS: dict[CampaignStatus, list[CampaignStatus]] = {
|
||||||
|
CampaignStatus.draft: [CampaignStatus.active],
|
||||||
|
CampaignStatus.active: [CampaignStatus.completed],
|
||||||
|
CampaignStatus.completed: [CampaignStatus.archived],
|
||||||
|
CampaignStatus.archived: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CampaignEntity:
|
||||||
|
name: str
|
||||||
|
type: CampaignType = CampaignType.custom
|
||||||
|
status: CampaignStatus = CampaignStatus.draft
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
description: str | None = None
|
||||||
|
threat_actor_id: uuid.UUID | None = None
|
||||||
|
created_by: uuid.UUID | None = None
|
||||||
|
target_platform: str | None = None
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
test_count: int = 0
|
||||||
|
|
||||||
|
def can_transition_to(self, target: CampaignStatus) -> bool:
|
||||||
|
return target in VALID_TRANSITIONS.get(self.status, [])
|
||||||
|
|
||||||
|
def activate(self) -> None:
|
||||||
|
if not self.can_transition_to(CampaignStatus.active):
|
||||||
|
raise InvalidStateTransition(
|
||||||
|
self.status.value, CampaignStatus.active.value,
|
||||||
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
|
)
|
||||||
|
if self.test_count == 0:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
"Campaign must have at least one test to activate"
|
||||||
|
)
|
||||||
|
self.status = CampaignStatus.active
|
||||||
|
|
||||||
|
def complete(self) -> None:
|
||||||
|
if not self.can_transition_to(CampaignStatus.completed):
|
||||||
|
raise InvalidStateTransition(
|
||||||
|
self.status.value, CampaignStatus.completed.value,
|
||||||
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
|
)
|
||||||
|
self.status = CampaignStatus.completed
|
||||||
|
|
||||||
|
def archive(self) -> None:
|
||||||
|
if not self.can_transition_to(CampaignStatus.archived):
|
||||||
|
raise InvalidStateTransition(
|
||||||
|
self.status.value, CampaignStatus.archived.value,
|
||||||
|
[s.value for s in VALID_TRANSITIONS[self.status]],
|
||||||
|
)
|
||||||
|
self.status = CampaignStatus.archived
|
||||||
|
|
||||||
|
def ensure_modifiable(self) -> None:
|
||||||
|
if self.status not in (CampaignStatus.draft, CampaignStatus.active):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot modify campaign in '{self.status.value}' state"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_orm(cls, orm: Any) -> CampaignEntity:
|
||||||
|
"""Build a CampaignEntity from a SQLAlchemy Campaign model."""
|
||||||
|
test_count = len(getattr(orm, "campaign_tests", None) or [])
|
||||||
|
return cls(
|
||||||
|
id=orm.id,
|
||||||
|
name=orm.name,
|
||||||
|
type=CampaignType(orm.type) if orm.type else CampaignType.custom,
|
||||||
|
status=CampaignStatus(orm.status) if orm.status else CampaignStatus.draft,
|
||||||
|
description=orm.description,
|
||||||
|
threat_actor_id=orm.threat_actor_id,
|
||||||
|
created_by=orm.created_by,
|
||||||
|
target_platform=orm.target_platform,
|
||||||
|
tags=orm.tags or [],
|
||||||
|
test_count=test_count,
|
||||||
|
)
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""Compliance domain entities with coverage calculation logic.
|
||||||
|
|
||||||
|
Pure domain logic — no framework imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
class ControlCoverageStatus(str, enum.Enum):
|
||||||
|
covered = "covered"
|
||||||
|
partially_covered = "partially_covered"
|
||||||
|
not_covered = "not_covered"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComplianceControlEntity:
|
||||||
|
control_id: str
|
||||||
|
title: str
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
description: str | None = None
|
||||||
|
category: str | None = None
|
||||||
|
technique_statuses: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coverage_status(self) -> ControlCoverageStatus:
|
||||||
|
if not self.technique_statuses:
|
||||||
|
return ControlCoverageStatus.not_covered
|
||||||
|
covered_statuses = {"validated", "partial"}
|
||||||
|
covered = [s for s in self.technique_statuses if s in covered_statuses]
|
||||||
|
if len(covered) == len(self.technique_statuses):
|
||||||
|
return ControlCoverageStatus.covered
|
||||||
|
elif len(covered) > 0:
|
||||||
|
return ControlCoverageStatus.partially_covered
|
||||||
|
return ControlCoverageStatus.not_covered
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComplianceFrameworkEntity:
|
||||||
|
name: str
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
version: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
is_active: bool = True
|
||||||
|
controls: list[ComplianceControlEntity] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_controls(self) -> int:
|
||||||
|
return len(self.controls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def covered_controls(self) -> int:
|
||||||
|
return sum(
|
||||||
|
1 for c in self.controls
|
||||||
|
if c.coverage_status == ControlCoverageStatus.covered
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coverage_pct(self) -> float:
|
||||||
|
if self.total_controls == 0:
|
||||||
|
return 0.0
|
||||||
|
return round(self.covered_controls / self.total_controls * 100, 1)
|
||||||
|
|
||||||
|
def get_gap_controls(self) -> list[ComplianceControlEntity]:
|
||||||
|
return [
|
||||||
|
c for c in self.controls
|
||||||
|
if c.coverage_status != ControlCoverageStatus.covered
|
||||||
|
]
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
"""Threat actor domain entity with coverage analysis logic.
|
||||||
|
|
||||||
|
Pure domain logic — no framework imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThreatActorTechniqueRef:
|
||||||
|
"""Lightweight reference to a technique used by an actor."""
|
||||||
|
|
||||||
|
technique_id: uuid.UUID
|
||||||
|
mitre_id: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
usage_description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThreatActorEntity:
|
||||||
|
name: str
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
mitre_id: str | None = None
|
||||||
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
description: str | None = None
|
||||||
|
country: str | None = None
|
||||||
|
target_sectors: list[str] = field(default_factory=list)
|
||||||
|
target_regions: list[str] = field(default_factory=list)
|
||||||
|
motivation: str | None = None
|
||||||
|
sophistication: str | None = None
|
||||||
|
first_seen: str | None = None
|
||||||
|
last_seen: str | None = None
|
||||||
|
is_active: bool = True
|
||||||
|
techniques: list[ThreatActorTechniqueRef] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def technique_count(self) -> int:
|
||||||
|
return len(self.techniques)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def covered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||||
|
return [
|
||||||
|
t for t in self.techniques
|
||||||
|
if t.status in ("validated", "partial")
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uncovered_techniques(self) -> list[ThreatActorTechniqueRef]:
|
||||||
|
return [
|
||||||
|
t for t in self.techniques
|
||||||
|
if t.status not in ("validated", "partial")
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coverage_pct(self) -> float:
|
||||||
|
if not self.techniques:
|
||||||
|
return 0.0
|
||||||
|
return round(len(self.covered_techniques) / len(self.techniques) * 100, 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_orm(cls, orm: Any) -> ThreatActorEntity:
|
||||||
|
techs: list[ThreatActorTechniqueRef] = []
|
||||||
|
for tat in getattr(orm, "techniques", None) or []:
|
||||||
|
technique = getattr(tat, "technique", None)
|
||||||
|
techs.append(ThreatActorTechniqueRef(
|
||||||
|
technique_id=tat.technique_id,
|
||||||
|
mitre_id=getattr(technique, "mitre_id", None) if technique else None,
|
||||||
|
name=getattr(technique, "name", None) if technique else None,
|
||||||
|
status=(
|
||||||
|
technique.status_global.value
|
||||||
|
if technique and hasattr(technique.status_global, "value")
|
||||||
|
else getattr(technique, "status_global", None) if technique else None
|
||||||
|
),
|
||||||
|
usage_description=tat.usage_description,
|
||||||
|
))
|
||||||
|
return cls(
|
||||||
|
id=orm.id,
|
||||||
|
name=orm.name,
|
||||||
|
mitre_id=orm.mitre_id,
|
||||||
|
aliases=orm.aliases or [],
|
||||||
|
description=orm.description,
|
||||||
|
country=orm.country,
|
||||||
|
target_sectors=orm.target_sectors or [],
|
||||||
|
target_regions=orm.target_regions or [],
|
||||||
|
motivation=orm.motivation,
|
||||||
|
sophistication=orm.sophistication,
|
||||||
|
first_seen=orm.first_seen,
|
||||||
|
last_seen=orm.last_seen,
|
||||||
|
is_active=orm.is_active if orm.is_active is not None else True,
|
||||||
|
techniques=techs,
|
||||||
|
)
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
"""Port defining the common interface for data import services.
|
||||||
|
|
||||||
|
All import services (Atomic Red Team, Sigma, CALDERA, etc.) follow the
|
||||||
|
same contract: they receive a database session and return a summary dict
|
||||||
|
with import statistics.
|
||||||
|
|
||||||
|
New import sources can be added by:
|
||||||
|
1. Implementing the ``ImportService`` protocol in a new module
|
||||||
|
2. Registering the handler in the ``IMPORT_REGISTRY``
|
||||||
|
|
||||||
|
This satisfies the Open/Closed Principle — the system is open for new
|
||||||
|
import sources without modifying existing code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ImportService(Protocol):
|
||||||
|
"""Contract for any data-import operation.
|
||||||
|
|
||||||
|
Each implementation is a callable ``(Session) -> dict`` that
|
||||||
|
downloads, parses, and upserts records from an external source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, db: Session) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ImportServiceEntry:
|
||||||
|
"""Lazy-loading wrapper that resolves a module-level function on first call."""
|
||||||
|
|
||||||
|
__slots__ = ("_module_path", "_func_name", "_resolved")
|
||||||
|
|
||||||
|
def __init__(self, module_path: str, func_name: str) -> None:
|
||||||
|
self._module_path = module_path
|
||||||
|
self._func_name = func_name
|
||||||
|
self._resolved: ImportService | None = None
|
||||||
|
|
||||||
|
def __call__(self, db: Session) -> dict[str, Any]:
|
||||||
|
if self._resolved is None:
|
||||||
|
import importlib
|
||||||
|
mod = importlib.import_module(self._module_path)
|
||||||
|
self._resolved = getattr(mod, self._func_name)
|
||||||
|
return self._resolved(db)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_info(self) -> str:
|
||||||
|
return f"{self._module_path}.{self._func_name}"
|
||||||
|
|
||||||
|
|
||||||
|
IMPORT_REGISTRY: dict[str, ImportServiceEntry] = {
|
||||||
|
"atomic_red_team": ImportServiceEntry(
|
||||||
|
"app.services.atomic_import_service", "import_atomic_red_team",
|
||||||
|
),
|
||||||
|
"sigma": ImportServiceEntry(
|
||||||
|
"app.services.sigma_import_service", "sync",
|
||||||
|
),
|
||||||
|
"lolbas": ImportServiceEntry(
|
||||||
|
"app.services.lolbas_import_service", "sync",
|
||||||
|
),
|
||||||
|
"gtfobins": ImportServiceEntry(
|
||||||
|
"app.services.lolbas_import_service", "sync_gtfobins",
|
||||||
|
),
|
||||||
|
"caldera": ImportServiceEntry(
|
||||||
|
"app.services.caldera_import_service", "sync",
|
||||||
|
),
|
||||||
|
"elastic_rules": ImportServiceEntry(
|
||||||
|
"app.services.elastic_import_service", "sync",
|
||||||
|
),
|
||||||
|
"mitre_cti": ImportServiceEntry(
|
||||||
|
"app.services.threat_actor_import_service", "sync",
|
||||||
|
),
|
||||||
|
"d3fend": ImportServiceEntry(
|
||||||
|
"app.services.d3fend_import_service", "sync",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_import_handler(source_name: str) -> ImportServiceEntry | None:
|
||||||
|
"""Look up the import handler for *source_name*.
|
||||||
|
|
||||||
|
Returns ``None`` when no handler is registered.
|
||||||
|
"""
|
||||||
|
return IMPORT_REGISTRY.get(source_name)
|
||||||
@@ -10,6 +10,16 @@ Usage in routers::
|
|||||||
If an exception propagates, ``__exit__`` issues a rollback automatically.
|
If an exception propagates, ``__exit__`` issues a rollback automatically.
|
||||||
Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
Services should **never** call ``db.commit()``; they use ``db.add()`` /
|
||||||
``db.flush()`` to stage work and let the caller decide when to commit.
|
``db.flush()`` to stage work and let the caller decide when to commit.
|
||||||
|
|
||||||
|
**Documented exceptions** (services that may commit internally):
|
||||||
|
- ``audit_service.log_action`` — called from 15+ routers; commits to ensure
|
||||||
|
audit records persist even when callers do not.
|
||||||
|
- Import services (atomic_import, sigma_import, etc.) — self-contained sync ops.
|
||||||
|
- Background jobs (campaign_scheduler, intel_service, stale_detection,
|
||||||
|
mitre_sync) — self-contained operations.
|
||||||
|
- Self-contained batch ops (e.g. detection_rule_service.auto_associate_rules,
|
||||||
|
snapshot_service.create_snapshot, campaign_service.generate_campaign_from_*,
|
||||||
|
osint_enrichment_service.enrich_technique_with_cves).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|||||||
@@ -1,17 +1,12 @@
|
|||||||
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
|
"""Advanced metrics endpoints — coverage by tactic, never-tested, avg validation time."""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy import func, case
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
from app.models.audit import AuditLog
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.services import advanced_metrics_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
|
||||||
|
|
||||||
@@ -22,39 +17,7 @@ def coverage_by_tactic(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||||
results = (
|
return advanced_metrics_service.get_coverage_by_tactic(db)
|
||||||
db.query(
|
|
||||||
Technique.tactic,
|
|
||||||
func.count(Technique.id).label("total"),
|
|
||||||
func.sum(
|
|
||||||
case((Technique.status_global == "validated", 1), else_=0)
|
|
||||||
).label("validated"),
|
|
||||||
func.sum(
|
|
||||||
case((Technique.status_global == "partial", 1), else_=0)
|
|
||||||
).label("partial"),
|
|
||||||
func.sum(
|
|
||||||
case((Technique.status_global == "not_covered", 1), else_=0)
|
|
||||||
).label("not_covered"),
|
|
||||||
func.sum(
|
|
||||||
case((Technique.status_global == "in_progress", 1), else_=0)
|
|
||||||
).label("in_progress"),
|
|
||||||
)
|
|
||||||
.group_by(Technique.tactic)
|
|
||||||
.order_by(Technique.tactic)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"tactic": r[0] or "Unknown",
|
|
||||||
"total": r[1],
|
|
||||||
"validated": int(r[2]),
|
|
||||||
"partial": int(r[3]),
|
|
||||||
"not_covered": int(r[4]),
|
|
||||||
"in_progress": int(r[5]),
|
|
||||||
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
|
||||||
}
|
|
||||||
for r in results
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/never-tested")
|
@router.get("/never-tested")
|
||||||
@@ -63,24 +26,7 @@ def never_tested_techniques(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Techniques that have never had a test created."""
|
"""Techniques that have never had a test created."""
|
||||||
tested_technique_ids = (
|
return advanced_metrics_service.get_never_tested_techniques(db)
|
||||||
db.query(Test.technique_id).distinct().subquery()
|
|
||||||
)
|
|
||||||
techniques = (
|
|
||||||
db.query(Technique)
|
|
||||||
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
|
|
||||||
.order_by(Technique.mitre_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"mitre_id": t.mitre_id,
|
|
||||||
"name": t.name,
|
|
||||||
"tactic": t.tactic,
|
|
||||||
"is_subtechnique": t.is_subtechnique,
|
|
||||||
}
|
|
||||||
for t in techniques
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/avg-validation-time")
|
@router.get("/avg-validation-time")
|
||||||
@@ -92,50 +38,7 @@ def avg_validation_time(
|
|||||||
|
|
||||||
Returns overall average and per-phase averages where data is available.
|
Returns overall average and per-phase averages where data is available.
|
||||||
"""
|
"""
|
||||||
validated_tests = (
|
return advanced_metrics_service.get_avg_validation_time(db)
|
||||||
db.query(Test)
|
|
||||||
.filter(Test.state == "validated")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not validated_tests:
|
|
||||||
return {
|
|
||||||
"total_validated": 0,
|
|
||||||
"avg_total_hours": 0,
|
|
||||||
"avg_red_phase_hours": 0,
|
|
||||||
"avg_blue_phase_hours": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
total_durations = []
|
|
||||||
red_durations = []
|
|
||||||
blue_durations = []
|
|
||||||
|
|
||||||
for test in validated_tests:
|
|
||||||
if test.created_at and test.red_validated_at:
|
|
||||||
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
|
|
||||||
total_durations.append(total_seconds)
|
|
||||||
|
|
||||||
if test.red_started_at and test.blue_started_at:
|
|
||||||
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
|
|
||||||
red_paused = test.red_paused_seconds or 0
|
|
||||||
red_durations.append(max(red_sec - red_paused, 0))
|
|
||||||
|
|
||||||
if test.blue_started_at and test.blue_validated_at:
|
|
||||||
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
|
|
||||||
blue_paused = test.blue_paused_seconds or 0
|
|
||||||
blue_durations.append(max(blue_sec - blue_paused, 0))
|
|
||||||
|
|
||||||
def avg_hours(durations: list[float]) -> float:
|
|
||||||
if not durations:
|
|
||||||
return 0
|
|
||||||
return round(sum(durations) / len(durations) / 3600, 2)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_validated": len(validated_tests),
|
|
||||||
"avg_total_hours": avg_hours(total_durations),
|
|
||||||
"avg_red_phase_hours": avg_hours(red_durations),
|
|
||||||
"avg_blue_phase_hours": avg_hours(blue_durations),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/detection-rate-trend")
|
@router.get("/detection-rate-trend")
|
||||||
@@ -144,41 +47,4 @@ def detection_rate_trend(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Monthly detection rate trend for the last 12 months."""
|
"""Monthly detection rate trend for the last 12 months."""
|
||||||
from datetime import timedelta
|
return advanced_metrics_service.get_detection_rate_trend(db)
|
||||||
|
|
||||||
now = datetime.utcnow()
|
|
||||||
months = []
|
|
||||||
|
|
||||||
for i in range(11, -1, -1):
|
|
||||||
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
|
|
||||||
month_end = month_start + timedelta(days=30)
|
|
||||||
|
|
||||||
validated = (
|
|
||||||
db.query(func.count(Test.id))
|
|
||||||
.filter(
|
|
||||||
Test.state == "validated",
|
|
||||||
Test.created_at >= month_start,
|
|
||||||
Test.created_at < month_end,
|
|
||||||
)
|
|
||||||
.scalar() or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
detected = (
|
|
||||||
db.query(func.count(Test.id))
|
|
||||||
.filter(
|
|
||||||
Test.state == "validated",
|
|
||||||
Test.detection_result == "detected",
|
|
||||||
Test.created_at >= month_start,
|
|
||||||
Test.created_at < month_end,
|
|
||||||
)
|
|
||||||
.scalar() or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
months.append({
|
|
||||||
"month": month_start.strftime("%Y-%m"),
|
|
||||||
"validated": validated,
|
|
||||||
"detected": detected,
|
|
||||||
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
return months
|
|
||||||
|
|||||||
@@ -5,15 +5,12 @@ directly from URL. All endpoints require authentication.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy import func
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.coverage_snapshot import CoverageSnapshot
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.services import analytics_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||||
|
|
||||||
@@ -24,22 +21,7 @@ def analytics_coverage(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Coverage per technique — flat format for BI dashboards."""
|
"""Coverage per technique — flat format for BI dashboards."""
|
||||||
techniques = db.query(Technique).all()
|
return analytics_service.get_coverage_analytics(db)
|
||||||
return [
|
|
||||||
{
|
|
||||||
"mitre_id": t.mitre_id,
|
|
||||||
"name": t.name,
|
|
||||||
"tactic": t.tactic,
|
|
||||||
"status": t.status_global.value if t.status_global else "not_evaluated",
|
|
||||||
"is_subtechnique": t.is_subtechnique,
|
|
||||||
"test_count": len(t.tests) if t.tests else 0,
|
|
||||||
"review_required": t.review_required,
|
|
||||||
"last_review_date": (
|
|
||||||
t.last_review_date.isoformat() if t.last_review_date else None
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for t in techniques
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tests")
|
@router.get("/tests")
|
||||||
@@ -50,34 +32,9 @@ def analytics_tests(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""All tests with timestamps — flat format for BI dashboards."""
|
"""All tests with timestamps — flat format for BI dashboards."""
|
||||||
query = db.query(Test)
|
return analytics_service.get_tests_analytics(
|
||||||
if date_from:
|
db, date_from=date_from, date_to=date_to
|
||||||
query = query.filter(Test.created_at >= date_from)
|
)
|
||||||
if date_to:
|
|
||||||
query = query.filter(Test.created_at <= date_to)
|
|
||||||
tests = query.all()
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": str(t.id),
|
|
||||||
"technique_id": str(t.technique_id),
|
|
||||||
"name": t.name,
|
|
||||||
"state": t.state.value if t.state else None,
|
|
||||||
"result": t.result.value if t.result else None,
|
|
||||||
"detection_result": (
|
|
||||||
t.detection_result.value if t.detection_result else None
|
|
||||||
),
|
|
||||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
|
||||||
"execution_date": (
|
|
||||||
t.execution_date.isoformat() if t.execution_date else None
|
|
||||||
),
|
|
||||||
"platform": t.platform,
|
|
||||||
"tool_used": t.tool_used,
|
|
||||||
"attack_success": t.attack_success,
|
|
||||||
"remediation_status": t.remediation_status,
|
|
||||||
}
|
|
||||||
for t in tests
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/trends")
|
@router.get("/trends")
|
||||||
@@ -86,23 +43,7 @@ def analytics_trends(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Historical coverage snapshots for trend visualization."""
|
"""Historical coverage snapshots for trend visualization."""
|
||||||
snapshots = (
|
return analytics_service.get_trends_analytics(db)
|
||||||
db.query(CoverageSnapshot)
|
|
||||||
.order_by(CoverageSnapshot.created_at)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"date": s.created_at.isoformat() if s.created_at else None,
|
|
||||||
"name": s.name,
|
|
||||||
"total_techniques": s.total_techniques,
|
|
||||||
"validated_count": s.validated_count,
|
|
||||||
"partial_count": s.partial_count,
|
|
||||||
"not_covered_count": s.not_covered_count,
|
|
||||||
"organization_score": s.organization_score,
|
|
||||||
}
|
|
||||||
for s in snapshots
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/operators")
|
@router.get("/operators")
|
||||||
@@ -111,17 +52,4 @@ def analytics_operators(
|
|||||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Per-operator metrics — for workload management dashboards."""
|
"""Per-operator metrics — for workload management dashboards."""
|
||||||
results = (
|
return analytics_service.get_operators_analytics(db)
|
||||||
db.query(
|
|
||||||
User.username,
|
|
||||||
User.role,
|
|
||||||
func.count(Test.id).label("test_count"),
|
|
||||||
)
|
|
||||||
.outerjoin(Test, Test.created_by == User.id)
|
|
||||||
.group_by(User.id, User.username, User.role)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
{"username": r[0], "role": r[1], "test_count": r[2]}
|
|
||||||
for r in results
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -4,14 +4,17 @@ from datetime import datetime
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy import func
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.orm import Session, joinedload
|
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
from app.models.audit import AuditLog
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.audit import AuditLogOut, AuditLogPage
|
from app.schemas.audit import AuditLogOut, AuditLogPage
|
||||||
|
from app.services.audit_query_service import (
|
||||||
|
list_distinct_actions,
|
||||||
|
list_distinct_entity_types,
|
||||||
|
list_logs,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
router = APIRouter(prefix="/audit-logs", tags=["audit"])
|
||||||
|
|
||||||
@@ -32,53 +35,22 @@ def list_audit_logs(
|
|||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
result = list_logs(
|
||||||
|
db,
|
||||||
# Apply filters
|
user_id=user_id,
|
||||||
if user_id:
|
action=action,
|
||||||
query = query.filter(AuditLog.user_id == user_id)
|
entity_type=entity_type,
|
||||||
if action:
|
start_date=start_date,
|
||||||
query = query.filter(AuditLog.action == action)
|
end_date=end_date,
|
||||||
if entity_type:
|
|
||||||
query = query.filter(AuditLog.entity_type == entity_type)
|
|
||||||
if start_date:
|
|
||||||
query = query.filter(AuditLog.timestamp >= start_date)
|
|
||||||
if end_date:
|
|
||||||
query = query.filter(AuditLog.timestamp <= end_date)
|
|
||||||
|
|
||||||
# Get total count
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
# Get paginated results
|
|
||||||
logs = (
|
|
||||||
query
|
|
||||||
.order_by(AuditLog.timestamp.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to response format with username
|
|
||||||
items = []
|
|
||||||
for log in logs:
|
|
||||||
item = AuditLogOut(
|
|
||||||
id=log.id,
|
|
||||||
user_id=log.user_id,
|
|
||||||
username=log.user.username if log.user else None,
|
|
||||||
action=log.action,
|
|
||||||
entity_type=log.entity_type,
|
|
||||||
entity_id=log.entity_id,
|
|
||||||
timestamp=log.timestamp,
|
|
||||||
details=log.details,
|
|
||||||
)
|
|
||||||
items.append(item)
|
|
||||||
|
|
||||||
return AuditLogPage(
|
|
||||||
items=items,
|
|
||||||
total=total,
|
|
||||||
offset=offset,
|
offset=offset,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
return AuditLogPage(
|
||||||
|
items=[AuditLogOut(**item) for item in result["items"]],
|
||||||
|
total=result["total"],
|
||||||
|
offset=result["offset"],
|
||||||
|
limit=result["limit"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/actions", response_model=list[str])
|
@router.get("/actions", response_model=list[str])
|
||||||
@@ -90,13 +62,7 @@ def list_actions(
|
|||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
actions = (
|
return list_distinct_actions(db)
|
||||||
db.query(AuditLog.action)
|
|
||||||
.distinct()
|
|
||||||
.order_by(AuditLog.action)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [a[0] for a in actions]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/entity-types", response_model=list[str])
|
@router.get("/entity-types", response_model=list[str])
|
||||||
@@ -108,11 +74,4 @@ def list_entity_types(
|
|||||||
|
|
||||||
**Requires admin role.**
|
**Requires admin role.**
|
||||||
"""
|
"""
|
||||||
types = (
|
return list_distinct_entity_types(db)
|
||||||
db.query(AuditLog.entity_type)
|
|
||||||
.filter(AuditLog.entity_type.isnot(None))
|
|
||||||
.distinct()
|
|
||||||
.order_by(AuditLog.entity_type)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [t[0] for t in types]
|
|
||||||
|
|||||||
+15
-28
@@ -9,7 +9,7 @@ cannot use cookies (e.g. Swagger UI).
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Cookie, Depends, Request, Response
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
@@ -17,11 +17,13 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from jose import jwt, JWTError
|
from jose import jwt, JWTError
|
||||||
|
|
||||||
from app.auth import verify_password, hash_password, create_access_token, blacklist_token
|
from app.auth import create_access_token, blacklist_token
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.services.auth_service import authenticate_user, change_password as auth_change_password
|
||||||
from app.schemas.auth import TokenResponse, UserOut
|
from app.schemas.auth import TokenResponse, UserOut
|
||||||
from app.schemas.user import PasswordChange
|
from app.schemas.user import PasswordChange
|
||||||
|
|
||||||
@@ -56,24 +58,10 @@ def login(
|
|||||||
attacks. The token is set as an HttpOnly cookie **and** returned in the
|
attacks. The token is set as an HttpOnly cookie **and** returned in the
|
||||||
JSON body for API/Swagger compatibility.
|
JSON body for API/Swagger compatibility.
|
||||||
"""
|
"""
|
||||||
user = db.query(User).filter(User.username == form_data.username).first()
|
user = authenticate_user(
|
||||||
|
db,
|
||||||
# Constant-time comparison: always run bcrypt verify to prevent
|
username=form_data.username,
|
||||||
# timing-based user enumeration (SEC-005).
|
password=form_data.password,
|
||||||
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
|
|
||||||
hashed = user.hashed_password if user else _DUMMY_HASH
|
|
||||||
password_valid = verify_password(form_data.password, hashed)
|
|
||||||
|
|
||||||
if user is None or not password_valid:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Incorrect username or password",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user.is_active:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Account is disabled. Contact an administrator.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = create_access_token(data={"sub": user.username})
|
access_token = create_access_token(data={"sub": user.username})
|
||||||
@@ -163,14 +151,13 @@ def change_password(
|
|||||||
``must_change_password`` flag is cleared so the user can proceed
|
``must_change_password`` flag is cleared so the user can proceed
|
||||||
normally.
|
normally.
|
||||||
"""
|
"""
|
||||||
if not verify_password(body.current_password, current_user.hashed_password):
|
auth_change_password(
|
||||||
raise HTTPException(
|
db,
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
current_user,
|
||||||
detail="Current password is incorrect",
|
current_password=body.current_password,
|
||||||
|
new_password=body.new_password,
|
||||||
)
|
)
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
current_user.hashed_password = hash_password(body.new_password)
|
uow.commit()
|
||||||
current_user.must_change_password = False
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return {"detail": "Password changed successfully"}
|
return {"detail": "Password changed successfully"}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from app.services.campaign_crud_service import (
|
|||||||
serialize_campaign,
|
serialize_campaign,
|
||||||
update_campaign as crud_update,
|
update_campaign as crud_update,
|
||||||
)
|
)
|
||||||
from app.services.notification_service import create_notification
|
from app.services.notification_service import notify_role
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -237,11 +237,9 @@ def activate_campaign(
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
notify_role(
|
||||||
for user in red_techs:
|
|
||||||
create_notification(
|
|
||||||
db,
|
db,
|
||||||
user_id=user.id,
|
role="red_tech",
|
||||||
type="campaign_activated",
|
type="campaign_activated",
|
||||||
title="Campaign activated",
|
title="Campaign activated",
|
||||||
message=f'Campaign "{campaign.name}" has been activated.',
|
message=f'Campaign "{campaign.name}" has been activated.',
|
||||||
|
|||||||
@@ -3,18 +3,20 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
|
||||||
from app.services.d3fend_import_service import (
|
from app.services.d3fend_import_service import (
|
||||||
import_d3fend_techniques,
|
import_d3fend_techniques,
|
||||||
import_d3fend_mappings,
|
import_d3fend_mappings,
|
||||||
get_defenses_for_technique,
|
)
|
||||||
|
from app.services.d3fend_query_service import (
|
||||||
|
list_defensive_techniques as list_defensive_techniques_svc,
|
||||||
|
list_d3fend_tactics,
|
||||||
|
get_defenses_for_attack_technique,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -36,60 +38,22 @@ def list_defensive_techniques(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List all D3FEND defensive techniques with optional filters."""
|
"""List all D3FEND defensive techniques with optional filters."""
|
||||||
query = db.query(DefensiveTechnique)
|
return list_defensive_techniques_svc(
|
||||||
|
db, tactic=tactic, search=search, offset=offset, limit=limit
|
||||||
if tactic:
|
|
||||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
|
||||||
|
|
||||||
if search:
|
|
||||||
from app.utils import escape_like
|
|
||||||
pattern = f"%{escape_like(search)}%"
|
|
||||||
query = query.filter(
|
|
||||||
DefensiveTechnique.name.ilike(pattern)
|
|
||||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"offset": offset,
|
|
||||||
"limit": limit,
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"id": str(dt.id),
|
|
||||||
"d3fend_id": dt.d3fend_id,
|
|
||||||
"name": dt.name,
|
|
||||||
"description": dt.description,
|
|
||||||
"tactic": dt.tactic,
|
|
||||||
"d3fend_url": dt.d3fend_url,
|
|
||||||
}
|
|
||||||
for dt in items
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /d3fend/tactics — List all D3FEND tactics
|
# GET /d3fend/tactics — List all D3FEND tactics
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/tactics")
|
@router.get("/tactics")
|
||||||
def list_d3fend_tactics(
|
def list_d3fend_tactics_endpoint(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return a list of all D3FEND tactics with counts."""
|
"""Return a list of all D3FEND tactics with counts."""
|
||||||
from sqlalchemy import func
|
return list_d3fend_tactics(db)
|
||||||
|
|
||||||
rows = (
|
|
||||||
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
|
||||||
.group_by(DefensiveTechnique.tactic)
|
|
||||||
.order_by(DefensiveTechnique.tactic)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -97,24 +61,13 @@ def list_d3fend_tactics(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/for-technique/{mitre_id}")
|
@router.get("/for-technique/{mitre_id}")
|
||||||
def get_defenses_for_attack_technique(
|
def get_defenses_for_attack_technique_endpoint(
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
return get_defenses_for_attack_technique(db, mitre_id)
|
||||||
if not technique:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
|
|
||||||
|
|
||||||
defenses = get_defenses_for_technique(db, technique.id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mitre_id": mitre_id,
|
|
||||||
"technique_name": technique.name,
|
|
||||||
"defenses": defenses,
|
|
||||||
"total": len(defenses),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -5,19 +5,23 @@ Provides a centralized panel for managing all external data sources
|
|||||||
including sync triggers, enable/disable toggles, and statistics.
|
including sync triggers, enable/disable toggles, and statistics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
from fastapi import APIRouter, Depends
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.data_source import DataSource
|
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
from app.services.data_source_service import (
|
||||||
|
get_source_stats,
|
||||||
|
list_sources,
|
||||||
|
sync_all_sources,
|
||||||
|
sync_source,
|
||||||
|
update_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -30,41 +34,10 @@ class DataSourceUpdate(BaseModel):
|
|||||||
sync_frequency: Optional[str] = None
|
sync_frequency: Optional[str] = None
|
||||||
config: Optional[dict] = None
|
config: Optional[dict] = None
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
router = APIRouter(prefix="/data-sources", tags=["data-sources"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Sync dispatcher — maps source name → import function
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_sync_handler(source_name: str):
|
|
||||||
"""Lazily import and return the sync function for *source_name*.
|
|
||||||
|
|
||||||
We import lazily to avoid circular imports and to only load the
|
|
||||||
modules that are actually needed.
|
|
||||||
"""
|
|
||||||
handlers = {
|
|
||||||
"atomic_red_team": ("app.services.atomic_import_service", "import_atomic_red_team"),
|
|
||||||
"sigma": ("app.services.sigma_import_service", "sync"),
|
|
||||||
"lolbas": ("app.services.lolbas_import_service", "sync"),
|
|
||||||
"gtfobins": ("app.services.lolbas_import_service", "sync_gtfobins"),
|
|
||||||
"caldera": ("app.services.caldera_import_service", "sync"),
|
|
||||||
"elastic_rules": ("app.services.elastic_import_service", "sync"),
|
|
||||||
"mitre_cti": ("app.services.threat_actor_import_service", "sync"),
|
|
||||||
"d3fend": ("app.services.d3fend_import_service", "sync"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if source_name not in handlers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
module_path, func_name = handlers[source_name]
|
|
||||||
import importlib
|
|
||||||
mod = importlib.import_module(module_path)
|
|
||||||
return getattr(mod, func_name)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Endpoints
|
# Endpoints
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -79,25 +52,7 @@ def list_data_sources(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
sources = db.query(DataSource).order_by(DataSource.name).all()
|
return list_sources(db)
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": str(s.id),
|
|
||||||
"name": s.name,
|
|
||||||
"display_name": s.display_name,
|
|
||||||
"type": s.type,
|
|
||||||
"url": s.url,
|
|
||||||
"description": s.description,
|
|
||||||
"is_enabled": s.is_enabled,
|
|
||||||
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
|
||||||
"last_sync_status": s.last_sync_status,
|
|
||||||
"last_sync_stats": s.last_sync_stats,
|
|
||||||
"sync_frequency": s.sync_frequency,
|
|
||||||
"config": s.config,
|
|
||||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
|
||||||
}
|
|
||||||
for s in sources
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{source_id}")
|
@router.patch("/{source_id}")
|
||||||
@@ -111,31 +66,21 @@ def update_data_source(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
|
||||||
if not ds:
|
|
||||||
raise HTTPException(status_code=404, detail="Data source not found")
|
|
||||||
|
|
||||||
update_data = body.model_dump(exclude_unset=True)
|
update_data = body.model_dump(exclude_unset=True)
|
||||||
|
update_source(db, source_id, **update_data)
|
||||||
if "is_enabled" in update_data:
|
with UnitOfWork(db) as uow:
|
||||||
ds.is_enabled = update_data["is_enabled"]
|
uow.commit()
|
||||||
if "sync_frequency" in update_data:
|
|
||||||
ds.sync_frequency = update_data["sync_frequency"]
|
|
||||||
if "config" in update_data:
|
|
||||||
ds.config = update_data["config"]
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
action="update_data_source",
|
action="update_data_source",
|
||||||
entity_type="data_source",
|
entity_type="data_source",
|
||||||
entity_id=str(ds.id),
|
entity_id=source_id,
|
||||||
details={"updates": update_data},
|
details={"updates": update_data},
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"message": "Data source updated", "id": str(ds.id)}
|
return {"message": "Data source updated", "id": source_id}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{source_id}/sync")
|
@router.post("/{source_id}/sync")
|
||||||
@@ -148,46 +93,7 @@ def sync_data_source(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
return sync_source(db, source_id)
|
||||||
if not ds:
|
|
||||||
raise HTTPException(status_code=404, detail="Data source not found")
|
|
||||||
|
|
||||||
handler = _get_sync_handler(ds.name)
|
|
||||||
if handler is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"No sync handler available for '{ds.name}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark as in_progress
|
|
||||||
ds.last_sync_status = "in_progress"
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
summary = handler(db)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
|
||||||
ds.last_sync_status = "error"
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_stats = {"error": str(exc)}
|
|
||||||
db.commit()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"Sync failed for '{ds.display_name}'. Check server logs for details.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update DS record (the handler may already have done this,
|
|
||||||
# but we ensure it here as well)
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_status = "success"
|
|
||||||
ds.last_sync_stats = summary
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": f"Sync complete for {ds.display_name}",
|
|
||||||
"source": ds.name,
|
|
||||||
"stats": summary,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sync-all")
|
@router.post("/sync-all")
|
||||||
@@ -199,49 +105,7 @@ def sync_all_data_sources(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
enabled_sources = (
|
results = sync_all_sources(db)
|
||||||
db.query(DataSource)
|
|
||||||
.filter(DataSource.is_enabled == True)
|
|
||||||
.order_by(DataSource.name)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for ds in enabled_sources:
|
|
||||||
handler = _get_sync_handler(ds.name)
|
|
||||||
if handler is None:
|
|
||||||
results.append({
|
|
||||||
"source": ds.name,
|
|
||||||
"status": "skipped",
|
|
||||||
"detail": "No sync handler available",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
ds.last_sync_status = "in_progress"
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
summary = handler(db)
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_status = "success"
|
|
||||||
ds.last_sync_stats = summary
|
|
||||||
db.commit()
|
|
||||||
results.append({
|
|
||||||
"source": ds.name,
|
|
||||||
"status": "success",
|
|
||||||
"stats": summary,
|
|
||||||
})
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
|
||||||
ds.last_sync_status = "error"
|
|
||||||
ds.last_sync_at = datetime.utcnow()
|
|
||||||
ds.last_sync_stats = {"error": str(exc)}
|
|
||||||
db.commit()
|
|
||||||
results.append({
|
|
||||||
"source": ds.name,
|
|
||||||
"status": "error",
|
|
||||||
"detail": "Sync failed. Check server logs for details.",
|
|
||||||
})
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
@@ -265,39 +129,4 @@ def get_data_source_stats(
|
|||||||
|
|
||||||
**Requires** the ``admin`` role.
|
**Requires** the ``admin`` role.
|
||||||
"""
|
"""
|
||||||
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
return get_source_stats(db, source_id)
|
||||||
if not ds:
|
|
||||||
raise HTTPException(status_code=404, detail="Data source not found")
|
|
||||||
|
|
||||||
# Count items from this source
|
|
||||||
from app.models.test_template import TestTemplate
|
|
||||||
from app.models.detection_rule import DetectionRule
|
|
||||||
|
|
||||||
template_count = 0
|
|
||||||
rule_count = 0
|
|
||||||
|
|
||||||
if ds.type == "attack_procedure":
|
|
||||||
template_count = (
|
|
||||||
db.query(TestTemplate)
|
|
||||||
.filter(TestTemplate.source == ds.name)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
elif ds.type == "detection_rule":
|
|
||||||
rule_count = (
|
|
||||||
db.query(DetectionRule)
|
|
||||||
.filter(DetectionRule.source == ds.name)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(ds.id),
|
|
||||||
"name": ds.name,
|
|
||||||
"display_name": ds.display_name,
|
|
||||||
"type": ds.type,
|
|
||||||
"is_enabled": ds.is_enabled,
|
|
||||||
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
|
||||||
"last_sync_status": ds.last_sync_status,
|
|
||||||
"last_sync_stats": ds.last_sync_stats,
|
|
||||||
"total_templates": template_count,
|
|
||||||
"total_rules": rule_count,
|
|
||||||
}
|
|
||||||
|
|||||||
+13
-83
@@ -7,14 +7,9 @@ from uuid import UUID
|
|||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
from app.domain.exceptions import EntityNotFoundError
|
from app.models.jira_link import JiraLinkEntityType
|
||||||
from app.models.jira_link import JiraLink, JiraLinkEntityType
|
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.campaign import Campaign
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.jira_schema import (
|
from app.schemas.jira_schema import (
|
||||||
JiraIssueResult,
|
JiraIssueResult,
|
||||||
@@ -45,23 +40,14 @@ def create_link(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Associate an Aegis entity with a Jira issue."""
|
"""Associate an Aegis entity with a Jira issue."""
|
||||||
link = JiraLink(
|
link = jira_service.create_link(
|
||||||
|
db,
|
||||||
entity_type=body.entity_type,
|
entity_type=body.entity_type,
|
||||||
entity_id=body.entity_id,
|
entity_id=body.entity_id,
|
||||||
jira_issue_key=body.jira_issue_key,
|
jira_issue_key=body.jira_issue_key,
|
||||||
sync_direction=body.sync_direction,
|
sync_direction=body.sync_direction,
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
)
|
)
|
||||||
db.add(link)
|
|
||||||
db.flush()
|
|
||||||
|
|
||||||
# Pull initial data from Jira if enabled
|
|
||||||
if settings.JIRA_ENABLED:
|
|
||||||
try:
|
|
||||||
jira_service.sync_jira_to_aegis(db, link)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e)
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(link)
|
db.refresh(link)
|
||||||
|
|
||||||
@@ -88,12 +74,11 @@ def list_links(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List Jira links, optionally filtered by entity."""
|
"""List Jira links, optionally filtered by entity."""
|
||||||
query = db.query(JiraLink)
|
return jira_service.list_links(
|
||||||
if entity_type:
|
db,
|
||||||
query = query.filter(JiraLink.entity_type == entity_type)
|
entity_type=entity_type,
|
||||||
if entity_id:
|
entity_id=entity_id,
|
||||||
query = query.filter(JiraLink.entity_id == entity_id)
|
)
|
||||||
return query.order_by(JiraLink.created_at.desc()).all()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/links/{link_id}/sync")
|
@router.post("/links/{link_id}/sync")
|
||||||
@@ -103,9 +88,7 @@ def sync_link(
|
|||||||
user: User = Depends(require_role("admin")),
|
user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Force bidirectional sync for a specific Jira link."""
|
"""Force bidirectional sync for a specific Jira link."""
|
||||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
link = jira_service.get_link_or_raise(db, link_id)
|
||||||
if not link:
|
|
||||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
|
||||||
jira_service.sync_jira_to_aegis(db, link)
|
jira_service.sync_jira_to_aegis(db, link)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"message": "Sync completed", "jira_status": link.jira_status}
|
return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||||
@@ -118,10 +101,7 @@ def delete_link(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Remove a Jira link."""
|
"""Remove a Jira link."""
|
||||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
link = jira_service.delete_link(db, link_id)
|
||||||
if not link:
|
|
||||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
|
||||||
db.delete(link)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
audit_service.log_action(
|
audit_service.log_action(
|
||||||
db,
|
db,
|
||||||
@@ -141,61 +121,11 @@ def create_issue_from_entity(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
||||||
summary, description = _build_issue_data(db, entity_type, entity_id)
|
result = jira_service.create_issue_and_link(
|
||||||
result = jira_service.create_jira_issue(
|
db,
|
||||||
project_key=settings.JIRA_DEFAULT_PROJECT,
|
|
||||||
summary=summary,
|
|
||||||
description=description,
|
|
||||||
labels=["aegis", entity_type.value],
|
|
||||||
)
|
|
||||||
link = JiraLink(
|
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
jira_issue_key=result["issue_key"],
|
|
||||||
jira_issue_id=result["issue_id"],
|
|
||||||
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
)
|
)
|
||||||
db.add(link)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _build_issue_data(
|
|
||||||
db: Session,
|
|
||||||
entity_type: JiraLinkEntityType,
|
|
||||||
entity_id: UUID,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""Build Jira issue summary + description from an Aegis entity."""
|
|
||||||
if entity_type == JiraLinkEntityType.test:
|
|
||||||
entity = db.query(Test).filter(Test.id == entity_id).first()
|
|
||||||
if not entity:
|
|
||||||
raise EntityNotFoundError("Test", str(entity_id))
|
|
||||||
return (
|
|
||||||
f"[Aegis Test] {entity.name}",
|
|
||||||
f"Test: {entity.name}\n"
|
|
||||||
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
|
||||||
f"Description: {entity.description or 'N/A'}",
|
|
||||||
)
|
|
||||||
elif entity_type == JiraLinkEntityType.campaign:
|
|
||||||
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
|
||||||
if not entity:
|
|
||||||
raise EntityNotFoundError("Campaign", str(entity_id))
|
|
||||||
return (
|
|
||||||
f"[Aegis Campaign] {entity.name}",
|
|
||||||
f"Campaign: {entity.name}\n"
|
|
||||||
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
|
||||||
f"Description: {entity.description or 'N/A'}",
|
|
||||||
)
|
|
||||||
elif entity_type == JiraLinkEntityType.technique:
|
|
||||||
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
|
||||||
if not entity:
|
|
||||||
raise EntityNotFoundError("Technique", str(entity_id))
|
|
||||||
return (
|
|
||||||
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
|
||||||
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
|
||||||
f"Tactic: {entity.tactic or 'N/A'}\n"
|
|
||||||
f"Description: {entity.description or 'N/A'}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
|
||||||
|
|||||||
@@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.notification import Notification
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||||
from app.services.notification_service import (
|
from app.services.notification_service import (
|
||||||
|
list_notifications,
|
||||||
mark_as_read,
|
mark_as_read,
|
||||||
mark_all_as_read,
|
mark_all_as_read,
|
||||||
get_unread_count,
|
get_unread_count,
|
||||||
@@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[NotificationOut])
|
@router.get("", response_model=list[NotificationOut])
|
||||||
def list_notifications(
|
def list_notifications_endpoint(
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
limit: int = Query(20, ge=1, le=100),
|
limit: int = Query(20, ge=1, le=100),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return paginated notifications for the current user, newest first."""
|
"""Return paginated notifications for the current user, newest first."""
|
||||||
notifs = (
|
return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||||
db.query(Notification)
|
|
||||||
.filter(Notification.user_id == current_user.id)
|
|
||||||
.order_by(Notification.created_at.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return notifs
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -80,14 +72,8 @@ def read_notification(
|
|||||||
):
|
):
|
||||||
"""Mark a single notification as read."""
|
"""Mark a single notification as read."""
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
success = mark_as_read(db, notification_id, current_user.id)
|
notif = mark_as_read(db, notification_id, current_user.id)
|
||||||
if not success:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Notification not found",
|
|
||||||
)
|
|
||||||
uow.commit()
|
uow.commit()
|
||||||
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
|
||||||
return notif
|
return notif
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,15 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.osint_item import OsintItem
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.osint_enrichment_service import (
|
from app.services.osint_enrichment_service import (
|
||||||
enrich_technique_with_cves,
|
enrich_technique_with_cves,
|
||||||
get_osint_items_for_technique,
|
get_osint_items_for_technique,
|
||||||
|
get_osint_summary,
|
||||||
|
get_technique_or_raise,
|
||||||
|
list_osint_items as service_list_osint_items,
|
||||||
mark_osint_reviewed,
|
mark_osint_reviewed,
|
||||||
get_unreviewed_count,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/osint", tags=["osint"])
|
router = APIRouter(prefix="/osint", tags=["osint"])
|
||||||
@@ -56,41 +57,15 @@ def list_osint_items(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List OSINT items with optional filters."""
|
"""List OSINT items with optional filters."""
|
||||||
query = db.query(OsintItem)
|
return service_list_osint_items(
|
||||||
if technique_id:
|
db,
|
||||||
query = query.filter(OsintItem.technique_id == technique_id)
|
technique_id=technique_id,
|
||||||
if source_type:
|
source_type=source_type,
|
||||||
query = query.filter(OsintItem.source_type == source_type)
|
reviewed=reviewed,
|
||||||
if reviewed is not None:
|
offset=offset,
|
||||||
query = query.filter(OsintItem.reviewed == reviewed)
|
limit=limit,
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
items = (
|
|
||||||
query.order_by(OsintItem.discovered_at.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"id": str(item.id),
|
|
||||||
"technique_id": str(item.technique_id),
|
|
||||||
"source_type": item.source_type,
|
|
||||||
"source_url": item.source_url,
|
|
||||||
"title": item.title,
|
|
||||||
"description": item.description,
|
|
||||||
"severity": item.severity,
|
|
||||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
|
||||||
"reviewed": item.reviewed,
|
|
||||||
"metadata": item.metadata_,
|
|
||||||
}
|
|
||||||
for item in items
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/summary")
|
@router.get("/summary")
|
||||||
def osint_summary(
|
def osint_summary(
|
||||||
@@ -98,34 +73,7 @@ def osint_summary(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Summary statistics for OSINT items."""
|
"""Summary statistics for OSINT items."""
|
||||||
from sqlalchemy import func
|
return get_osint_summary(db)
|
||||||
|
|
||||||
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
|
||||||
unreviewed = get_unreviewed_count(db)
|
|
||||||
|
|
||||||
by_severity = dict(
|
|
||||||
db.query(OsintItem.severity, func.count(OsintItem.id))
|
|
||||||
.group_by(OsintItem.severity)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
by_type = dict(
|
|
||||||
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
|
||||||
.group_by(OsintItem.source_type)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
techniques_with_items = (
|
|
||||||
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_items": total,
|
|
||||||
"unreviewed": unreviewed,
|
|
||||||
"techniques_with_items": techniques_with_items,
|
|
||||||
"by_severity": by_severity,
|
|
||||||
"by_type": by_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/items/{item_id}/review")
|
@router.post("/items/{item_id}/review")
|
||||||
@@ -135,12 +83,14 @@ def review_osint_item(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Mark an OSINT item as reviewed."""
|
"""Mark an OSINT item as reviewed."""
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
item = mark_osint_reviewed(db, str(item_id))
|
item = mark_osint_reviewed(db, str(item_id))
|
||||||
if not item:
|
if not item:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="OSINT item not found",
|
detail="OSINT item not found",
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
return {"id": str(item.id), "reviewed": True}
|
return {"id": str(item.id), "reviewed": True}
|
||||||
|
|
||||||
|
|
||||||
@@ -151,13 +101,7 @@ def trigger_technique_enrichment(
|
|||||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Manually trigger OSINT enrichment for a single technique."""
|
"""Manually trigger OSINT enrichment for a single technique."""
|
||||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
technique = get_technique_or_raise(db, technique_id)
|
||||||
if not technique:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Technique not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
count = enrich_technique_with_cves(db, technique)
|
count = enrich_technique_with_cves(db, technique)
|
||||||
return {
|
return {
|
||||||
"technique_id": str(technique.id),
|
"technique_id": str(technique.id),
|
||||||
|
|||||||
@@ -5,19 +5,18 @@ Provides granular scoring with breakdowns and configurable weights.
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.threat_actor import ThreatActor
|
|
||||||
from app.services.scoring_service import (
|
from app.services.scoring_service import (
|
||||||
calculate_technique_score,
|
score_technique_by_mitre_id,
|
||||||
|
score_actor_by_id,
|
||||||
calculate_tactic_score,
|
calculate_tactic_score,
|
||||||
calculate_actor_coverage_score,
|
|
||||||
calculate_organization_score,
|
calculate_organization_score,
|
||||||
get_score_history,
|
get_score_history,
|
||||||
)
|
)
|
||||||
@@ -39,23 +38,7 @@ def score_technique(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get detailed score with breakdown for a specific technique."""
|
"""Get detailed score with breakdown for a specific technique."""
|
||||||
technique = (
|
return score_technique_by_mitre_id(db, mitre_id)
|
||||||
db.query(Technique)
|
|
||||||
.filter(Technique.mitre_id == mitre_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not technique:
|
|
||||||
raise HTTPException(status_code=404, detail="Technique not found")
|
|
||||||
|
|
||||||
result = calculate_technique_score(technique, db)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mitre_id": technique.mitre_id,
|
|
||||||
"name": technique.name,
|
|
||||||
"tactic": technique.tactic,
|
|
||||||
"status_global": technique.status_global.value if technique.status_global else None,
|
|
||||||
**result,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
|
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
|
||||||
@@ -81,11 +64,7 @@ def score_threat_actor(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get coverage score against a specific threat actor."""
|
"""Get coverage score against a specific threat actor."""
|
||||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
return score_actor_by_id(db, actor_id)
|
||||||
if not actor:
|
|
||||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
|
||||||
|
|
||||||
return calculate_actor_coverage_score(actor_id, db)
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /scores/organization ─────────────────────────────────────────
|
# ── GET /scores/organization ─────────────────────────────────────────
|
||||||
@@ -149,6 +128,7 @@ def update_scoring_config(
|
|||||||
Weights are persisted in the database and survive restarts.
|
Weights are persisted in the database and survive restarts.
|
||||||
Validation enforces that all weights are non-negative and sum to 100.
|
Validation enforces that all weights are non-negative and sum to 100.
|
||||||
"""
|
"""
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
result = update_scoring_weights(
|
result = update_scoring_weights(
|
||||||
db,
|
db,
|
||||||
tests=payload.tests,
|
tests=payload.tests,
|
||||||
@@ -157,6 +137,7 @@ def update_scoring_config(
|
|||||||
freshness=payload.freshness,
|
freshness=payload.freshness,
|
||||||
platform_diversity=payload.platform_diversity,
|
platform_diversity=payload.platform_diversity,
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
from app.services.score_cache import invalidate
|
from app.services.score_cache import invalidate
|
||||||
invalidate()
|
invalidate()
|
||||||
|
|||||||
@@ -8,18 +8,24 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||||
|
from app.domain.errors import BusinessRuleViolation
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
|
||||||
from app.services.snapshot_service import (
|
from app.services.snapshot_service import (
|
||||||
create_snapshot,
|
create_snapshot,
|
||||||
compare_snapshots,
|
compare_snapshots,
|
||||||
cleanup_old_snapshots,
|
cleanup_old_snapshots,
|
||||||
|
serialize_snapshot_summary,
|
||||||
|
list_snapshots as list_snapshots_svc,
|
||||||
|
get_snapshot_or_raise,
|
||||||
|
get_snapshot_detail,
|
||||||
|
delete_snapshot,
|
||||||
)
|
)
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
|
|||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
|
||||||
"""Lightweight serialization for list views."""
|
|
||||||
return {
|
|
||||||
"id": str(snap.id),
|
|
||||||
"name": snap.name,
|
|
||||||
"organization_score": snap.organization_score,
|
|
||||||
"total_techniques": snap.total_techniques,
|
|
||||||
"validated_count": snap.validated_count,
|
|
||||||
"partial_count": snap.partial_count,
|
|
||||||
"not_covered_count": snap.not_covered_count,
|
|
||||||
"in_progress_count": snap.in_progress_count,
|
|
||||||
"not_evaluated_count": snap.not_evaluated_count,
|
|
||||||
"created_by": str(snap.created_by) if snap.created_by else None,
|
|
||||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
|
||||||
"""Full serialization including technique states."""
|
|
||||||
base = _serialize_snapshot_summary(snap)
|
|
||||||
|
|
||||||
technique_states = (
|
|
||||||
db.query(SnapshotTechniqueState)
|
|
||||||
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
|
||||||
.order_by(SnapshotTechniqueState.mitre_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
base["technique_states"] = [
|
|
||||||
{
|
|
||||||
"mitre_id": s.mitre_id,
|
|
||||||
"technique_id": str(s.technique_id),
|
|
||||||
"status": s.status,
|
|
||||||
"score": s.score,
|
|
||||||
}
|
|
||||||
for s in technique_states
|
|
||||||
]
|
|
||||||
return base
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /snapshots — List snapshots (paginated)
|
# GET /snapshots — List snapshots (paginated)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -88,23 +52,7 @@ def list_snapshots(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||||
query = db.query(CoverageSnapshot)
|
return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
snapshots = (
|
|
||||||
query
|
|
||||||
.order_by(CoverageSnapshot.created_at.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"offset": offset,
|
|
||||||
"limit": limit,
|
|
||||||
"items": [_serialize_snapshot_summary(s) for s in snapshots],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
|
|||||||
details={"name": snapshot.name, "score": snapshot.organization_score},
|
details={"name": snapshot.name, "score": snapshot.organization_score},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_snapshot_summary(snapshot)
|
return serialize_snapshot_summary(snapshot)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
|
|||||||
a_id = uuid.UUID(a)
|
a_id = uuid.UUID(a)
|
||||||
b_id = uuid.UUID(b)
|
b_id = uuid.UUID(b)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
|
raise BusinessRuleViolation("Invalid snapshot ID format")
|
||||||
|
|
||||||
result = compare_snapshots(db, a_id, b_id)
|
return compare_snapshots(db, a_id, b_id)
|
||||||
if "error" in result:
|
|
||||||
raise HTTPException(status_code=404, detail=result["error"])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -168,11 +112,7 @@ def get_snapshot(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get detailed snapshot information including per-technique states."""
|
"""Get detailed snapshot information including per-technique states."""
|
||||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
return get_snapshot_detail(db, snapshot_id)
|
||||||
if not snapshot:
|
|
||||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
||||||
|
|
||||||
return _serialize_snapshot_detail(db, snapshot)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -180,15 +120,13 @@ def get_snapshot(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.delete("/{snapshot_id}")
|
@router.delete("/{snapshot_id}")
|
||||||
def delete_snapshot(
|
def delete_snapshot_endpoint(
|
||||||
snapshot_id: str,
|
snapshot_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Delete a snapshot (admin only)."""
|
"""Delete a snapshot (admin only)."""
|
||||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
if not snapshot:
|
|
||||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
@@ -199,7 +137,8 @@ def delete_snapshot(
|
|||||||
details={"name": snapshot.name},
|
details={"name": snapshot.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
db.delete(snapshot)
|
with UnitOfWork(db) as uow:
|
||||||
db.commit()
|
delete_snapshot(db, snapshot_id)
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
return {"detail": "Snapshot deleted"}
|
return {"detail": "Snapshot deleted"}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ exceptions to HTTP responses automatically.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||||
@@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork
|
|||||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
SATechniqueRepository,
|
SATechniqueRepository,
|
||||||
)
|
)
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.technique import (
|
from app.schemas.technique import (
|
||||||
TechniqueCreate,
|
TechniqueCreate,
|
||||||
@@ -27,7 +26,7 @@ from app.schemas.technique import (
|
|||||||
TechniqueUpdate,
|
TechniqueUpdate,
|
||||||
)
|
)
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
from app.services.d3fend_import_service import get_defenses_for_technique
|
from app.services.technique_query_service import get_technique_detail
|
||||||
|
|
||||||
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||||
|
|
||||||
@@ -67,45 +66,7 @@ def get_technique(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
||||||
technique = (
|
return get_technique_detail(db, mitre_id)
|
||||||
db.query(Technique)
|
|
||||||
.options(joinedload(Technique.tests))
|
|
||||||
.filter(Technique.mitre_id == mitre_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if technique is None:
|
|
||||||
raise EntityNotFoundError("Technique", mitre_id)
|
|
||||||
|
|
||||||
defenses = get_defenses_for_technique(db, technique.id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(technique.id),
|
|
||||||
"mitre_id": technique.mitre_id,
|
|
||||||
"name": technique.name,
|
|
||||||
"description": technique.description,
|
|
||||||
"tactic": technique.tactic,
|
|
||||||
"platforms": technique.platforms or [],
|
|
||||||
"mitre_version": technique.mitre_version,
|
|
||||||
"mitre_last_modified": technique.mitre_last_modified,
|
|
||||||
"is_subtechnique": technique.is_subtechnique,
|
|
||||||
"parent_mitre_id": technique.parent_mitre_id,
|
|
||||||
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
|
||||||
"review_required": technique.review_required,
|
|
||||||
"last_review_date": technique.last_review_date,
|
|
||||||
"tests": [
|
|
||||||
{
|
|
||||||
"id": str(t.id),
|
|
||||||
"name": t.name,
|
|
||||||
"state": t.state.value if t.state else None,
|
|
||||||
"result": t.result.value if t.result else None,
|
|
||||||
"platform": t.platform,
|
|
||||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
|
||||||
}
|
|
||||||
for t in technique.tests
|
|
||||||
],
|
|
||||||
"d3fend_defenses": defenses,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -25,13 +25,12 @@ Filters (GET /test-templates)
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
from sqlalchemy import func, or_
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.test_template import TestTemplate
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.test_template import (
|
from app.schemas.test_template import (
|
||||||
TestTemplateCreate,
|
TestTemplateCreate,
|
||||||
@@ -39,6 +38,17 @@ from app.schemas.test_template import (
|
|||||||
TestTemplateSummary,
|
TestTemplateSummary,
|
||||||
)
|
)
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
from app.services.test_template_service import (
|
||||||
|
bulk_activate,
|
||||||
|
create_template as create_template_svc,
|
||||||
|
get_template_or_raise,
|
||||||
|
get_template_stats,
|
||||||
|
get_templates_by_technique as templates_by_technique,
|
||||||
|
list_templates,
|
||||||
|
soft_delete_template,
|
||||||
|
toggle_template_active as toggle_template_active_svc,
|
||||||
|
update_template as update_template_svc,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
||||||
|
|
||||||
@@ -49,7 +59,7 @@ router = APIRouter(prefix="/test-templates", tags=["test-templates"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[TestTemplateSummary])
|
@router.get("", response_model=list[TestTemplateSummary])
|
||||||
def list_templates(
|
def _list_templates_handler(
|
||||||
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
|
source: Optional[str] = Query(None, description="Filter by source (atomic_red_team, mitre, custom)"),
|
||||||
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
|
platform: Optional[str] = Query(None, description="Filter by platform (windows, linux, macos)"),
|
||||||
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
|
severity: Optional[str] = Query(None, description="Filter by severity (low, medium, high, critical)"),
|
||||||
@@ -62,37 +72,17 @@ def list_templates(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return a paginated, filterable list of test templates."""
|
"""Return a paginated, filterable list of test templates."""
|
||||||
query = db.query(TestTemplate)
|
return list_templates(
|
||||||
if is_active is not None:
|
db,
|
||||||
query = query.filter(TestTemplate.is_active == is_active) # noqa: E712
|
source=source,
|
||||||
|
platform=platform,
|
||||||
if source:
|
severity=severity,
|
||||||
query = query.filter(TestTemplate.source == source)
|
mitre_technique_id=mitre_technique_id,
|
||||||
if platform:
|
search=search,
|
||||||
from app.utils import escape_like
|
is_active=is_active,
|
||||||
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
|
offset=offset,
|
||||||
if severity:
|
limit=limit,
|
||||||
query = query.filter(TestTemplate.severity == severity)
|
|
||||||
if mitre_technique_id:
|
|
||||||
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
|
|
||||||
if search:
|
|
||||||
from app.utils import escape_like
|
|
||||||
pattern = f"%{escape_like(search)}%"
|
|
||||||
query = query.filter(
|
|
||||||
or_(
|
|
||||||
TestTemplate.name.ilike(pattern),
|
|
||||||
TestTemplate.description.ilike(pattern),
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
templates = (
|
|
||||||
query
|
|
||||||
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return templates
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -105,41 +95,8 @@ def template_stats(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Return catalog statistics: totals by source, platform, active/inactive."""
|
"""Return catalog statistics: active, by_source, by_platform."""
|
||||||
|
return get_template_stats(db)
|
||||||
total = db.query(func.count(TestTemplate.id)).scalar() or 0
|
|
||||||
active = (
|
|
||||||
db.query(func.count(TestTemplate.id))
|
|
||||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
|
||||||
.scalar()
|
|
||||||
) or 0
|
|
||||||
inactive = total - active
|
|
||||||
|
|
||||||
# By source
|
|
||||||
source_rows = (
|
|
||||||
db.query(TestTemplate.source, func.count(TestTemplate.id))
|
|
||||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
|
||||||
.group_by(TestTemplate.source)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
by_source = {source: cnt for source, cnt in source_rows}
|
|
||||||
|
|
||||||
# By platform
|
|
||||||
platform_rows = (
|
|
||||||
db.query(TestTemplate.platform, func.count(TestTemplate.id))
|
|
||||||
.filter(TestTemplate.is_active == True) # noqa: E712
|
|
||||||
.group_by(TestTemplate.platform)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"active": active,
|
|
||||||
"inactive": inactive,
|
|
||||||
"by_source": by_source,
|
|
||||||
"by_platform": by_platform,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -154,13 +111,8 @@ def bulk_activate_templates(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Set all templates to active or inactive."""
|
"""Set all templates to active or inactive."""
|
||||||
count = (
|
count = bulk_activate(db, activate=activate)
|
||||||
db.query(TestTemplate)
|
with UnitOfWork(db) as uow:
|
||||||
.filter(TestTemplate.is_active != activate)
|
|
||||||
.update({TestTemplate.is_active: activate})
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -169,6 +121,7 @@ def bulk_activate_templates(
|
|||||||
entity_id=None,
|
entity_id=None,
|
||||||
details={"affected": count, "is_active": activate},
|
details={"affected": count, "is_active": activate},
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
|
"detail": f"{'Activated' if activate else 'Deactivated'} {count} templates",
|
||||||
@@ -183,22 +136,13 @@ def bulk_activate_templates(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
|
@router.get("/by-technique/{mitre_id}", response_model=list[TestTemplateSummary])
|
||||||
def templates_by_technique(
|
def _templates_by_technique_handler(
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return all active templates mapped to a specific MITRE technique."""
|
"""Return all active templates mapped to a specific MITRE technique."""
|
||||||
templates = (
|
return templates_by_technique(db, mitre_id)
|
||||||
db.query(TestTemplate)
|
|
||||||
.filter(
|
|
||||||
TestTemplate.mitre_technique_id == mitre_id,
|
|
||||||
TestTemplate.is_active == True, # noqa: E712
|
|
||||||
)
|
|
||||||
.order_by(TestTemplate.name)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return templates
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -213,13 +157,7 @@ def get_template(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return full details for a single test template."""
|
"""Return full details for a single test template."""
|
||||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
return get_template_or_raise(db, template_id)
|
||||||
if template is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test template not found",
|
|
||||||
)
|
|
||||||
return template
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -238,11 +176,8 @@ def create_template(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Create a custom test template."""
|
"""Create a custom test template."""
|
||||||
template = TestTemplate(**payload.model_dump())
|
template = create_template_svc(db, **payload.model_dump())
|
||||||
db.add(template)
|
with UnitOfWork(db) as uow:
|
||||||
db.commit()
|
|
||||||
db.refresh(template)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -255,6 +190,8 @@ def create_template(
|
|||||||
"mitre_technique_id": template.mitre_technique_id,
|
"mitre_technique_id": template.mitre_technique_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
|
db.refresh(template)
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
@@ -272,28 +209,18 @@ def update_template(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Update fields of an existing test template."""
|
"""Update fields of an existing test template."""
|
||||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
|
||||||
if template is None:
|
with UnitOfWork(db) as uow:
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test template not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(template, field, value)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(template)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
action="update_test_template",
|
action="update_test_template",
|
||||||
entity_type="test_template",
|
entity_type="test_template",
|
||||||
entity_id=template.id,
|
entity_id=template.id,
|
||||||
details={"updated_fields": list(update_data.keys())},
|
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
|
db.refresh(template)
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
@@ -309,18 +236,9 @@ def toggle_template_active(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Toggle a template between active and inactive."""
|
"""Toggle a template between active and inactive (is_active = not is_active)."""
|
||||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
template = toggle_template_active_svc(db, template_id)
|
||||||
if template is None:
|
with UnitOfWork(db) as uow:
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test template not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
template.is_active = not template.is_active
|
|
||||||
db.commit()
|
|
||||||
db.refresh(template)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -329,6 +247,8 @@ def toggle_template_active(
|
|||||||
entity_id=template.id,
|
entity_id=template.id,
|
||||||
details={"name": template.name, "is_active": template.is_active},
|
details={"name": template.name, "is_active": template.is_active},
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
|
db.refresh(template)
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
@@ -345,16 +265,9 @@ def delete_template(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Soft-delete a test template by setting ``is_active=False``."""
|
"""Soft-delete a test template by setting ``is_active=False``."""
|
||||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
template = get_template_or_raise(db, template_id)
|
||||||
if template is None:
|
soft_delete_template(db, template_id)
|
||||||
raise HTTPException(
|
with UnitOfWork(db) as uow:
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test template not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
template.is_active = False
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -363,5 +276,6 @@ def delete_template(
|
|||||||
entity_id=template.id,
|
entity_id=template.id,
|
||||||
details={"name": template.name},
|
details={"name": template.name},
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
return {"detail": "Test template deactivated"}
|
return {"detail": "Test template deactivated"}
|
||||||
|
|||||||
@@ -2,20 +2,24 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import require_role
|
from app.dependencies.auth import require_role
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
from app.schemas.user import UserCreate, UserUpdate, UserOut
|
||||||
from app.auth import hash_password
|
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
from app.services.user_service import (
|
||||||
|
create_user,
|
||||||
|
get_user_or_raise,
|
||||||
|
list_users,
|
||||||
|
update_user,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/users", tags=["users"])
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /users — list all users
|
# GET /users — list all users
|
||||||
@@ -23,12 +27,12 @@ VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewe
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[UserOut])
|
@router.get("", response_model=list[UserOut])
|
||||||
def list_users(
|
def list_users_route(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Return a list of all users. **Requires admin role.**"""
|
"""Return a list of all users. **Requires admin role.**"""
|
||||||
return db.query(User).order_by(User.username).all()
|
return list_users(db)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -37,36 +41,21 @@ def list_users(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
@router.post("", response_model=UserOut, status_code=status.HTTP_201_CREATED)
|
||||||
def create_user(
|
def create_user_route(
|
||||||
payload: UserCreate,
|
payload: UserCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Create a new user. **Requires admin role.**"""
|
"""Create a new user. **Requires admin role.**"""
|
||||||
|
user = create_user(
|
||||||
# Check if username already exists
|
db,
|
||||||
existing = db.query(User).filter(User.username == payload.username).first()
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=f"Username '{payload.username}' already exists",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate role
|
|
||||||
if payload.role not in VALID_ROLES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Invalid role '{payload.role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
|
||||||
)
|
|
||||||
|
|
||||||
user = User(
|
|
||||||
username=payload.username,
|
username=payload.username,
|
||||||
email=payload.email,
|
email=payload.email,
|
||||||
hashed_password=hash_password(payload.password),
|
password=payload.password,
|
||||||
role=payload.role,
|
role=payload.role,
|
||||||
)
|
)
|
||||||
db.add(user)
|
with UnitOfWork(db) as uow:
|
||||||
db.commit()
|
uow.commit()
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -93,13 +82,7 @@ def get_user(
|
|||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Return a single user by ID. **Requires admin role.**"""
|
"""Return a single user by ID. **Requires admin role.**"""
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
return get_user_or_raise(db, user_id)
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -108,37 +91,17 @@ def get_user(
|
|||||||
|
|
||||||
|
|
||||||
@router.patch("/{user_id}", response_model=UserOut)
|
@router.patch("/{user_id}", response_model=UserOut)
|
||||||
def update_user(
|
def update_user_route(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
payload: UserUpdate,
|
payload: UserUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
"""Update one or more fields of an existing user. **Requires admin role.**"""
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
|
user = update_user(db, user_id, **update_data)
|
||||||
# Validate role if being updated
|
with UnitOfWork(db) as uow:
|
||||||
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
uow.commit()
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hash password if being updated
|
|
||||||
if "password" in update_data:
|
|
||||||
update_data["hashed_password"] = hash_password(update_data.pop("password"))
|
|
||||||
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(user, field, value)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -147,7 +110,7 @@ def update_user(
|
|||||||
action="update_user",
|
action="update_user",
|
||||||
entity_type="user",
|
entity_type="user",
|
||||||
entity_id=user.id,
|
entity_id=user.id,
|
||||||
details={"updated_fields": list(payload.model_dump(exclude_unset=True).keys())},
|
details={"updated_fields": list(update_data.keys())},
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -10,9 +10,8 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.domain.exceptions import EntityNotFoundError
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.worklog import Worklog
|
|
||||||
from app.services import worklog_service
|
from app.services import worklog_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||||
@@ -59,6 +58,7 @@ def create(
|
|||||||
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Create a manually-logged worklog entry."""
|
"""Create a manually-logged worklog entry."""
|
||||||
|
with UnitOfWork(db) as uow:
|
||||||
wl = worklog_service.create_worklog(
|
wl = worklog_service.create_worklog(
|
||||||
db,
|
db,
|
||||||
entity_type=body.entity_type,
|
entity_type=body.entity_type,
|
||||||
@@ -70,6 +70,8 @@ def create(
|
|||||||
duration_seconds=body.duration_seconds,
|
duration_seconds=body.duration_seconds,
|
||||||
description=body.description,
|
description=body.description,
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
|
db.refresh(wl)
|
||||||
return wl
|
return wl
|
||||||
|
|
||||||
|
|
||||||
@@ -97,10 +99,7 @@ def get_one(
|
|||||||
_user: User = Depends(get_current_user),
|
_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get a single worklog by ID."""
|
"""Get a single worklog by ID."""
|
||||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
if not wl:
|
|
||||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
|
||||||
return wl
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{worklog_id}/verify")
|
@router.get("/{worklog_id}/verify")
|
||||||
@@ -110,9 +109,7 @@ def verify_integrity(
|
|||||||
_user: User = Depends(get_current_user),
|
_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Check whether a worklog's integrity hash is still valid."""
|
"""Check whether a worklog's integrity hash is still valid."""
|
||||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
if not wl:
|
|
||||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
|
||||||
return {
|
return {
|
||||||
"worklog_id": str(wl.id),
|
"worklog_id": str(wl.id),
|
||||||
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
||||||
|
|||||||
@@ -0,0 +1,160 @@
|
|||||||
|
"""Advanced metrics service — coverage by tactic, never-tested, avg validation time, detection trend."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import case, func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
from app.models.enums import TestResult
|
||||||
|
|
||||||
|
|
||||||
|
def get_coverage_by_tactic(db: Session) -> list[dict]:
|
||||||
|
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
|
||||||
|
results = (
|
||||||
|
db.query(
|
||||||
|
Technique.tactic,
|
||||||
|
func.count(Technique.id).label("total"),
|
||||||
|
func.sum(
|
||||||
|
case((Technique.status_global == "validated", 1), else_=0)
|
||||||
|
).label("validated"),
|
||||||
|
func.sum(
|
||||||
|
case((Technique.status_global == "partial", 1), else_=0)
|
||||||
|
).label("partial"),
|
||||||
|
func.sum(
|
||||||
|
case((Technique.status_global == "not_covered", 1), else_=0)
|
||||||
|
).label("not_covered"),
|
||||||
|
func.sum(
|
||||||
|
case((Technique.status_global == "in_progress", 1), else_=0)
|
||||||
|
).label("in_progress"),
|
||||||
|
)
|
||||||
|
.group_by(Technique.tactic)
|
||||||
|
.order_by(Technique.tactic)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"tactic": r[0] or "Unknown",
|
||||||
|
"total": r[1],
|
||||||
|
"validated": int(r[2]),
|
||||||
|
"partial": int(r[3]),
|
||||||
|
"not_covered": int(r[4]),
|
||||||
|
"in_progress": int(r[5]),
|
||||||
|
"coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0,
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_never_tested_techniques(db: Session) -> list[dict]:
|
||||||
|
"""Techniques that have never had a test created."""
|
||||||
|
tested_technique_ids = db.query(Test.technique_id).distinct().subquery()
|
||||||
|
techniques = (
|
||||||
|
db.query(Technique)
|
||||||
|
.filter(~Technique.id.in_(db.query(tested_technique_ids)))
|
||||||
|
.order_by(Technique.mitre_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"mitre_id": t.mitre_id,
|
||||||
|
"name": t.name,
|
||||||
|
"tactic": t.tactic,
|
||||||
|
"is_subtechnique": t.is_subtechnique,
|
||||||
|
}
|
||||||
|
for t in techniques
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_avg_validation_time(db: Session) -> dict:
|
||||||
|
"""Average time from test creation to validation, computed from validated tests.
|
||||||
|
|
||||||
|
Returns overall average and per-phase averages where data is available.
|
||||||
|
"""
|
||||||
|
validated_tests = (
|
||||||
|
db.query(Test)
|
||||||
|
.filter(Test.state == "validated")
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not validated_tests:
|
||||||
|
return {
|
||||||
|
"total_validated": 0,
|
||||||
|
"avg_total_hours": 0,
|
||||||
|
"avg_red_phase_hours": 0,
|
||||||
|
"avg_blue_phase_hours": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
total_durations = []
|
||||||
|
red_durations = []
|
||||||
|
blue_durations = []
|
||||||
|
|
||||||
|
for test in validated_tests:
|
||||||
|
if test.created_at and test.red_validated_at:
|
||||||
|
total_seconds = (test.red_validated_at - test.created_at).total_seconds()
|
||||||
|
total_durations.append(total_seconds)
|
||||||
|
|
||||||
|
if test.red_started_at and test.blue_started_at:
|
||||||
|
red_sec = (test.blue_started_at - test.red_started_at).total_seconds()
|
||||||
|
red_paused = test.red_paused_seconds or 0
|
||||||
|
red_durations.append(max(red_sec - red_paused, 0))
|
||||||
|
|
||||||
|
if test.blue_started_at and test.blue_validated_at:
|
||||||
|
blue_sec = (test.blue_validated_at - test.blue_started_at).total_seconds()
|
||||||
|
blue_paused = test.blue_paused_seconds or 0
|
||||||
|
blue_durations.append(max(blue_sec - blue_paused, 0))
|
||||||
|
|
||||||
|
def avg_hours(durations: list[float]) -> float:
|
||||||
|
if not durations:
|
||||||
|
return 0
|
||||||
|
return round(sum(durations) / len(durations) / 3600, 2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_validated": len(validated_tests),
|
||||||
|
"avg_total_hours": avg_hours(total_durations),
|
||||||
|
"avg_red_phase_hours": avg_hours(red_durations),
|
||||||
|
"avg_blue_phase_hours": avg_hours(blue_durations),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_detection_rate_trend(db: Session) -> list[dict]:
|
||||||
|
"""Monthly detection rate trend for the last 12 months."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
months = []
|
||||||
|
|
||||||
|
for i in range(11, -1, -1):
|
||||||
|
month_start = datetime(now.year, now.month, 1) - timedelta(days=i * 30)
|
||||||
|
month_end = month_start + timedelta(days=30)
|
||||||
|
|
||||||
|
validated = (
|
||||||
|
db.query(func.count(Test.id))
|
||||||
|
.filter(
|
||||||
|
Test.state == "validated",
|
||||||
|
Test.created_at >= month_start,
|
||||||
|
Test.created_at < month_end,
|
||||||
|
)
|
||||||
|
.scalar() or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
detected = (
|
||||||
|
db.query(func.count(Test.id))
|
||||||
|
.filter(
|
||||||
|
Test.state == "validated",
|
||||||
|
Test.detection_result == TestResult.detected,
|
||||||
|
Test.created_at >= month_start,
|
||||||
|
Test.created_at < month_end,
|
||||||
|
)
|
||||||
|
.scalar() or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
months.append({
|
||||||
|
"month": month_start.strftime("%Y-%m"),
|
||||||
|
"validated": validated,
|
||||||
|
"detected": detected,
|
||||||
|
"detection_rate": round((detected / validated) * 100, 1) if validated > 0 else 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
return months
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Analytics service — flat JSON optimized for PowerBI / BI tools."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.models.coverage_snapshot import CoverageSnapshot
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
def get_coverage_analytics(db: Session) -> list[dict]:
|
||||||
|
"""Coverage per technique — flat format for BI dashboards."""
|
||||||
|
techniques = db.query(Technique).all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"mitre_id": t.mitre_id,
|
||||||
|
"name": t.name,
|
||||||
|
"tactic": t.tactic,
|
||||||
|
"status": t.status_global.value if t.status_global else "not_evaluated",
|
||||||
|
"is_subtechnique": t.is_subtechnique,
|
||||||
|
"test_count": len(t.tests) if t.tests else 0,
|
||||||
|
"review_required": t.review_required,
|
||||||
|
"last_review_date": (
|
||||||
|
t.last_review_date.isoformat() if t.last_review_date else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for t in techniques
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tests_analytics(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
date_from: str | None = None,
|
||||||
|
date_to: str | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""All tests with timestamps — flat format for BI dashboards."""
|
||||||
|
query = db.query(Test)
|
||||||
|
if date_from:
|
||||||
|
query = query.filter(Test.created_at >= date_from)
|
||||||
|
if date_to:
|
||||||
|
query = query.filter(Test.created_at <= date_to)
|
||||||
|
tests = query.all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(t.id),
|
||||||
|
"technique_id": str(t.technique_id),
|
||||||
|
"name": t.name,
|
||||||
|
"state": t.state.value if t.state else None,
|
||||||
|
"result": t.result.value if t.result else None,
|
||||||
|
"detection_result": (
|
||||||
|
t.detection_result.value if t.detection_result else None
|
||||||
|
),
|
||||||
|
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||||
|
"execution_date": (
|
||||||
|
t.execution_date.isoformat() if t.execution_date else None
|
||||||
|
),
|
||||||
|
"platform": t.platform,
|
||||||
|
"tool_used": t.tool_used,
|
||||||
|
"attack_success": t.attack_success,
|
||||||
|
"remediation_status": t.remediation_status,
|
||||||
|
}
|
||||||
|
for t in tests
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_trends_analytics(db: Session) -> list[dict]:
|
||||||
|
"""Historical coverage snapshots for trend visualization."""
|
||||||
|
snapshots = (
|
||||||
|
db.query(CoverageSnapshot)
|
||||||
|
.order_by(CoverageSnapshot.created_at)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"date": s.created_at.isoformat() if s.created_at else None,
|
||||||
|
"name": s.name,
|
||||||
|
"total_techniques": s.total_techniques,
|
||||||
|
"validated_count": s.validated_count,
|
||||||
|
"partial_count": s.partial_count,
|
||||||
|
"not_covered_count": s.not_covered_count,
|
||||||
|
"organization_score": s.organization_score,
|
||||||
|
}
|
||||||
|
for s in snapshots
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_operators_analytics(db: Session) -> list[dict]:
|
||||||
|
"""Per-operator metrics — for workload management dashboards."""
|
||||||
|
results = (
|
||||||
|
db.query(
|
||||||
|
User.username,
|
||||||
|
User.role,
|
||||||
|
func.count(Test.id).label("test_count"),
|
||||||
|
)
|
||||||
|
.outerjoin(Test, Test.created_by == User.id)
|
||||||
|
.group_by(User.id, User.username, User.role)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{"username": r[0], "role": r[1], "test_count": r[2]}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
"""Audit log query service — framework-agnostic query logic for audit logs.
|
||||||
|
|
||||||
|
Provides paginated logs and distinct action/entity-type lists.
|
||||||
|
No FastAPI imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
|
||||||
|
|
||||||
|
def list_logs(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
action: str | None = None,
|
||||||
|
entity_type: str | None = None,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""Return paginated audit logs with optional filters.
|
||||||
|
|
||||||
|
Returns a dict with keys: items, total, offset, limit.
|
||||||
|
Each item is a dict with: id, user_id, username, action, entity_type,
|
||||||
|
entity_id, timestamp, details.
|
||||||
|
"""
|
||||||
|
query = db.query(AuditLog).options(joinedload(AuditLog.user))
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
query = query.filter(AuditLog.user_id == user_id)
|
||||||
|
if action:
|
||||||
|
query = query.filter(AuditLog.action == action)
|
||||||
|
if entity_type:
|
||||||
|
query = query.filter(AuditLog.entity_type == entity_type)
|
||||||
|
if start_date:
|
||||||
|
query = query.filter(AuditLog.timestamp >= start_date)
|
||||||
|
if end_date:
|
||||||
|
query = query.filter(AuditLog.timestamp <= end_date)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
logs = (
|
||||||
|
query
|
||||||
|
.order_by(AuditLog.timestamp.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
{
|
||||||
|
"id": log.id,
|
||||||
|
"user_id": log.user_id,
|
||||||
|
"username": log.user.username if log.user else None,
|
||||||
|
"action": log.action,
|
||||||
|
"entity_type": log.entity_type,
|
||||||
|
"entity_id": log.entity_id,
|
||||||
|
"timestamp": log.timestamp,
|
||||||
|
"details": log.details,
|
||||||
|
}
|
||||||
|
for log in logs
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"items": items, "total": total, "offset": offset, "limit": limit}
|
||||||
|
|
||||||
|
|
||||||
|
def list_distinct_actions(db: Session) -> list[str]:
|
||||||
|
"""Return a list of distinct action types in the audit log."""
|
||||||
|
actions = (
|
||||||
|
db.query(AuditLog.action)
|
||||||
|
.distinct()
|
||||||
|
.order_by(AuditLog.action)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [a[0] for a in actions]
|
||||||
|
|
||||||
|
|
||||||
|
def list_distinct_entity_types(db: Session) -> list[str]:
|
||||||
|
"""Return a list of distinct entity types in the audit log."""
|
||||||
|
types = (
|
||||||
|
db.query(AuditLog.entity_type)
|
||||||
|
.filter(AuditLog.entity_type.isnot(None))
|
||||||
|
.distinct()
|
||||||
|
.order_by(AuditLog.entity_type)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [t[0] for t in types]
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""Authentication service — credential validation and password management."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth import hash_password, verify_password
|
||||||
|
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
_DUMMY_HASH = "$2b$12$LJ3m4ys3Lg3dMO/NpNmOaeVwFpWJMxlB2FLmEAo9fZr.S8H1vC4Wy"
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user(db: Session, *, username: str, password: str) -> User:
|
||||||
|
"""Validate credentials and return the User.
|
||||||
|
|
||||||
|
Raises BusinessRuleViolation for invalid credentials.
|
||||||
|
Raises PermissionViolation for disabled account.
|
||||||
|
Uses constant-time comparison to prevent timing attacks.
|
||||||
|
"""
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
hashed = user.hashed_password if user else _DUMMY_HASH
|
||||||
|
password_valid = verify_password(password, hashed)
|
||||||
|
|
||||||
|
if user is None or not password_valid:
|
||||||
|
raise BusinessRuleViolation("Incorrect username or password")
|
||||||
|
if not user.is_active:
|
||||||
|
raise PermissionViolation("Account is disabled. Contact an administrator.")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def change_password(
|
||||||
|
db: Session,
|
||||||
|
user: User,
|
||||||
|
*,
|
||||||
|
current_password: str,
|
||||||
|
new_password: str,
|
||||||
|
) -> None:
|
||||||
|
"""Change a user's password. Does NOT commit.
|
||||||
|
|
||||||
|
Raises BusinessRuleViolation if current password is wrong.
|
||||||
|
"""
|
||||||
|
if not verify_password(current_password, user.hashed_password):
|
||||||
|
raise BusinessRuleViolation("Current password is incorrect")
|
||||||
|
user.hashed_password = hash_password(new_password)
|
||||||
|
user.must_change_password = False
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
|
from app.models.defensive_technique import DefensiveTechnique
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||||
|
from app.utils import escape_like
|
||||||
|
|
||||||
|
|
||||||
|
def list_defensive_techniques(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
tactic: Optional[str] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""List D3FEND defensive techniques with optional filters."""
|
||||||
|
query = db.query(DefensiveTechnique)
|
||||||
|
|
||||||
|
if tactic:
|
||||||
|
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
pattern = f"%{escape_like(search)}%"
|
||||||
|
query = query.filter(
|
||||||
|
DefensiveTechnique.name.ilike(pattern)
|
||||||
|
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||||
|
)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"offset": offset,
|
||||||
|
"limit": limit,
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"id": str(dt.id),
|
||||||
|
"d3fend_id": dt.d3fend_id,
|
||||||
|
"name": dt.name,
|
||||||
|
"description": dt.description,
|
||||||
|
"tactic": dt.tactic,
|
||||||
|
"d3fend_url": dt.d3fend_url,
|
||||||
|
}
|
||||||
|
for dt in items
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def list_d3fend_tactics(db: Session) -> list[dict]:
|
||||||
|
"""Return a list of all D3FEND tactics with counts."""
|
||||||
|
rows = (
|
||||||
|
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
||||||
|
.group_by(DefensiveTechnique.tactic)
|
||||||
|
.order_by(DefensiveTechnique.tactic)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
|
||||||
|
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||||
|
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
|
||||||
|
defenses = get_defenses_for_technique(db, technique.id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mitre_id": mitre_id,
|
||||||
|
"technique_name": technique.name,
|
||||||
|
"defenses": defenses,
|
||||||
|
"total": len(defenses),
|
||||||
|
}
|
||||||
@@ -0,0 +1,197 @@
|
|||||||
|
"""Data source management service — framework-agnostic query and sync logic.
|
||||||
|
|
||||||
|
Provides list, update, sync, and stats. Sync operations commit internally
|
||||||
|
since they are long-running and self-contained.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import BusinessRuleViolation, EntityNotFoundError
|
||||||
|
from app.domain.ports.import_service import get_import_handler
|
||||||
|
from app.models.data_source import DataSource
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def list_sources(db: Session) -> list[dict]:
|
||||||
|
"""Return all registered data sources as a list of dicts."""
|
||||||
|
sources = db.query(DataSource).order_by(DataSource.name).all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(s.id),
|
||||||
|
"name": s.name,
|
||||||
|
"display_name": s.display_name,
|
||||||
|
"type": s.type,
|
||||||
|
"url": s.url,
|
||||||
|
"description": s.description,
|
||||||
|
"is_enabled": s.is_enabled,
|
||||||
|
"last_sync_at": s.last_sync_at.isoformat() if s.last_sync_at else None,
|
||||||
|
"last_sync_status": s.last_sync_status,
|
||||||
|
"last_sync_stats": s.last_sync_stats,
|
||||||
|
"sync_frequency": s.sync_frequency,
|
||||||
|
"config": s.config,
|
||||||
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
|
}
|
||||||
|
for s in sources
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def update_source(db: Session, source_id: str, **fields: object) -> None:
|
||||||
|
"""Update a data source's fields (is_enabled, sync_frequency, config).
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if source does not exist.
|
||||||
|
Does not commit; the router handles that.
|
||||||
|
"""
|
||||||
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||||
|
if not ds:
|
||||||
|
raise EntityNotFoundError("Data source", source_id)
|
||||||
|
|
||||||
|
if "is_enabled" in fields:
|
||||||
|
ds.is_enabled = fields["is_enabled"]
|
||||||
|
if "sync_frequency" in fields:
|
||||||
|
ds.sync_frequency = fields["sync_frequency"]
|
||||||
|
if "config" in fields:
|
||||||
|
ds.config = fields["config"]
|
||||||
|
|
||||||
|
|
||||||
|
def sync_source(db: Session, source_id: str) -> dict:
|
||||||
|
"""Trigger sync for a specific data source.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if source does not exist.
|
||||||
|
Raises BusinessRuleViolation if no sync handler is available.
|
||||||
|
Commits internally (long-running, self-contained operation).
|
||||||
|
Returns dict with message, source, stats.
|
||||||
|
"""
|
||||||
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||||
|
if not ds:
|
||||||
|
raise EntityNotFoundError("Data source", source_id)
|
||||||
|
|
||||||
|
handler = get_import_handler(ds.name)
|
||||||
|
if handler is None:
|
||||||
|
raise BusinessRuleViolation(f"No sync handler available for '{ds.name}'")
|
||||||
|
|
||||||
|
ds.last_sync_status = "in_progress"
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = handler(db)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||||
|
ds.last_sync_status = "error"
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_stats = {"error": str(exc)}
|
||||||
|
db.commit()
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Sync failed for '{ds.display_name}'. Check server logs for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_status = "success"
|
||||||
|
ds.last_sync_stats = summary
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Sync complete for {ds.display_name}",
|
||||||
|
"source": ds.name,
|
||||||
|
"stats": summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sync_all_sources(db: Session) -> list[dict]:
|
||||||
|
"""Trigger sync for all enabled data sources (sequentially).
|
||||||
|
|
||||||
|
Commits internally (long-running, self-contained operation).
|
||||||
|
Returns list of result dicts with source, status, stats/detail.
|
||||||
|
"""
|
||||||
|
enabled_sources = (
|
||||||
|
db.query(DataSource)
|
||||||
|
.filter(DataSource.is_enabled == True)
|
||||||
|
.order_by(DataSource.name)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ds in enabled_sources:
|
||||||
|
handler = get_import_handler(ds.name)
|
||||||
|
if handler is None:
|
||||||
|
results.append({
|
||||||
|
"source": ds.name,
|
||||||
|
"status": "skipped",
|
||||||
|
"detail": "No sync handler available",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
ds.last_sync_status = "in_progress"
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = handler(db)
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_status = "success"
|
||||||
|
ds.last_sync_stats = summary
|
||||||
|
db.commit()
|
||||||
|
results.append({
|
||||||
|
"source": ds.name,
|
||||||
|
"status": "success",
|
||||||
|
"stats": summary,
|
||||||
|
})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Sync failed for %s: %s", ds.name, exc, exc_info=True)
|
||||||
|
ds.last_sync_status = "error"
|
||||||
|
ds.last_sync_at = datetime.utcnow()
|
||||||
|
ds.last_sync_stats = {"error": str(exc)}
|
||||||
|
db.commit()
|
||||||
|
results.append({
|
||||||
|
"source": ds.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": "Sync failed. Check server logs for details.",
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def get_source_stats(db: Session, source_id: str) -> dict:
|
||||||
|
"""Return detailed statistics for a data source.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if source does not exist.
|
||||||
|
"""
|
||||||
|
ds = db.query(DataSource).filter(DataSource.id == source_id).first()
|
||||||
|
if not ds:
|
||||||
|
raise EntityNotFoundError("Data source", source_id)
|
||||||
|
|
||||||
|
from app.models.test_template import TestTemplate
|
||||||
|
from app.models.detection_rule import DetectionRule
|
||||||
|
|
||||||
|
template_count = 0
|
||||||
|
rule_count = 0
|
||||||
|
|
||||||
|
if ds.type == "attack_procedure":
|
||||||
|
template_count = (
|
||||||
|
db.query(TestTemplate)
|
||||||
|
.filter(TestTemplate.source == ds.name)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
elif ds.type == "detection_rule":
|
||||||
|
rule_count = (
|
||||||
|
db.query(DetectionRule)
|
||||||
|
.filter(DetectionRule.source == ds.name)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(ds.id),
|
||||||
|
"name": ds.name,
|
||||||
|
"display_name": ds.display_name,
|
||||||
|
"type": ds.type,
|
||||||
|
"is_enabled": ds.is_enabled,
|
||||||
|
"last_sync_at": ds.last_sync_at.isoformat() if ds.last_sync_at else None,
|
||||||
|
"last_sync_status": ds.last_sync_status,
|
||||||
|
"last_sync_stats": ds.last_sync_stats,
|
||||||
|
"total_templates": template_count,
|
||||||
|
"total_rules": rule_count,
|
||||||
|
}
|
||||||
@@ -3,12 +3,17 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.domain.exceptions import InvalidOperationError
|
from app.domain.exceptions import InvalidOperationError
|
||||||
from app.models.jira_link import JiraLink
|
from app.models.campaign import Campaign
|
||||||
|
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str:
|
|||||||
lines.append(f"*{key}:* {value}")
|
lines.append(f"*{key}:* {value}")
|
||||||
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
|
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Link CRUD ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def create_link(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
entity_type: JiraLinkEntityType,
|
||||||
|
entity_id: UUID,
|
||||||
|
jira_issue_key: str,
|
||||||
|
sync_direction: JiraSyncDirection,
|
||||||
|
created_by: UUID,
|
||||||
|
) -> JiraLink:
|
||||||
|
"""Create a Jira link and optionally pull initial data from Jira."""
|
||||||
|
link = JiraLink(
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
jira_issue_key=jira_issue_key,
|
||||||
|
sync_direction=sync_direction,
|
||||||
|
created_by=created_by,
|
||||||
|
)
|
||||||
|
db.add(link)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
if settings.JIRA_ENABLED:
|
||||||
|
try:
|
||||||
|
sync_jira_to_aegis(db, link)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
|
||||||
|
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
def list_links(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
entity_type: Optional[JiraLinkEntityType] = None,
|
||||||
|
entity_id: Optional[UUID] = None,
|
||||||
|
) -> list[JiraLink]:
|
||||||
|
"""List Jira links with optional filters."""
|
||||||
|
query = db.query(JiraLink)
|
||||||
|
if entity_type:
|
||||||
|
query = query.filter(JiraLink.entity_type == entity_type)
|
||||||
|
if entity_id:
|
||||||
|
query = query.filter(JiraLink.entity_id == entity_id)
|
||||||
|
return query.order_by(JiraLink.created_at.desc()).all()
|
||||||
|
|
||||||
|
|
||||||
|
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
|
||||||
|
"""Get a Jira link by ID or raise EntityNotFoundError."""
|
||||||
|
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||||
|
if not link:
|
||||||
|
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
def delete_link(db: Session, link_id: UUID) -> JiraLink:
|
||||||
|
"""Delete a Jira link. Returns the deleted link (for audit)."""
|
||||||
|
link = get_link_or_raise(db, link_id)
|
||||||
|
db.delete(link)
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
|
||||||
|
"""Build Jira issue summary and description from an Aegis entity."""
|
||||||
|
if entity_type == JiraLinkEntityType.test:
|
||||||
|
entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||||
|
if not entity:
|
||||||
|
raise EntityNotFoundError("Test", str(entity_id))
|
||||||
|
return (
|
||||||
|
f"[Aegis Test] {entity.name}",
|
||||||
|
f"Test: {entity.name}\n"
|
||||||
|
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
||||||
|
f"Description: {entity.description or 'N/A'}",
|
||||||
|
)
|
||||||
|
elif entity_type == JiraLinkEntityType.campaign:
|
||||||
|
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||||
|
if not entity:
|
||||||
|
raise EntityNotFoundError("Campaign", str(entity_id))
|
||||||
|
return (
|
||||||
|
f"[Aegis Campaign] {entity.name}",
|
||||||
|
f"Campaign: {entity.name}\n"
|
||||||
|
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
||||||
|
f"Description: {entity.description or 'N/A'}",
|
||||||
|
)
|
||||||
|
elif entity_type == JiraLinkEntityType.technique:
|
||||||
|
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
||||||
|
if not entity:
|
||||||
|
raise EntityNotFoundError("Technique", str(entity_id))
|
||||||
|
return (
|
||||||
|
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
||||||
|
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
||||||
|
f"Tactic: {entity.tactic or 'N/A'}\n"
|
||||||
|
f"Description: {entity.description or 'N/A'}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_issue_and_link(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
entity_type: JiraLinkEntityType,
|
||||||
|
entity_id: UUID,
|
||||||
|
created_by: UUID,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a Jira issue from an Aegis entity and link them."""
|
||||||
|
summary, description = build_issue_data(db, entity_type, entity_id)
|
||||||
|
result = create_jira_issue(
|
||||||
|
project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||||
|
summary=summary,
|
||||||
|
description=description,
|
||||||
|
labels=["aegis", entity_type.value],
|
||||||
|
)
|
||||||
|
link = JiraLink(
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
jira_issue_key=result["issue_key"],
|
||||||
|
jira_issue_id=result["issue_id"],
|
||||||
|
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||||
|
created_by=created_by,
|
||||||
|
)
|
||||||
|
db.add(link)
|
||||||
|
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.notification import Notification
|
from app.models.notification import Notification
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
@@ -22,6 +23,71 @@ from app.models.user import User
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def list_notifications(
|
||||||
|
db: Session,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
*,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[Notification]:
|
||||||
|
"""Return paginated notifications for a user, newest first."""
|
||||||
|
return (
|
||||||
|
db.query(Notification)
|
||||||
|
.filter(Notification.user_id == user_id)
|
||||||
|
.order_by(Notification.created_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_notification_or_raise(
|
||||||
|
db: Session,
|
||||||
|
notification_id: uuid.UUID,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
) -> Notification:
|
||||||
|
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
|
||||||
|
notif = (
|
||||||
|
db.query(Notification)
|
||||||
|
.filter(
|
||||||
|
Notification.id == notification_id,
|
||||||
|
Notification.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if notif is None:
|
||||||
|
raise EntityNotFoundError("Notification", str(notification_id))
|
||||||
|
return notif
|
||||||
|
|
||||||
|
|
||||||
|
def notify_role(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
role: str,
|
||||||
|
type: str,
|
||||||
|
title: str,
|
||||||
|
message: str,
|
||||||
|
entity_type: str,
|
||||||
|
entity_id: uuid.UUID,
|
||||||
|
) -> None:
|
||||||
|
"""Send notifications to all active users with a given role."""
|
||||||
|
users = (
|
||||||
|
db.query(User)
|
||||||
|
.filter(User.role == role, User.is_active == True) # noqa: E712
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
for user in users:
|
||||||
|
create_notification(
|
||||||
|
db,
|
||||||
|
user_id=user.id,
|
||||||
|
type=type,
|
||||||
|
title=title,
|
||||||
|
message=message,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_notification(
|
def create_notification(
|
||||||
db: Session,
|
db: Session,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
@@ -45,17 +111,13 @@ def create_notification(
|
|||||||
return notif
|
return notif
|
||||||
|
|
||||||
|
|
||||||
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
|
def mark_as_read(
|
||||||
"""Mark a single notification as read. Returns True if updated."""
|
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
|
||||||
notif = (
|
) -> Notification:
|
||||||
db.query(Notification)
|
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
|
||||||
.filter(Notification.id == notification_id, Notification.user_id == user_id)
|
notif = get_notification_or_raise(db, notification_id, user_id)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if notif is None:
|
|
||||||
return False
|
|
||||||
notif.read = True
|
notif.read = True
|
||||||
return True
|
return notif
|
||||||
|
|
||||||
|
|
||||||
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
||||||
|
|||||||
@@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.osint_item import OsintItem
|
from app.models.osint_item import OsintItem
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
|
|
||||||
@@ -177,15 +181,97 @@ def get_osint_items_for_technique(
|
|||||||
|
|
||||||
|
|
||||||
def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
|
def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
|
||||||
"""Mark an OSINT item as reviewed."""
|
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork."""
|
||||||
item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
|
item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
|
||||||
if item:
|
if item:
|
||||||
item.reviewed = True
|
item.reviewed = True
|
||||||
db.commit()
|
|
||||||
db.refresh(item)
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
def get_unreviewed_count(db: Session) -> int:
|
def get_unreviewed_count(db: Session) -> int:
|
||||||
"""Return the total number of unreviewed OSINT items."""
|
"""Return the total number of unreviewed OSINT items."""
|
||||||
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
|
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
|
||||||
|
|
||||||
|
|
||||||
|
def list_osint_items(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
technique_id: Optional[UUID] = None,
|
||||||
|
source_type: Optional[str] = None,
|
||||||
|
reviewed: Optional[bool] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""List OSINT items with optional filters and pagination."""
|
||||||
|
query = db.query(OsintItem)
|
||||||
|
if technique_id:
|
||||||
|
query = query.filter(OsintItem.technique_id == technique_id)
|
||||||
|
if source_type:
|
||||||
|
query = query.filter(OsintItem.source_type == source_type)
|
||||||
|
if reviewed is not None:
|
||||||
|
query = query.filter(OsintItem.reviewed == reviewed)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
items = (
|
||||||
|
query.order_by(OsintItem.discovered_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"id": str(item.id),
|
||||||
|
"technique_id": str(item.technique_id),
|
||||||
|
"source_type": item.source_type,
|
||||||
|
"source_url": item.source_url,
|
||||||
|
"title": item.title,
|
||||||
|
"description": item.description,
|
||||||
|
"severity": item.severity,
|
||||||
|
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||||
|
"reviewed": item.reviewed,
|
||||||
|
"metadata": item.metadata_,
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_osint_summary(db: Session) -> dict:
|
||||||
|
"""Summary statistics for OSINT items."""
|
||||||
|
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
||||||
|
unreviewed = get_unreviewed_count(db)
|
||||||
|
|
||||||
|
by_severity = dict(
|
||||||
|
db.query(OsintItem.severity, func.count(OsintItem.id))
|
||||||
|
.group_by(OsintItem.severity)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
by_type = dict(
|
||||||
|
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
||||||
|
.group_by(OsintItem.source_type)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
techniques_with_items = (
|
||||||
|
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_items": total,
|
||||||
|
"unreviewed": unreviewed,
|
||||||
|
"techniques_with_items": techniques_with_items,
|
||||||
|
"by_severity": by_severity,
|
||||||
|
"by_type": by_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
|
||||||
|
"""Get a technique by ID or raise EntityNotFoundError."""
|
||||||
|
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||||
|
if not technique:
|
||||||
|
raise EntityNotFoundError("Technique", str(technique_id))
|
||||||
|
return technique
|
||||||
|
|||||||
@@ -82,9 +82,7 @@ def update_scoring_weights(
|
|||||||
row.weight_freshness = new.freshness
|
row.weight_freshness = new.freshness
|
||||||
row.weight_platform_diversity = new.platform_diversity
|
row.weight_platform_diversity = new.platform_diversity
|
||||||
|
|
||||||
db.commit()
|
# Does not commit; caller (router) uses UnitOfWork.
|
||||||
db.refresh(row)
|
|
||||||
|
|
||||||
return _weights_dict(new)
|
return _weights_dict(new)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import Optional
|
|||||||
from sqlalchemy import case, func
|
from sqlalchemy import case, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.test import Test
|
from app.models.test import Test
|
||||||
from app.models.detection_rule import DetectionRule
|
from app.models.detection_rule import DetectionRule
|
||||||
@@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict:
|
|||||||
# ── Technique-level scoring (single technique — preserved API) ────────
|
# ── Technique-level scoring (single technique — preserved API) ────────
|
||||||
|
|
||||||
|
|
||||||
|
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
|
||||||
|
"""Get detailed score with breakdown for a technique by MITRE ID."""
|
||||||
|
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||||
|
if not technique:
|
||||||
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
result = calculate_technique_score(technique, db)
|
||||||
|
return {
|
||||||
|
"mitre_id": technique.mitre_id,
|
||||||
|
"name": technique.name,
|
||||||
|
"tactic": technique.tactic,
|
||||||
|
"status_global": technique.status_global.value if technique.status_global else None,
|
||||||
|
**result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def score_actor_by_id(db: Session, actor_id: str) -> dict:
|
||||||
|
"""Get coverage score for a threat actor by ID."""
|
||||||
|
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||||
|
if not actor:
|
||||||
|
raise EntityNotFoundError("ThreatActor", actor_id)
|
||||||
|
return calculate_actor_coverage_score(actor_id, db)
|
||||||
|
|
||||||
|
|
||||||
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||||
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from datetime import datetime
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||||
from app.models.enums import TechniqueStatus
|
from app.models.enums import TechniqueStatus
|
||||||
@@ -25,6 +26,101 @@ from app.services.scoring_service import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serialization and queries
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
||||||
|
"""Lightweight serialization for list views."""
|
||||||
|
return {
|
||||||
|
"id": str(snap.id),
|
||||||
|
"name": snap.name,
|
||||||
|
"organization_score": snap.organization_score,
|
||||||
|
"total_techniques": snap.total_techniques,
|
||||||
|
"validated_count": snap.validated_count,
|
||||||
|
"partial_count": snap.partial_count,
|
||||||
|
"not_covered_count": snap.not_covered_count,
|
||||||
|
"in_progress_count": snap.in_progress_count,
|
||||||
|
"not_evaluated_count": snap.not_evaluated_count,
|
||||||
|
"created_by": str(snap.created_by) if snap.created_by else None,
|
||||||
|
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
||||||
|
"""Full serialization including technique states."""
|
||||||
|
base = serialize_snapshot_summary(snap)
|
||||||
|
|
||||||
|
technique_states = (
|
||||||
|
db.query(SnapshotTechniqueState)
|
||||||
|
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
||||||
|
.order_by(SnapshotTechniqueState.mitre_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
base["technique_states"] = [
|
||||||
|
{
|
||||||
|
"mitre_id": s.mitre_id,
|
||||||
|
"technique_id": str(s.technique_id),
|
||||||
|
"status": s.status,
|
||||||
|
"score": s.score,
|
||||||
|
}
|
||||||
|
for s in technique_states
|
||||||
|
]
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
def list_snapshots(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||||
|
query = db.query(CoverageSnapshot)
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
snapshots = (
|
||||||
|
query
|
||||||
|
.order_by(CoverageSnapshot.created_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"offset": offset,
|
||||||
|
"limit": limit,
|
||||||
|
"items": [serialize_snapshot_summary(s) for s in snapshots],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
|
||||||
|
"""Fetch snapshot by ID or raise EntityNotFoundError."""
|
||||||
|
try:
|
||||||
|
sid = uuid.UUID(snapshot_id)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||||
|
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
|
||||||
|
if snapshot is None:
|
||||||
|
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
|
||||||
|
"""Get detailed snapshot including per-technique states."""
|
||||||
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
|
return serialize_snapshot_detail(db, snapshot)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_snapshot(db: Session, snapshot_id: str) -> None:
|
||||||
|
"""Delete a snapshot. Does not commit — caller must commit."""
|
||||||
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
|
db.delete(snapshot)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Create snapshot
|
# Create snapshot
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -138,7 +234,7 @@ def compare_snapshots(
|
|||||||
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
||||||
|
|
||||||
if not snap_a or not snap_b:
|
if not snap_a or not snap_b:
|
||||||
return {"error": "One or both snapshots not found"}
|
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
|
||||||
|
|
||||||
# Build lookup dicts: mitre_id -> {status, score}
|
# Build lookup dicts: mitre_id -> {status, score}
|
||||||
states_a = {
|
states_a = {
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""Technique query service — framework-agnostic queries for technique details."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||||
|
|
||||||
|
|
||||||
|
def get_technique_detail(db: Session, mitre_id: str) -> dict:
|
||||||
|
"""Fetch full technique details including tests and D3FEND defenses."""
|
||||||
|
technique = (
|
||||||
|
db.query(Technique)
|
||||||
|
.options(joinedload(Technique.tests))
|
||||||
|
.filter(Technique.mitre_id == mitre_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
defenses = get_defenses_for_technique(db, technique.id)
|
||||||
|
return {
|
||||||
|
"id": str(technique.id),
|
||||||
|
"mitre_id": technique.mitre_id,
|
||||||
|
"name": technique.name,
|
||||||
|
"description": technique.description,
|
||||||
|
"tactic": technique.tactic,
|
||||||
|
"platforms": technique.platforms or [],
|
||||||
|
"mitre_version": technique.mitre_version,
|
||||||
|
"mitre_last_modified": technique.mitre_last_modified,
|
||||||
|
"is_subtechnique": technique.is_subtechnique,
|
||||||
|
"parent_mitre_id": technique.parent_mitre_id,
|
||||||
|
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
||||||
|
"review_required": technique.review_required,
|
||||||
|
"last_review_date": technique.last_review_date,
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"id": str(t.id),
|
||||||
|
"name": t.name,
|
||||||
|
"state": t.state.value if t.state else None,
|
||||||
|
"result": t.result.value if t.result else None,
|
||||||
|
"platform": t.platform,
|
||||||
|
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||||
|
}
|
||||||
|
for t in technique.tests
|
||||||
|
],
|
||||||
|
"d3fend_defenses": defenses,
|
||||||
|
}
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
"""Test template service — framework-agnostic CRUD and queries."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import func, or_
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
|
from app.models.test_template import TestTemplate
|
||||||
|
from app.utils import escape_like
|
||||||
|
|
||||||
|
|
||||||
|
def list_templates(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
source: str | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
severity: str | None = None,
|
||||||
|
mitre_technique_id: str | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> list:
|
||||||
|
"""Return paginated, filterable list of test templates."""
|
||||||
|
query = db.query(TestTemplate)
|
||||||
|
if is_active is not None:
|
||||||
|
query = query.filter(TestTemplate.is_active == is_active)
|
||||||
|
|
||||||
|
if source:
|
||||||
|
query = query.filter(TestTemplate.source == source)
|
||||||
|
if platform:
|
||||||
|
query = query.filter(TestTemplate.platform.ilike(f"%{escape_like(platform)}%"))
|
||||||
|
if severity:
|
||||||
|
query = query.filter(TestTemplate.severity == severity)
|
||||||
|
if mitre_technique_id:
|
||||||
|
query = query.filter(TestTemplate.mitre_technique_id == mitre_technique_id)
|
||||||
|
if search:
|
||||||
|
pattern = f"%{escape_like(search)}%"
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
TestTemplate.name.ilike(pattern),
|
||||||
|
TestTemplate.description.ilike(pattern),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
templates = (
|
||||||
|
query
|
||||||
|
.order_by(TestTemplate.mitre_technique_id, TestTemplate.name)
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return templates
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_stats(db: Session) -> dict:
|
||||||
|
"""Return catalog statistics: totals by source, platform, active/inactive."""
|
||||||
|
total = db.query(func.count(TestTemplate.id)).scalar() or 0
|
||||||
|
active = (
|
||||||
|
db.query(func.count(TestTemplate.id))
|
||||||
|
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||||
|
.scalar()
|
||||||
|
) or 0
|
||||||
|
inactive = total - active
|
||||||
|
|
||||||
|
source_rows = (
|
||||||
|
db.query(TestTemplate.source, func.count(TestTemplate.id))
|
||||||
|
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||||
|
.group_by(TestTemplate.source)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
by_source = {source: cnt for source, cnt in source_rows}
|
||||||
|
|
||||||
|
platform_rows = (
|
||||||
|
db.query(TestTemplate.platform, func.count(TestTemplate.id))
|
||||||
|
.filter(TestTemplate.is_active == True) # noqa: E712
|
||||||
|
.group_by(TestTemplate.platform)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
by_platform = {(platform or "unspecified"): cnt for platform, cnt in platform_rows}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"active": active,
|
||||||
|
"inactive": inactive,
|
||||||
|
"by_source": by_source,
|
||||||
|
"by_platform": by_platform,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bulk_activate(db: Session, *, activate: bool) -> int:
|
||||||
|
"""Set all templates to active or inactive. Returns count of affected. Does NOT commit."""
|
||||||
|
count = (
|
||||||
|
db.query(TestTemplate)
|
||||||
|
.filter(TestTemplate.is_active != activate)
|
||||||
|
.update({TestTemplate.is_active: activate})
|
||||||
|
)
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def get_templates_by_technique(db: Session, mitre_id: str) -> list:
|
||||||
|
"""Return all active templates mapped to a specific MITRE technique."""
|
||||||
|
return (
|
||||||
|
db.query(TestTemplate)
|
||||||
|
.filter(
|
||||||
|
TestTemplate.mitre_technique_id == mitre_id,
|
||||||
|
TestTemplate.is_active == True, # noqa: E712
|
||||||
|
)
|
||||||
|
.order_by(TestTemplate.name)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_or_raise(db: Session, template_id: uuid.UUID) -> TestTemplate:
|
||||||
|
"""Return a template by ID. Raises EntityNotFoundError if not found."""
|
||||||
|
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||||
|
if template is None:
|
||||||
|
raise EntityNotFoundError("Test template", str(template_id))
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def create_template(db: Session, **fields: object) -> TestTemplate:
|
||||||
|
"""Create a test template from keyword args (e.g. payload.model_dump()). Does NOT commit."""
|
||||||
|
template = TestTemplate(**fields)
|
||||||
|
db.add(template)
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def update_template(db: Session, template_id: uuid.UUID, **fields: object) -> TestTemplate:
|
||||||
|
"""Update an existing template. Raises EntityNotFoundError if not found. Does NOT commit."""
|
||||||
|
template = get_template_or_raise(db, template_id)
|
||||||
|
for field, value in fields.items():
|
||||||
|
if hasattr(template, field):
|
||||||
|
setattr(template, field, value)
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def toggle_template_active(db: Session, template_id: uuid.UUID) -> TestTemplate:
|
||||||
|
"""Toggle template active/inactive. Does NOT commit."""
|
||||||
|
template = get_template_or_raise(db, template_id)
|
||||||
|
template.is_active = not template.is_active
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def soft_delete_template(db: Session, template_id: uuid.UUID) -> None:
|
||||||
|
"""Soft-delete a template by setting is_active=False. Does NOT commit."""
|
||||||
|
template = get_template_or_raise(db, template_id)
|
||||||
|
template.is_active = False
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
"""User management service — framework-agnostic CRUD for users.
|
||||||
|
|
||||||
|
Uses domain exceptions from app.domain.errors. The router handles
|
||||||
|
HTTP concerns, auth, audit logging, and commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.auth import hash_password
|
||||||
|
from app.domain.errors import BusinessRuleViolation, DuplicateEntityError, EntityNotFoundError
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
VALID_ROLES = {"admin", "red_tech", "blue_tech", "red_lead", "blue_lead", "viewer"}
|
||||||
|
|
||||||
|
|
||||||
|
def list_users(db: Session) -> list[User]:
|
||||||
|
"""Return a list of all users ordered by username."""
|
||||||
|
return db.query(User).order_by(User.username).all()
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
username: str,
|
||||||
|
email: str | None,
|
||||||
|
password: str,
|
||||||
|
role: str,
|
||||||
|
) -> User:
|
||||||
|
"""Create a new user.
|
||||||
|
|
||||||
|
Raises DuplicateEntityError if username already exists.
|
||||||
|
Raises BusinessRuleViolation if role is invalid.
|
||||||
|
Does not commit; the router handles that.
|
||||||
|
"""
|
||||||
|
existing = db.query(User).filter(User.username == username).first()
|
||||||
|
if existing:
|
||||||
|
raise DuplicateEntityError("User", "username", username)
|
||||||
|
|
||||||
|
if role not in VALID_ROLES:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Invalid role '{role}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
username=username,
|
||||||
|
email=email,
|
||||||
|
hashed_password=hash_password(password),
|
||||||
|
role=role,
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_or_raise(db: Session, user_id: uuid.UUID) -> User:
|
||||||
|
"""Return a user by ID or raise EntityNotFoundError."""
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if user is None:
|
||||||
|
raise EntityNotFoundError("User", str(user_id))
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def update_user(db: Session, user_id: uuid.UUID, **fields: object) -> User:
|
||||||
|
"""Update one or more fields of an existing user.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if user does not exist.
|
||||||
|
Raises BusinessRuleViolation if role is invalid.
|
||||||
|
Handles 'password' by hashing and storing as 'hashed_password'.
|
||||||
|
Does not commit; the router handles that.
|
||||||
|
"""
|
||||||
|
user = get_user_or_raise(db, user_id)
|
||||||
|
update_data = dict(fields)
|
||||||
|
|
||||||
|
if "role" in update_data and update_data["role"] not in VALID_ROLES:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Invalid role '{update_data['role']}'. Must be one of: {', '.join(sorted(VALID_ROLES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "password" in update_data:
|
||||||
|
update_data["hashed_password"] = hash_password(str(update_data.pop("password")))
|
||||||
|
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(user, field, value)
|
||||||
|
|
||||||
|
return user
|
||||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.worklog import Worklog
|
from app.models.worklog import Worklog
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -38,8 +39,15 @@ def create_worklog(
|
|||||||
)
|
)
|
||||||
wl.integrity_hash = _compute_hash(wl)
|
wl.integrity_hash = _compute_hash(wl)
|
||||||
db.add(wl)
|
db.add(wl)
|
||||||
db.commit()
|
# Does not commit; caller (router) uses UnitOfWork.
|
||||||
db.refresh(wl)
|
return wl
|
||||||
|
|
||||||
|
|
||||||
|
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
|
||||||
|
"""Get a worklog by ID or raise EntityNotFoundError."""
|
||||||
|
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||||
|
if not wl:
|
||||||
|
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||||
return wl
|
return wl
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,175 @@
|
|||||||
|
"""Tests for CampaignEntity — pure domain logic, no DB."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
if backend_dir not in sys.path:
|
||||||
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
|
from app.domain.entities.campaign import (
|
||||||
|
CampaignEntity,
|
||||||
|
CampaignStatus,
|
||||||
|
CampaignType,
|
||||||
|
)
|
||||||
|
from app.domain.errors import BusinessRuleViolation, InvalidStateTransition
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _entity(status: str = "draft", test_count: int = 0, **overrides) -> CampaignEntity:
|
||||||
|
defaults = dict(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Campaign",
|
||||||
|
type=CampaignType.custom,
|
||||||
|
status=CampaignStatus(status),
|
||||||
|
description=None,
|
||||||
|
threat_actor_id=None,
|
||||||
|
created_by=None,
|
||||||
|
target_platform=None,
|
||||||
|
tags=[],
|
||||||
|
test_count=test_count,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return CampaignEntity(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_orm(status: str = "draft", test_count: int = 0, **overrides) -> MagicMock:
|
||||||
|
m = MagicMock()
|
||||||
|
m.id = uuid.uuid4()
|
||||||
|
m.name = "Test Campaign"
|
||||||
|
m.type = "custom"
|
||||||
|
m.status = status
|
||||||
|
m.description = None
|
||||||
|
m.threat_actor_id = None
|
||||||
|
m.created_by = None
|
||||||
|
m.target_platform = None
|
||||||
|
m.tags = []
|
||||||
|
m.campaign_tests = [MagicMock()] * test_count if test_count else []
|
||||||
|
for k, v in overrides.items():
|
||||||
|
setattr(m, k, v)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
# ── 1. Test activation from draft with tests → success ───────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_draft_with_tests_success():
|
||||||
|
e = _entity("draft", test_count=1)
|
||||||
|
e.activate()
|
||||||
|
assert e.status == CampaignStatus.active
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_draft_with_multiple_tests_success():
|
||||||
|
e = _entity("draft", test_count=3)
|
||||||
|
e.activate()
|
||||||
|
assert e.status == CampaignStatus.active
|
||||||
|
|
||||||
|
|
||||||
|
# ── 2. Test activation from draft with 0 tests → BusinessRuleViolation ───
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_draft_with_zero_tests_raises():
|
||||||
|
e = _entity("draft", test_count=0)
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="at least one test"):
|
||||||
|
e.activate()
|
||||||
|
assert e.status == CampaignStatus.draft
|
||||||
|
|
||||||
|
|
||||||
|
# ── 3. Test activation from active → InvalidStateTransition ────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_activate_from_active_raises():
|
||||||
|
e = _entity("active", test_count=2)
|
||||||
|
with pytest.raises(InvalidStateTransition) as exc_info:
|
||||||
|
e.activate()
|
||||||
|
assert exc_info.value.current_state == "active"
|
||||||
|
assert exc_info.value.target_state == "active"
|
||||||
|
assert "completed" in exc_info.value.valid_transitions
|
||||||
|
|
||||||
|
|
||||||
|
# ── 4. Test complete from active → success ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_from_active_success():
|
||||||
|
e = _entity("active", test_count=2)
|
||||||
|
e.complete()
|
||||||
|
assert e.status == CampaignStatus.completed
|
||||||
|
|
||||||
|
|
||||||
|
# ── 5. Test complete from draft → InvalidStateTransition ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_from_draft_raises():
|
||||||
|
e = _entity("draft", test_count=1)
|
||||||
|
with pytest.raises(InvalidStateTransition) as exc_info:
|
||||||
|
e.complete()
|
||||||
|
assert exc_info.value.current_state == "draft"
|
||||||
|
assert exc_info.value.target_state == "completed"
|
||||||
|
assert "active" in exc_info.value.valid_transitions
|
||||||
|
|
||||||
|
|
||||||
|
# ── 6. Test ensure_modifiable in draft/active → ok ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_draft_ok():
|
||||||
|
e = _entity("draft")
|
||||||
|
e.ensure_modifiable() # no raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_active_ok():
|
||||||
|
e = _entity("active", test_count=1)
|
||||||
|
e.ensure_modifiable() # no raise
|
||||||
|
|
||||||
|
|
||||||
|
# ── 7. Test ensure_modifiable in completed → BusinessRuleViolation ──────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_completed_raises():
|
||||||
|
e = _entity("completed", test_count=1)
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
|
||||||
|
e.ensure_modifiable()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_modifiable_archived_raises():
|
||||||
|
e = _entity("archived", test_count=1)
|
||||||
|
with pytest.raises(BusinessRuleViolation, match="Cannot modify"):
|
||||||
|
e.ensure_modifiable()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 8. Test from_orm conversion ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_basic():
|
||||||
|
orm = _fake_orm("draft", test_count=0)
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.name == "Test Campaign"
|
||||||
|
assert e.type == CampaignType.custom
|
||||||
|
assert e.status == CampaignStatus.draft
|
||||||
|
assert e.id == orm.id
|
||||||
|
assert e.test_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_with_tests():
|
||||||
|
orm = _fake_orm("draft", test_count=3)
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.test_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_coerces_type_and_status():
|
||||||
|
orm = _fake_orm(status="active", type="apt_emulation", test_count=1)
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.status == CampaignStatus.active
|
||||||
|
assert e.type == CampaignType.apt_emulation
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_handles_none_tags():
|
||||||
|
orm = _fake_orm("draft", test_count=0)
|
||||||
|
orm.tags = None
|
||||||
|
e = CampaignEntity.from_orm(orm)
|
||||||
|
assert e.tags == []
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
"""Tests for compliance domain entities."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.domain.entities.compliance import (
|
||||||
|
ComplianceControlEntity,
|
||||||
|
ComplianceFrameworkEntity,
|
||||||
|
ControlCoverageStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Control coverage status ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_all_techniques_validated_covered():
|
||||||
|
"""All techniques validated → covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["validated", "validated"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_all_techniques_partial_covered():
|
||||||
|
"""All techniques partial → covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["partial"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_mixed_statuses_partially_covered():
|
||||||
|
"""Mixed statuses (some validated/partial, some not) → partially_covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["validated", "not_evaluated"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.partially_covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_no_validated_techniques_not_covered():
|
||||||
|
"""No validated/partial techniques → not_covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=["not_evaluated", "not_covered"],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.not_covered
|
||||||
|
|
||||||
|
|
||||||
|
def test_control_empty_techniques_not_covered():
|
||||||
|
"""Empty technique_statuses → not_covered."""
|
||||||
|
control = ComplianceControlEntity(
|
||||||
|
control_id="AC-2",
|
||||||
|
title="Account Management",
|
||||||
|
technique_statuses=[],
|
||||||
|
)
|
||||||
|
assert control.coverage_status == ControlCoverageStatus.not_covered
|
||||||
|
|
||||||
|
|
||||||
|
# ── Framework coverage ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_framework_coverage_pct_calculation():
|
||||||
|
"""Framework coverage_pct = (covered_controls / total_controls) * 100."""
|
||||||
|
controls = [
|
||||||
|
ComplianceControlEntity("AC-1", "Title 1", technique_statuses=["validated"]),
|
||||||
|
ComplianceControlEntity("AC-2", "Title 2", technique_statuses=["not_evaluated"]),
|
||||||
|
ComplianceControlEntity("AC-3", "Title 3", technique_statuses=["validated", "partial"]),
|
||||||
|
ComplianceControlEntity("AC-4", "Title 4", technique_statuses=["partial"]),
|
||||||
|
ComplianceControlEntity("AC-5", "Title 5", technique_statuses=[]),
|
||||||
|
]
|
||||||
|
framework = ComplianceFrameworkEntity(name="NIST 800-53", controls=controls)
|
||||||
|
# AC-1: covered, AC-2: not_covered, AC-3: covered, AC-4: covered, AC-5: not_covered
|
||||||
|
assert framework.total_controls == 5
|
||||||
|
assert framework.covered_controls == 3
|
||||||
|
assert framework.coverage_pct == 60.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_framework_get_gap_controls():
|
||||||
|
"""get_gap_controls returns only uncovered and partially_covered controls."""
|
||||||
|
controls = [
|
||||||
|
ComplianceControlEntity("AC-1", "Covered", technique_statuses=["validated"]),
|
||||||
|
ComplianceControlEntity("AC-2", "Partial", technique_statuses=["validated", "not_evaluated"]),
|
||||||
|
ComplianceControlEntity("AC-3", "Not Covered", technique_statuses=["not_evaluated"]),
|
||||||
|
ComplianceControlEntity("AC-4", "Empty", technique_statuses=[]),
|
||||||
|
]
|
||||||
|
framework = ComplianceFrameworkEntity(name="Test", controls=controls)
|
||||||
|
gaps = framework.get_gap_controls()
|
||||||
|
assert len(gaps) == 3
|
||||||
|
assert gaps[0].control_id == "AC-2"
|
||||||
|
assert gaps[1].control_id == "AC-3"
|
||||||
|
assert gaps[2].control_id == "AC-4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_framework_no_controls_coverage_pct_zero():
|
||||||
|
"""Framework with no controls → coverage_pct is 0."""
|
||||||
|
framework = ComplianceFrameworkEntity(name="Empty", controls=[])
|
||||||
|
assert framework.total_controls == 0
|
||||||
|
assert framework.covered_controls == 0
|
||||||
|
assert framework.coverage_pct == 0.0
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
"""Tests for the ImportService protocol and IMPORT_REGISTRY."""
|
||||||
|
|
||||||
|
from app.domain.ports.import_service import (
|
||||||
|
IMPORT_REGISTRY,
|
||||||
|
ImportService,
|
||||||
|
ImportServiceEntry,
|
||||||
|
get_import_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_has_all_known_sources():
|
||||||
|
expected = {
|
||||||
|
"atomic_red_team",
|
||||||
|
"sigma",
|
||||||
|
"lolbas",
|
||||||
|
"gtfobins",
|
||||||
|
"caldera",
|
||||||
|
"elastic_rules",
|
||||||
|
"mitre_cti",
|
||||||
|
"d3fend",
|
||||||
|
}
|
||||||
|
assert set(IMPORT_REGISTRY.keys()) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_entries_are_import_service_entries():
|
||||||
|
for name, entry in IMPORT_REGISTRY.items():
|
||||||
|
assert isinstance(entry, ImportServiceEntry), f"{name} is not ImportServiceEntry"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_import_handler_returns_entry_for_known_source():
|
||||||
|
handler = get_import_handler("sigma")
|
||||||
|
assert handler is not None
|
||||||
|
assert isinstance(handler, ImportServiceEntry)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_import_handler_returns_none_for_unknown():
|
||||||
|
assert get_import_handler("nonexistent_source") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_service_entry_source_info():
|
||||||
|
entry = ImportServiceEntry("app.services.sigma_import_service", "sync")
|
||||||
|
assert entry.source_info == "app.services.sigma_import_service.sync"
|
||||||
|
|
||||||
|
|
||||||
|
def test_callable_satisfies_protocol():
|
||||||
|
def mock_handler(db):
|
||||||
|
return {"created": 0}
|
||||||
|
|
||||||
|
assert isinstance(mock_handler, ImportService)
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
"""Tests for the ThreatActorEntity domain entity."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from app.domain.entities.threat_actor import (
|
||||||
|
ThreatActorEntity,
|
||||||
|
ThreatActorTechniqueRef,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ref(status: str = "not_evaluated") -> ThreatActorTechniqueRef:
|
||||||
|
return ThreatActorTechniqueRef(
|
||||||
|
technique_id=uuid.uuid4(),
|
||||||
|
mitre_id="T1059",
|
||||||
|
name="Command and Scripting Interpreter",
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_coverage_pct_all_covered():
|
||||||
|
actor = ThreatActorEntity(
|
||||||
|
name="APT29",
|
||||||
|
techniques=[_make_ref("validated"), _make_ref("partial")],
|
||||||
|
)
|
||||||
|
assert actor.coverage_pct == 100.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_coverage_pct_partial():
|
||||||
|
actor = ThreatActorEntity(
|
||||||
|
name="APT28",
|
||||||
|
techniques=[_make_ref("validated"), _make_ref("not_evaluated")],
|
||||||
|
)
|
||||||
|
assert actor.coverage_pct == 50.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_coverage_pct_none_covered():
|
||||||
|
actor = ThreatActorEntity(
|
||||||
|
name="Lazarus",
|
||||||
|
techniques=[_make_ref("not_evaluated"), _make_ref("in_progress")],
|
||||||
|
)
|
||||||
|
assert actor.coverage_pct == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_coverage_pct_no_techniques():
|
||||||
|
actor = ThreatActorEntity(name="Unknown")
|
||||||
|
assert actor.coverage_pct == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_covered_and_uncovered_techniques():
|
||||||
|
t1 = _make_ref("validated")
|
||||||
|
t2 = _make_ref("not_evaluated")
|
||||||
|
t3 = _make_ref("partial")
|
||||||
|
actor = ThreatActorEntity(name="Test", techniques=[t1, t2, t3])
|
||||||
|
assert len(actor.covered_techniques) == 2
|
||||||
|
assert len(actor.uncovered_techniques) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_technique_count():
|
||||||
|
actor = ThreatActorEntity(
|
||||||
|
name="Test",
|
||||||
|
techniques=[_make_ref(), _make_ref(), _make_ref()],
|
||||||
|
)
|
||||||
|
assert actor.technique_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_basic():
|
||||||
|
orm = SimpleNamespace(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="APT29",
|
||||||
|
mitre_id="G0016",
|
||||||
|
aliases=["Cozy Bear", "The Dukes"],
|
||||||
|
description="Russian APT group",
|
||||||
|
country="Russia",
|
||||||
|
target_sectors=["government"],
|
||||||
|
target_regions=["north-america"],
|
||||||
|
motivation="espionage",
|
||||||
|
sophistication="advanced",
|
||||||
|
first_seen="2008",
|
||||||
|
last_seen="2023",
|
||||||
|
is_active=True,
|
||||||
|
techniques=[],
|
||||||
|
)
|
||||||
|
entity = ThreatActorEntity.from_orm(orm)
|
||||||
|
assert entity.name == "APT29"
|
||||||
|
assert entity.mitre_id == "G0016"
|
||||||
|
assert entity.country == "Russia"
|
||||||
|
assert entity.aliases == ["Cozy Bear", "The Dukes"]
|
||||||
|
assert entity.technique_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_orm_with_techniques():
|
||||||
|
tech_orm = SimpleNamespace(
|
||||||
|
mitre_id="T1059",
|
||||||
|
name="Command and Scripting Interpreter",
|
||||||
|
status_global=SimpleNamespace(value="validated"),
|
||||||
|
)
|
||||||
|
tat_orm = SimpleNamespace(
|
||||||
|
technique_id=uuid.uuid4(),
|
||||||
|
technique=tech_orm,
|
||||||
|
usage_description="Uses PowerShell",
|
||||||
|
)
|
||||||
|
orm = SimpleNamespace(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="APT28",
|
||||||
|
mitre_id="G0007",
|
||||||
|
aliases=None,
|
||||||
|
description=None,
|
||||||
|
country=None,
|
||||||
|
target_sectors=None,
|
||||||
|
target_regions=None,
|
||||||
|
motivation=None,
|
||||||
|
sophistication=None,
|
||||||
|
first_seen=None,
|
||||||
|
last_seen=None,
|
||||||
|
is_active=None,
|
||||||
|
techniques=[tat_orm],
|
||||||
|
)
|
||||||
|
entity = ThreatActorEntity.from_orm(orm)
|
||||||
|
assert entity.technique_count == 1
|
||||||
|
assert entity.techniques[0].mitre_id == "T1059"
|
||||||
|
assert entity.techniques[0].status == "validated"
|
||||||
|
assert entity.is_active is True # defaults when None
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
# Aegis — Deep Architectural Analysis
|
# Aegis — Deep Architectural Analysis
|
||||||
|
|
||||||
> **Author:** Automated architecture review
|
> **Author:** Automated architecture review
|
||||||
> **Date:** February 11, 2026 (updated February 19, 2026)
|
> **Date:** February 11, 2026 (updated February 20, 2026; Tier 1-4 complete)
|
||||||
> **Scope:** Backend (FastAPI/Python), Frontend (React/TypeScript), Infrastructure (Docker)
|
> **Scope:** Backend (FastAPI/Python), Frontend (React/TypeScript), Infrastructure (Docker)
|
||||||
>
|
>
|
||||||
> **Note:** Sections marked with ✅ reflect changes implemented since the initial analysis.
|
> **Note:** Sections marked with ✅ reflect changes implemented since the initial analysis.
|
||||||
@@ -69,9 +69,9 @@ Aegis follows a **layered monolithic architecture** deployed as two containers (
|
|||||||
|
|
||||||
| Layer | Files | Actual Responsibility |
|
| Layer | Files | Actual Responsibility |
|
||||||
|-------|-------|----------------------|
|
|-------|-------|----------------------|
|
||||||
| **Routers** | 21 files | ✅ Thin HTTP adapters — auth, param parsing, response formatting. Delegate to services. |
|
| **Routers** | 21 files | ✅ Thin HTTP adapters — auth, param parsing, response formatting. All delegate to services. Zero inline ORM queries. |
|
||||||
| **Services** | 30+ files | ✅ All business logic, query orchestration, domain validation. Framework-agnostic. |
|
| **Services** | 40+ files | ✅ All business logic, query orchestration, domain validation. Framework-agnostic. Includes: 4 newly extracted (advanced_metrics, analytics, test_template, auth) + 7 new query services (technique_query, d3fend_query, etc.). |
|
||||||
| **Domain** | 8+ files | ✅ Pure entities, value objects, ports, errors. Zero framework imports. |
|
| **Domain** | 15+ files | ✅ Pure entities (Test, Technique, Campaign, Compliance, ThreatActor), value objects, ports (repos + ImportService protocol), errors. Zero framework imports. |
|
||||||
| **Infrastructure** | 5+ files | ✅ Repository implementations, Redis client, mappers. |
|
| **Infrastructure** | 5+ files | ✅ Repository implementations, Redis client, mappers. |
|
||||||
| **Models** | 19 files | ORM table definitions — persistence mapping only |
|
| **Models** | 19 files | ORM table definitions — persistence mapping only |
|
||||||
| **Schemas** | 10 files | Pydantic DTOs for request/response |
|
| **Schemas** | 10 files | Pydantic DTOs for request/response |
|
||||||
@@ -91,9 +91,11 @@ def get_threat_actor(actor_id: str, db=Depends(get_db), current_user=Depends(get
|
|||||||
return get_actor_detail(db, actor_id)
|
return get_actor_detail(db, actor_id)
|
||||||
```
|
```
|
||||||
|
|
||||||
Extracted services: `coverage_report_service`, `metrics_query_service`, `compliance_service`, `detection_rule_service`, `threat_actor_service`, `test_crud_service`, `evidence_service`, `campaign_crud_service`, `scoring_config_service`.
|
Extracted services: `coverage_report_service`, `metrics_query_service`, `compliance_service`, `detection_rule_service`, `threat_actor_service`, `test_crud_service`, `evidence_service`, `campaign_crud_service`, `scoring_config_service`, `user_service`, `audit_query_service`, `data_source_service`.
|
||||||
|
|
||||||
**Remaining:** `users.py`, `audit.py`, `data_sources.py`, `heatmap.py` still have direct queries. These are lower priority since they are simpler or already partially extracted.
|
**Update (Feb 20):** All routers now delegate to services. No routers contain direct ORM queries or business logic.
|
||||||
|
|
||||||
|
**Update (Feb 20 — Tier 1-2):** Four more routers fully extracted to new services: `advanced_metrics.py` → `advanced_metrics_service`, `analytics.py` → `analytics_service`, `test_templates.py` → `test_template_service`, `auth.py` → `auth_service`. Nine additional routers had remaining inline logic moved to their existing services: `techniques`, `campaigns`, `snapshots`, `notifications`, `scores`, `jira`, `d3fend`, `osint`, `worklogs`.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -110,25 +112,25 @@ Schemas NONE NONE LOW — NONE NONE
|
|||||||
Database NONE NONE NONE — NONE LOW
|
Database NONE NONE NONE — NONE LOW
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2.2. Router ↔ Model — ✅ LARGELY RESOLVED (was HIGH COUPLING)
|
### 2.2. Router ↔ Model — ✅ FULLY RESOLVED (was HIGH COUPLING)
|
||||||
|
|
||||||
**Update (Feb 19):** Most routers no longer import ORM models or execute queries directly. Only **4 out of 21 routers** still have direct DB access:
|
**Update (Feb 20):** All routers now delegate to services. No router imports ORM models or executes queries directly.
|
||||||
|
|
||||||
| Router | Status | Detail |
|
| Router | Status | Service |
|
||||||
|--------|--------|--------|
|
|--------|--------|---------|
|
||||||
| `techniques.py` | ✅ Extracted | Uses `SATechniqueRepository` via dependency injection |
|
| `techniques.py` | ✅ Extracted | `SATechniqueRepository` via dependency injection |
|
||||||
| `reports.py` | ✅ Extracted | Delegates to `coverage_report_service` |
|
| `reports.py` | ✅ Extracted | `coverage_report_service` |
|
||||||
| `metrics.py` | ✅ Extracted | Delegates to `metrics_query_service` |
|
| `metrics.py` | ✅ Extracted | `metrics_query_service` |
|
||||||
| `compliance.py` | ✅ Extracted | Delegates to `compliance_service` |
|
| `compliance.py` | ✅ Extracted | `compliance_service` |
|
||||||
| `detection_rules.py` | ✅ Extracted | Delegates to `detection_rule_service` |
|
| `detection_rules.py` | ✅ Extracted | `detection_rule_service` |
|
||||||
| `threat_actors.py` | ✅ Extracted | Delegates to `threat_actor_service` |
|
| `threat_actors.py` | ✅ Extracted | `threat_actor_service` |
|
||||||
| `tests.py` | ✅ Extracted | Delegates to `test_crud_service` + `test_workflow_service` |
|
| `tests.py` | ✅ Extracted | `test_crud_service` + `test_workflow_service` |
|
||||||
| `evidence.py` | ✅ Extracted | Delegates to `evidence_service` |
|
| `evidence.py` | ✅ Extracted | `evidence_service` |
|
||||||
| `campaigns.py` | ✅ Extracted | Delegates to `campaign_crud_service` |
|
| `campaigns.py` | ✅ Extracted | `campaign_crud_service` |
|
||||||
| `users.py` | Remaining | Direct queries (simple CRUD) |
|
| `users.py` | ✅ Extracted | `user_service` |
|
||||||
| `audit.py` | Remaining | Direct queries (read-only list) |
|
| `audit.py` | ✅ Extracted | `audit_query_service` |
|
||||||
| `data_sources.py` | Remaining | Direct queries |
|
| `data_sources.py` | ✅ Extracted | `data_source_service` |
|
||||||
| `heatmap.py` | Remaining | Complex queries (partially extracted via `heatmap_service`) |
|
| `heatmap.py` | ✅ Extracted | `heatmap_service` |
|
||||||
|
|
||||||
### 2.3. Router ↔ Database — HIGH COUPLING
|
### 2.3. Router ↔ Database — HIGH COUPLING
|
||||||
|
|
||||||
@@ -197,10 +199,13 @@ Communication is via REST API with aligned but independent types (`types/models.
|
|||||||
| **Threat actors** | ✅ SEPARATED | `threat_actor_service.py` handles queries, coverage, and gap analysis (N+1 fixed) |
|
| **Threat actors** | ✅ SEPARATED | `threat_actor_service.py` handles queries, coverage, and gap analysis (N+1 fixed) |
|
||||||
| **Evidence** | ✅ SEPARATED | `evidence_service.py` handles permission validation and queries with domain exceptions |
|
| **Evidence** | ✅ SEPARATED | `evidence_service.py` handles permission validation and queries with domain exceptions |
|
||||||
| **Campaigns** | ✅ SEPARATED | `campaign_crud_service.py` handles CRUD, lifecycle, and scheduling |
|
| **Campaigns** | ✅ SEPARATED | `campaign_crud_service.py` handles CRUD, lifecycle, and scheduling |
|
||||||
| **Heatmap/visualization** | PARTIAL | `heatmap_service.py` exists but router still has some logic |
|
| **Heatmap/visualization** | ✅ SEPARATED | `heatmap_service.py` contains all layer-building logic; router is a thin adapter |
|
||||||
| **Data import** | WELL SEPARATED | The 8 import services are correctly isolated |
|
| **Data import** | ✅ WELL SEPARATED | 8 import services behind `ImportService` protocol + central registry |
|
||||||
|
| **Data sources** | ✅ SEPARATED | `data_source_service.py` handles CRUD, sync dispatch, and stats |
|
||||||
|
| **Users** | ✅ SEPARATED | `user_service.py` handles CRUD, validation, and hashing |
|
||||||
|
| **Audit queries** | ✅ SEPARATED | `audit_query_service.py` handles paginated queries and distinct lookups |
|
||||||
| **Notifications** | WELL SEPARATED | `notification_service.py` encapsulates all logic |
|
| **Notifications** | WELL SEPARATED | `notification_service.py` encapsulates all logic |
|
||||||
| **Auditing** | WELL SEPARATED | `audit_service.py` is a pure `log_action()` function |
|
| **Auditing (writes)** | WELL SEPARATED | `audit_service.py` is a pure `log_action()` function |
|
||||||
|
|
||||||
### 3.2. Anemic Model (Anti-pattern)
|
### 3.2. Anemic Model (Anti-pattern)
|
||||||
|
|
||||||
@@ -237,36 +242,40 @@ Logic that should be in domain models (business validations, state transitions,
|
|||||||
|
|
||||||
| Component | Compliant? | Detail |
|
| Component | Compliant? | Detail |
|
||||||
|-----------|-----------|-------|
|
|-----------|-----------|-------|
|
||||||
| `heatmap.py` (router) | PARTIAL | Still has some inline logic; `heatmap_service` exists but not fully extracted |
|
| `heatmap.py` (router) | ✅ YES | Thin adapter → `heatmap_service` |
|
||||||
| `reports.py` (router) | ✅ YES | Thin adapter → `coverage_report_service` |
|
| `reports.py` (router) | ✅ YES | Thin adapter → `coverage_report_service` |
|
||||||
| `tests.py` (router) | ✅ YES | Thin adapter → `test_crud_service` + `test_workflow_service` |
|
| `tests.py` (router) | ✅ YES | Thin adapter → `test_crud_service` + `test_workflow_service` |
|
||||||
| `campaigns.py` (router) | ✅ YES | Thin adapter → `campaign_crud_service` |
|
| `campaigns.py` (router) | ✅ YES | Thin adapter → `campaign_crud_service` |
|
||||||
| `evidence.py` (router) | ✅ YES | Thin adapter → `evidence_service` |
|
| `evidence.py` (router) | ✅ YES | Thin adapter → `evidence_service` |
|
||||||
|
| `users.py` (router) | ✅ YES | Thin adapter → `user_service` |
|
||||||
|
| `audit.py` (router) | ✅ YES | Thin adapter → `audit_query_service` |
|
||||||
|
| `data_sources.py` (router) | ✅ YES | Thin adapter → `data_source_service` |
|
||||||
| `scoring_service.py` | ✅ YES | Reads weights from `scoring_config_service` (DB-backed, not mutable settings) |
|
| `scoring_service.py` | ✅ YES | Reads weights from `scoring_config_service` (DB-backed, not mutable settings) |
|
||||||
| `test_workflow_service.py` | ✅ YES | Single responsibility: test state machine |
|
| `test_workflow_service.py` | ✅ YES | Single responsibility: test state machine |
|
||||||
| `notification_service.py` | ✅ YES | Single responsibility: notification management |
|
| `notification_service.py` | ✅ YES | Single responsibility: notification management |
|
||||||
| `audit_service.py` | ✅ YES | Single responsibility: audit logging |
|
| `audit_service.py` | ✅ YES | Single responsibility: audit logging |
|
||||||
|
|
||||||
**Verdict:** All major routers now comply with SRP. Only `heatmap.py` and a few minor routers have remaining inline logic.
|
**Verdict:** All routers now comply with SRP. Every router is a thin HTTP adapter delegating to a dedicated service.
|
||||||
|
|
||||||
### 4.2. Open/Closed Principle (OCP) — ✅ PARTIALLY RESOLVED (was VIOLATION)
|
### 4.2. Open/Closed Principle (OCP) — ✅ MOSTLY RESOLVED (was VIOLATION)
|
||||||
|
|
||||||
**Update (Feb 19):**
|
**Update (Feb 20):**
|
||||||
|
|
||||||
- **Scoring weights:** ✅ Resolved — Weights are now persisted in the `scoring_config` DB table via `scoring_config_service.py`. The `ScoringWeights` value object validates invariants (sum = 100, non-negative). No more mutable global `settings`.
|
- **Scoring weights:** ✅ Resolved — Weights are now persisted in the `scoring_config` DB table via `scoring_config_service.py`. The `ScoringWeights` value object validates invariants (sum = 100, non-negative). No more mutable global `settings`.
|
||||||
- **Heatmap layers:** Each heatmap type is a separate endpoint with hardcoded logic. Adding a new layer type requires modifying the router.
|
- **Import services:** ✅ Resolved — All import services now satisfy the `ImportService` protocol (`domain/ports/import_service.py`). A central `IMPORT_REGISTRY` maps source names to lazy-loaded handlers. Adding a new import source requires only: (1) creating a new service module, (2) adding one line to `IMPORT_REGISTRY`.
|
||||||
- **Import services:** Each data source is a separate service without a common interface. Adding a new source requires creating a new service AND modifying `data_sources.py` and `system.py`.
|
- **Heatmap layers:** Each heatmap type is a separate endpoint with hardcoded logic. Adding a new layer type requires modifying the router. Low priority.
|
||||||
- **Test states:** The state machine is well defined in `VALID_TRANSITIONS`, but adding a new state requires modifying the dictionary AND potentially all services that read `TestState`.
|
- **Test states:** The state machine is well defined in `VALID_TRANSITIONS`, but adding a new state requires modifying the dictionary AND potentially all services that read `TestState`.
|
||||||
|
|
||||||
### 4.3. Liskov Substitution Principle (LSP) — N/A (Partial)
|
### 4.3. Liskov Substitution Principle (LSP) — N/A (Partial)
|
||||||
|
|
||||||
There is no significant inheritance or polymorphism in the backend. Services are functions, not classes. There are no interfaces or abstract classes. **Does not directly apply**, but the absence of formal contracts (protocols/ABCs) is a symptom of not being designed for extensibility.
|
There is no significant inheritance or polymorphism in the backend. Services are functions, not classes. There are no interfaces or abstract classes. **Does not directly apply**, but the absence of formal contracts (protocols/ABCs) is a symptom of not being designed for extensibility.
|
||||||
|
|
||||||
### 4.4. Interface Segregation Principle (ISP) — ✅ PARTIALLY RESOLVED (was VIOLATION)
|
### 4.4. Interface Segregation Principle (ISP) — ✅ MOSTLY RESOLVED (was VIOLATION)
|
||||||
|
|
||||||
**Update (Feb 19):**
|
**Update (Feb 20):**
|
||||||
|
|
||||||
- ✅ Protocol interfaces exist for `TechniqueRepository` and `TestRepository` in `domain/ports/repositories/`.
|
- ✅ Protocol interfaces exist for `TechniqueRepository` and `TestRepository` in `domain/ports/repositories/`.
|
||||||
|
- ✅ `ImportService` protocol in `domain/ports/import_service.py` — common contract for all data import services.
|
||||||
- Services expose focused functions per module (e.g., `threat_actor_service` exposes 4 functions, each for one use case).
|
- Services expose focused functions per module (e.g., `threat_actor_service` exposes 4 functions, each for one use case).
|
||||||
- The `Settings` object is still monolithic but scoring weights have been extracted to a dedicated DB table with a focused service interface.
|
- The `Settings` object is still monolithic but scoring weights have been extracted to a dedicated DB table with a focused service interface.
|
||||||
|
|
||||||
@@ -310,7 +319,7 @@ def get_technique_repository(db=Depends(get_db)) -> SATechniqueRepository: ...
|
|||||||
| `threat_actors.py` | 312 lines | ~100 lines | `threat_actor_service.py` |
|
| `threat_actors.py` | 312 lines | ~100 lines | `threat_actor_service.py` |
|
||||||
| `evidence.py` | 367 lines | ~200 lines | `evidence_service.py` |
|
| `evidence.py` | 367 lines | ~200 lines | `evidence_service.py` |
|
||||||
|
|
||||||
**Remaining:** `heatmap.py` still has inline logic (~528 lines). Lower priority since it's already partially extracted to `heatmap_service`.
|
**Update (Feb 20):** `heatmap.py` is also now a thin adapter — all logic was already in `heatmap_service`. Additionally, `users.py`, `audit.py`, and `data_sources.py` have been extracted to `user_service`, `audit_query_service`, and `data_source_service` respectively. No remaining fat routers.
|
||||||
|
|
||||||
### 5.2. ~~CRITICAL RISK: In-Memory Token Blacklist~~ ✅ RESOLVED
|
### 5.2. ~~CRITICAL RISK: In-Memory Token Blacklist~~ ✅ RESOLVED
|
||||||
|
|
||||||
@@ -374,13 +383,15 @@ Background jobs create sessions outside the request lifecycle. This is technical
|
|||||||
- `domain/value_objects/` — `MitreId`, `ScoringWeights` (immutable, validated).
|
- `domain/value_objects/` — `MitreId`, `ScoringWeights` (immutable, validated).
|
||||||
- ORM models remain anemic by design (persistence mapping only). Business logic lives in domain entities.
|
- ORM models remain anemic by design (persistence mapping only). Business logic lives in domain entities.
|
||||||
|
|
||||||
**Remaining:** Campaign, ComplianceFramework, ThreatActor still lack domain entity counterparts.
|
**Update (Feb 20):** `CampaignEntity` (with lifecycle state machine) and `ComplianceFrameworkEntity` / `ComplianceControlEntity` (with coverage calculation logic) have been added.
|
||||||
|
|
||||||
|
**Update (Feb 20 — Tier 4):** `ThreatActorEntity` (with coverage analysis: `coverage_pct`, `covered_techniques`, `uncovered_techniques`, `from_orm`) has been added. All major domain concepts now have rich entity counterparts.
|
||||||
|
|
||||||
### 5.8. ~~MEDIUM RISK: No Explicit Transaction Management~~ ✅ PARTIALLY RESOLVED
|
### 5.8. ~~MEDIUM RISK: No Explicit Transaction Management~~ ✅ PARTIALLY RESOLVED
|
||||||
|
|
||||||
**Update (Feb 18):** A `UnitOfWork` context manager exists at `domain/unit_of_work.py` with explicit `commit()`, `rollback()`, and `flush()`. Used by `test_workflow_service.py` which explicitly states "The caller (router) is responsible for committing the session via the Unit of Work pattern."
|
**Update (Feb 18):** A `UnitOfWork` context manager exists at `domain/unit_of_work.py` with explicit `commit()`, `rollback()`, and `flush()`. Used by `test_workflow_service.py` which explicitly states "The caller (router) is responsible for committing the session via the Unit of Work pattern."
|
||||||
|
|
||||||
**Remaining:** Some services like `audit_service.py` still call `db.commit()` directly. Needs incremental migration.
|
**Update (Feb 20 — Tier 3):** Business services (`scoring_config_service`, `worklog_service`, `osint_enrichment_service.mark_osint_reviewed`) no longer call `db.commit()` — their callers use `UnitOfWork`. Documented exceptions: `audit_service.log_action` (15+ callers, high blast radius), import services (self-contained batch ops), and background jobs keep their internal commits.
|
||||||
|
|
||||||
### 5.9. LOW RISK: No Semantic API Versioning
|
### 5.9. LOW RISK: No Semantic API Versioning
|
||||||
|
|
||||||
@@ -661,14 +672,14 @@ class SQLAlchemyTestRepository(TestRepository):
|
|||||||
|
|
||||||
| Weakness | Original Severity | Current Status |
|
| Weakness | Original Severity | Current Status |
|
||||||
|----------|----------|--------|
|
|----------|----------|--------|
|
||||||
| Fat controllers (routers with business logic) | HIGH | ✅ Resolved — 9 routers extracted to services |
|
| Fat controllers (routers with business logic) | HIGH | ✅ Resolved — all 21 routers now delegate to services (12 extracted) |
|
||||||
| No repository layer | HIGH | ✅ Resolved (Test, Technique repos + 9 service modules) |
|
| No repository layer | HIGH | ✅ Resolved (Test, Technique repos + 12 service modules) |
|
||||||
| Services depend on FastAPI | HIGH | ✅ Resolved (domain exceptions + middleware) |
|
| Services depend on FastAPI | HIGH | ✅ Resolved (domain exceptions + middleware) |
|
||||||
| Anemic models | MEDIUM | ✅ Partially resolved (TestEntity, TechniqueEntity) |
|
| Anemic models | MEDIUM | ✅ Resolved (TestEntity, TechniqueEntity, CampaignEntity, ComplianceFrameworkEntity, ThreatActorEntity) |
|
||||||
| In-memory token blacklist | HIGH | ✅ Resolved (Redis-backed) |
|
| In-memory token blacklist | HIGH | ✅ Resolved (Redis-backed) |
|
||||||
| Mutable settings at runtime | MEDIUM | ✅ Resolved (scoring_config DB table) |
|
| Mutable settings at runtime | MEDIUM | ✅ Resolved (scoring_config DB table) |
|
||||||
| No CI/CD | MEDIUM | ✅ Resolved (GitHub Actions) |
|
| No CI/CD | MEDIUM | ✅ Resolved (GitHub Actions) |
|
||||||
| No dependency inversion | HIGH | ✅ Partially resolved (ports + repos + services) |
|
| No dependency inversion | HIGH | ✅ Mostly resolved (ports + repos + ImportService protocol + services) |
|
||||||
| No structured logging | LOW | ✅ Resolved (JSON logging for production) |
|
| No structured logging | LOW | ✅ Resolved (JSON logging for production) |
|
||||||
|
|
||||||
### Final Classification
|
### Final Classification
|
||||||
@@ -677,34 +688,46 @@ class SQLAlchemyTestRepository(TestRepository):
|
|||||||
┌──────────────────────────────────────────────────────────┐
|
┌──────────────────────────────────────────────────────────┐
|
||||||
│ Type: Clean Modular Monolith │
|
│ Type: Clean Modular Monolith │
|
||||||
│ Maturity: Production-ready │
|
│ Maturity: Production-ready │
|
||||||
│ SOLID: 4/5 (SRP ✅, OCP partial, LSP n/a, │
|
│ SOLID: 4.5/5 (SRP ✅, OCP mostly ✅, LSP n/a, │
|
||||||
│ ISP partial, DIP ✅ started) │
|
│ ISP mostly ✅, DIP mostly ✅) │
|
||||||
│ Testability: 7/10 (326 tests, domain unit tests, repo │
|
│ Testability: 9/10 (362+ tests, domain unit tests, repo │
|
||||||
│ integration tests, service layer tests) │
|
│ integration tests, service layer tests) │
|
||||||
│ Coupling: 7/10 (domain decoupled, services agnostic, │
|
│ Coupling: 9/10 (domain decoupled, services agnostic, │
|
||||||
│ most routers are thin adapters) │
|
│ all routers zero inline ORM, UoW pattern) │
|
||||||
│ Cohesion: 8/10 (domain entities own business rules, │
|
│ Cohesion: 9/10 (domain entities own business rules, │
|
||||||
│ services own query logic) │
|
│ services own query logic, clear contracts) │
|
||||||
│ Estimated remaining tech debt: ~1 week │
|
│ Estimated remaining tech debt: ~1 day │
|
||||||
│ (heatmap extraction, remaining minor routers, │
|
│ (heatmap layer extensibility, full repo protocol │
|
||||||
│ Campaign/ComplianceFramework domain entities) │
|
│ coverage, audit_service commit migration) │
|
||||||
└──────────────────────────────────────────────────────────┘
|
└──────────────────────────────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
### Recommendation (Updated Feb 19)
|
### Recommendation (Updated Feb 20)
|
||||||
|
|
||||||
The architectural refactoring is substantially complete. All critical and high-priority items from the original analysis are resolved:
|
The architectural refactoring is **complete**. All items from the original analysis — critical, high, medium, and low priority — are resolved:
|
||||||
|
|
||||||
|
**Critical / High priority:**
|
||||||
1. ~~Extract domain exceptions~~ ✅ Done
|
1. ~~Extract domain exceptions~~ ✅ Done
|
||||||
2. ~~Create repositories for Test and Technique~~ ✅ Done
|
2. ~~Create repositories for Test and Technique~~ ✅ Done
|
||||||
3. ~~Move token blacklist to Redis~~ ✅ Done
|
3. ~~Move token blacklist to Redis~~ ✅ Done
|
||||||
4. ~~Set up basic CI/CD~~ ✅ Done
|
4. ~~Set up basic CI/CD~~ ✅ Done
|
||||||
5. ~~Migrate fat routers to services~~ ✅ Done (9 routers extracted)
|
5. ~~Migrate fat routers to services~~ ✅ Done (12 routers extracted, all 21 now delegate)
|
||||||
6. ~~Persist scoring weights in database~~ ✅ Done
|
6. ~~Persist scoring weights in database~~ ✅ Done
|
||||||
7. ~~Add structured JSON logging~~ ✅ Done
|
7. ~~Add structured JSON logging~~ ✅ Done
|
||||||
|
|
||||||
**Remaining low-priority items:**
|
**Low priority (completed Feb 20):**
|
||||||
1. Extract remaining logic from `heatmap.py` to `heatmap_service.py`
|
8. ~~Extract `heatmap.py` logic~~ ✅ Already done (was a thin adapter)
|
||||||
2. Create domain entities for Campaign and ComplianceFramework
|
9. ~~Create domain entities for Campaign and ComplianceFramework~~ ✅ Done (with lifecycle validation + coverage calculations)
|
||||||
3. Extract `users.py`, `audit.py`, `data_sources.py` to services (simple CRUD)
|
10. ~~Extract `users.py`, `audit.py`, `data_sources.py` to services~~ ✅ Done
|
||||||
4. Add common interface for import services (OCP improvement)
|
11. ~~Add common interface for import services (OCP)~~ ✅ Done (`ImportService` protocol + registry)
|
||||||
|
|
||||||
|
**Tier 1–4 (completed Feb 20):**
|
||||||
|
12. ~~Extract 4 fat routers to new services (advanced_metrics, analytics, test_templates, auth)~~ ✅ Done
|
||||||
|
13. ~~Move remaining inline logic from 9 routers to existing services~~ ✅ Done — all routers have zero inline ORM queries
|
||||||
|
14. ~~Migrate business services from direct db.commit() to UoW pattern~~ ✅ Done (3 services migrated, exceptions documented)
|
||||||
|
15. ~~Create ThreatActor domain entity~~ ✅ Done (with coverage analysis)
|
||||||
|
|
||||||
|
**Remaining nice-to-haves (not blocking):**
|
||||||
|
- Heatmap layer extensibility (currently hardcoded endpoints)
|
||||||
|
- Full migration of all services to use Repository pattern (incremental)
|
||||||
|
- Migrate `audit_service.log_action` from internal commit to UoW (15+ callers to update)
|
||||||
|
|||||||
+54
-5
@@ -114,18 +114,47 @@ database.py ← Engine + session management (lazy initialization)
|
|||||||
|
|
||||||
### Services
|
### Services
|
||||||
|
|
||||||
|
#### Business Logic Services
|
||||||
|
|
||||||
| Service | Responsibility |
|
| Service | Responsibility |
|
||||||
|---------|---------------|
|
|---------|---------------|
|
||||||
| `test_workflow_service` | Test state machine (draft → validated/rejected) with dual validation |
|
| `test_workflow_service` | Test state machine (draft → validated/rejected) with dual validation |
|
||||||
|
| `test_crud_service` | Test CRUD, query logic, permission validation |
|
||||||
| `scoring_service` | 0–100 scoring for techniques, tactics, actors, organization |
|
| `scoring_service` | 0–100 scoring for techniques, tactics, actors, organization |
|
||||||
|
| `scoring_config_service` | DB-persisted scoring weights with validation |
|
||||||
| `score_cache` | In-memory TTL cache (5 min) for expensive score/metric calculations |
|
| `score_cache` | In-memory TTL cache (5 min) for expensive score/metric calculations |
|
||||||
| `operational_metrics_service` | MTTD, MTTR, detection efficacy, alert fidelity, coverage velocity |
|
| `operational_metrics_service` | MTTD, MTTR, detection efficacy, alert fidelity, coverage velocity |
|
||||||
| `snapshot_service` | Coverage snapshot creation, temporal comparison, cleanup |
|
| `metrics_query_service` | Dashboard aggregation queries |
|
||||||
| `campaign_service` | Campaign CRUD, progress tracking, circular dependency prevention |
|
| `advanced_metrics_service` | Coverage by tactic, never-tested, avg validation time, detection trends |
|
||||||
|
| `analytics_service` | BI-ready flat datasets (coverage, tests, trends, operators) |
|
||||||
|
| `snapshot_service` | Coverage snapshot CRUD, temporal comparison, cleanup |
|
||||||
|
| `campaign_crud_service` | Campaign CRUD, lifecycle, scheduling |
|
||||||
|
| `campaign_service` | Campaign progress tracking, circular dependency prevention |
|
||||||
| `campaign_scheduler_service` | Recurring campaign execution (clone + schedule next run) |
|
| `campaign_scheduler_service` | Recurring campaign execution (clone + schedule next run) |
|
||||||
| `status_service` | Technique status recalculation from test results |
|
| `status_service` | Technique status recalculation from test results |
|
||||||
| `notification_service` | In-app notification CRUD and state-change alerts |
|
| `coverage_report_service` | Coverage report generation and CSV export |
|
||||||
| `audit_service` | Immutable audit trail logging |
|
| `compliance_service` | Compliance framework analysis and gap detection |
|
||||||
|
| `detection_rule_service` | Detection rule queries, auto-association, evaluation |
|
||||||
|
| `threat_actor_service` | Threat actor queries, coverage, gap analysis |
|
||||||
|
| `evidence_service` | Evidence permission validation and queries |
|
||||||
|
| `heatmap_service` | ATT&CK Navigator layer generation |
|
||||||
|
| `test_template_service` | Test template CRUD, stats, bulk-activate, filtered queries |
|
||||||
|
| `auth_service` | Credential validation, password management |
|
||||||
|
| `user_service` | User CRUD, role validation, password hashing |
|
||||||
|
| `audit_query_service` | Paginated audit log queries and distinct lookups |
|
||||||
|
| `audit_service` | Immutable audit trail logging (write-only) |
|
||||||
|
| `data_source_service` | Data source CRUD, sync dispatch, statistics |
|
||||||
|
| `notification_service` | In-app notification CRUD, state-change alerts, role-based dispatch |
|
||||||
|
| `technique_query_service` | Technique detail queries with test/D3FEND aggregation |
|
||||||
|
| `d3fend_query_service` | D3FEND defensive technique listing and tactic queries |
|
||||||
|
| `osint_enrichment_service` | OSINT item queries, enrichment, summary statistics |
|
||||||
|
| `worklog_service` | Worklog CRUD, integrity verification |
|
||||||
|
| `intel_service` | RSS-based threat intelligence scanning |
|
||||||
|
|
||||||
|
#### Import Services (all satisfy `ImportService` protocol)
|
||||||
|
|
||||||
|
| Service | Responsibility |
|
||||||
|
|---------|---------------|
|
||||||
| `mitre_sync_service` | MITRE ATT&CK sync via TAXII 2.0 / GitHub fallback |
|
| `mitre_sync_service` | MITRE ATT&CK sync via TAXII 2.0 / GitHub fallback |
|
||||||
| `atomic_import_service` | Atomic Red Team template import from GitHub |
|
| `atomic_import_service` | Atomic Red Team template import from GitHub |
|
||||||
| `sigma_import_service` | SigmaHQ detection rule import |
|
| `sigma_import_service` | SigmaHQ detection rule import |
|
||||||
@@ -135,7 +164,27 @@ database.py ← Engine + session management (lazy initialization)
|
|||||||
| `d3fend_import_service` | MITRE D3FEND defensive technique import |
|
| `d3fend_import_service` | MITRE D3FEND defensive technique import |
|
||||||
| `threat_actor_import_service` | MITRE CTI threat actor import (STIX) |
|
| `threat_actor_import_service` | MITRE CTI threat actor import (STIX) |
|
||||||
| `compliance_import_service` | NIST 800-53 ↔ ATT&CK mapping import |
|
| `compliance_import_service` | NIST 800-53 ↔ ATT&CK mapping import |
|
||||||
| `intel_service` | RSS-based threat intelligence scanning |
|
|
||||||
|
### Domain Layer
|
||||||
|
|
||||||
|
```
|
||||||
|
domain/
|
||||||
|
├── entities/ # Rich domain entities with business logic
|
||||||
|
│ ├── technique.py # TechniqueEntity with status recalculation
|
||||||
|
│ ├── campaign.py # CampaignEntity with lifecycle state machine
|
||||||
|
│ ├── compliance.py # ComplianceFrameworkEntity with coverage calculation
|
||||||
|
│ └── threat_actor.py # ThreatActorEntity with coverage analysis
|
||||||
|
├── value_objects/ # Immutable value types
|
||||||
|
│ ├── mitre_id.py # MITRE ATT&CK ID validation
|
||||||
|
│ └── scoring_weights.py # Scoring weights (sum=100, non-negative)
|
||||||
|
├── ports/ # Interfaces (Protocol contracts)
|
||||||
|
│ ├── repositories/ # TechniqueRepository, TestRepository
|
||||||
|
│ └── import_service.py # ImportService protocol + IMPORT_REGISTRY
|
||||||
|
├── errors.py # Domain exceptions (EntityNotFoundError, etc.)
|
||||||
|
├── enums.py # TestState, TechniqueStatus, TestResult
|
||||||
|
├── test_entity.py # TestEntity with state machine + domain events
|
||||||
|
└── unit_of_work.py # UnitOfWork context manager
|
||||||
|
```
|
||||||
|
|
||||||
### Scheduled Jobs (APScheduler)
|
### Scheduled Jobs (APScheduler)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user