feat: move all remaining inline logic from routers to services (Tier 2)

This commit is contained in:
2026-02-20 14:34:24 +01:00
parent 9e22fde746
commit 339d669498
17 changed files with 632 additions and 414 deletions

View File

@@ -0,0 +1,82 @@
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.defensive_technique import DefensiveTechnique
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
from app.utils import escape_like
def list_defensive_techniques(
db: Session,
*,
tactic: Optional[str] = None,
search: Optional[str] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List D3FEND defensive techniques with optional filters."""
query = db.query(DefensiveTechnique)
if tactic:
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
total = query.count()
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(dt.id),
"d3fend_id": dt.d3fend_id,
"name": dt.name,
"description": dt.description,
"tactic": dt.tactic,
"d3fend_url": dt.d3fend_url,
}
for dt in items
],
}
def list_d3fend_tactics(db: Session) -> list[dict]:
"""Return a list of all D3FEND tactics with counts."""
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
.group_by(DefensiveTechnique.tactic)
.order_by(DefensiveTechnique.tactic)
.all()
)
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"mitre_id": mitre_id,
"technique_name": technique.name,
"defenses": defenses,
"total": len(defenses),
}

View File

@@ -3,12 +3,17 @@
import logging
from datetime import datetime
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.domain.exceptions import InvalidOperationError
from app.models.jira_link import JiraLink
from app.models.campaign import Campaign
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
from app.models.technique import Technique
from app.models.test import Test
logger = logging.getLogger(__name__)
@@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str:
lines.append(f"*{key}:* {value}")
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
return "\n".join(lines)
# ── Link CRUD ────────────────────────────────────────────────────────
def create_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
jira_issue_key: str,
sync_direction: JiraSyncDirection,
created_by: UUID,
) -> JiraLink:
"""Create a Jira link and optionally pull initial data from Jira."""
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=jira_issue_key,
sync_direction=sync_direction,
created_by=created_by,
)
db.add(link)
db.flush()
if settings.JIRA_ENABLED:
try:
sync_jira_to_aegis(db, link)
except Exception as e:
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
return link
def list_links(
db: Session,
*,
entity_type: Optional[JiraLinkEntityType] = None,
entity_id: Optional[UUID] = None,
) -> list[JiraLink]:
"""List Jira links with optional filters."""
query = db.query(JiraLink)
if entity_type:
query = query.filter(JiraLink.entity_type == entity_type)
if entity_id:
query = query.filter(JiraLink.entity_id == entity_id)
return query.order_by(JiraLink.created_at.desc()).all()
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
"""Get a Jira link by ID or raise EntityNotFoundError."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
return link
def delete_link(db: Session, link_id: UUID) -> JiraLink:
"""Delete a Jira link. Returns the deleted link (for audit)."""
link = get_link_or_raise(db, link_id)
db.delete(link)
return link
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
"""Build Jira issue summary and description from an Aegis entity."""
if entity_type == JiraLinkEntityType.test:
entity = db.query(Test).filter(Test.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Test", str(entity_id))
return (
f"[Aegis Test] {entity.name}",
f"Test: {entity.name}\n"
f"State: {entity.state.value if entity.state else 'draft'}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.campaign:
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Campaign", str(entity_id))
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\n"
f"Type: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.technique:
entity = db.query(Technique).filter(Technique.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Technique", str(entity_id))
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
else:
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
def create_issue_and_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
created_by: UUID,
) -> dict:
"""Create a Jira issue from an Aegis entity and link them."""
summary, description = build_issue_data(db, entity_type, entity_id)
result = create_jira_issue(
project_key=settings.JIRA_DEFAULT_PROJECT,
summary=summary,
description=description,
labels=["aegis", entity_type.value],
)
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=result["issue_key"],
jira_issue_id=result["issue_id"],
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
created_by=created_by,
)
db.add(link)
return {"issue_key": result["issue_key"], "link_id": str(link.id)}

View File

@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.domain.errors import EntityNotFoundError
from app.models.notification import Notification
from app.models.user import User
@@ -22,6 +23,71 @@ from app.models.user import User
# ---------------------------------------------------------------------------
def list_notifications(
db: Session,
user_id: uuid.UUID,
*,
offset: int = 0,
limit: int = 20,
) -> list[Notification]:
"""Return paginated notifications for a user, newest first."""
return (
db.query(Notification)
.filter(Notification.user_id == user_id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
def get_notification_or_raise(
db: Session,
notification_id: uuid.UUID,
user_id: uuid.UUID,
) -> Notification:
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
notif = (
db.query(Notification)
.filter(
Notification.id == notification_id,
Notification.user_id == user_id,
)
.first()
)
if notif is None:
raise EntityNotFoundError("Notification", str(notification_id))
return notif
def notify_role(
db: Session,
*,
role: str,
type: str,
title: str,
message: str,
entity_type: str,
entity_id: uuid.UUID,
) -> None:
"""Send notifications to all active users with a given role."""
users = (
db.query(User)
.filter(User.role == role, User.is_active == True) # noqa: E712
.all()
)
for user in users:
create_notification(
db,
user_id=user.id,
type=type,
title=title,
message=message,
entity_type=entity_type,
entity_id=entity_id,
)
def create_notification(
db: Session,
user_id: uuid.UUID,
@@ -45,17 +111,13 @@ def create_notification(
return notif
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
"""Mark a single notification as read. Returns True if updated."""
notif = (
db.query(Notification)
.filter(Notification.id == notification_id, Notification.user_id == user_id)
.first()
)
if notif is None:
return False
def mark_as_read(
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
) -> Notification:
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
notif = get_notification_or_raise(db, notification_id, user_id)
notif.read = True
return True
return notif
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:

View File

@@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits
import logging
import time
from typing import Optional
from uuid import UUID
import requests
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.models.osint_item import OsintItem
from app.models.technique import Technique
@@ -189,3 +193,87 @@ def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
def get_unreviewed_count(db: Session) -> int:
"""Return the total number of unreviewed OSINT items."""
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
def list_osint_items(
db: Session,
*,
technique_id: Optional[UUID] = None,
source_type: Optional[str] = None,
reviewed: Optional[bool] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List OSINT items with optional filters and pagination."""
query = db.query(OsintItem)
if technique_id:
query = query.filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
def get_osint_summary(db: Session) -> dict:
"""Summary statistics for OSINT items."""
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
"""Get a technique by ID or raise EntityNotFoundError."""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if not technique:
raise EntityNotFoundError("Technique", str(technique_id))
return technique

View File

@@ -15,6 +15,7 @@ from typing import Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.models.test import Test
from app.models.detection_rule import DetectionRule
@@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict:
# ── Technique-level scoring (single technique — preserved API) ────────
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
"""Get detailed score with breakdown for a technique by MITRE ID."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if not technique:
raise EntityNotFoundError("Technique", mitre_id)
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
def score_actor_by_id(db: Session, actor_id: str) -> dict:
"""Get coverage score for a threat actor by ID."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("ThreatActor", actor_id)
return calculate_actor_coverage_score(actor_id, db)
def calculate_technique_score(technique: Technique, db: Session) -> dict:
"""Calculate a 0-100 score for a technique with detailed breakdown.

View File

@@ -14,6 +14,7 @@ from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus
@@ -25,6 +26,101 @@ from app.services.scoring_service import (
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Serialization and queries
# ---------------------------------------------------------------------------
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
def list_snapshots(
db: Session,
*,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_snapshot_summary(s) for s in snapshots],
}
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
"""Fetch snapshot by ID or raise EntityNotFoundError."""
try:
sid = uuid.UUID(snapshot_id)
except (ValueError, TypeError):
raise EntityNotFoundError("Snapshot", snapshot_id)
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
if snapshot is None:
raise EntityNotFoundError("Snapshot", snapshot_id)
return snapshot
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
"""Get detailed snapshot including per-technique states."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
return serialize_snapshot_detail(db, snapshot)
def delete_snapshot(db: Session, snapshot_id: str) -> None:
"""Delete a snapshot. Does not commit — caller must commit."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
db.delete(snapshot)
# ---------------------------------------------------------------------------
# Create snapshot
# ---------------------------------------------------------------------------
@@ -138,7 +234,7 @@ def compare_snapshots(
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b:
return {"error": "One or both snapshots not found"}
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
# Build lookup dicts: mitre_id -> {status, score}
states_a = {

View File

@@ -0,0 +1,48 @@
"""Technique query service — framework-agnostic queries for technique details."""
from __future__ import annotations
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
def get_technique_detail(db: Session, mitre_id: str) -> dict:
"""Fetch full technique details including tests and D3FEND defenses."""
technique = (
db.query(Technique)
.options(joinedload(Technique.tests))
.filter(Technique.mitre_id == mitre_id)
.first()
)
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"id": str(technique.id),
"mitre_id": technique.mitre_id,
"name": technique.name,
"description": technique.description,
"tactic": technique.tactic,
"platforms": technique.platforms or [],
"mitre_version": technique.mitre_version,
"mitre_last_modified": technique.mitre_last_modified,
"is_subtechnique": technique.is_subtechnique,
"parent_mitre_id": technique.parent_mitre_id,
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
"review_required": technique.review_required,
"last_review_date": technique.last_review_date,
"tests": [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"platform": t.platform,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
],
"d3fend_defenses": defenses,
}

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.worklog import Worklog
logger = logging.getLogger(__name__)
@@ -43,6 +44,14 @@ def create_worklog(
return wl
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
"""Get a worklog by ID or raise EntityNotFoundError."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return wl
def list_worklogs(
db: Session,
*,