feat: move all remaining inline logic from routers to services (Tier 2)

This commit is contained in:
2026-02-20 14:34:24 +01:00
parent 9e22fde746
commit 339d669498
17 changed files with 632 additions and 414 deletions

View File

@@ -30,7 +30,7 @@ from app.services.campaign_crud_service import (
serialize_campaign,
update_campaign as crud_update,
)
from app.services.notification_service import create_notification
from app.services.notification_service import notify_role
from app.services.audit_service import log_action
logger = logging.getLogger(__name__)
@@ -237,17 +237,15 @@ def activate_campaign(
db.commit()
db.refresh(campaign)
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
for user in red_techs:
create_notification(
db,
user_id=user.id,
type="campaign_activated",
title="Campaign activated",
message=f'Campaign "{campaign.name}" has been activated.',
entity_type="campaign",
entity_id=campaign.id,
)
notify_role(
db,
role="red_tech",
type="campaign_activated",
title="Campaign activated",
message=f'Campaign "{campaign.name}" has been activated.',
entity_type="campaign",
entity_id=campaign.id,
)
log_action(
db,

View File

@@ -3,18 +3,20 @@
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.models.user import User
from app.models.technique import Technique
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
from app.services.d3fend_import_service import (
import_d3fend_techniques,
import_d3fend_mappings,
get_defenses_for_technique,
)
from app.services.d3fend_query_service import (
list_defensive_techniques as list_defensive_techniques_svc,
list_d3fend_tactics,
get_defenses_for_attack_technique,
)
logger = logging.getLogger(__name__)
@@ -36,38 +38,9 @@ def list_defensive_techniques(
current_user: User = Depends(get_current_user),
):
"""List all D3FEND defensive techniques with optional filters."""
query = db.query(DefensiveTechnique)
if tactic:
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
total = query.count()
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(dt.id),
"d3fend_id": dt.d3fend_id,
"name": dt.name,
"description": dt.description,
"tactic": dt.tactic,
"d3fend_url": dt.d3fend_url,
}
for dt in items
],
}
return list_defensive_techniques_svc(
db, tactic=tactic, search=search, offset=offset, limit=limit
)
# ---------------------------------------------------------------------------
@@ -75,21 +48,12 @@ def list_defensive_techniques(
# ---------------------------------------------------------------------------
@router.get("/tactics")
def list_d3fend_tactics(
def list_d3fend_tactics_endpoint(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return a list of all D3FEND tactics with counts."""
from sqlalchemy import func
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
.group_by(DefensiveTechnique.tactic)
.order_by(DefensiveTechnique.tactic)
.all()
)
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
return list_d3fend_tactics(db)
# ---------------------------------------------------------------------------
@@ -97,24 +61,13 @@ def list_d3fend_tactics(
# ---------------------------------------------------------------------------
@router.get("/for-technique/{mitre_id}")
def get_defenses_for_attack_technique(
def get_defenses_for_attack_technique_endpoint(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if not technique:
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
defenses = get_defenses_for_technique(db, technique.id)
return {
"mitre_id": mitre_id,
"technique_name": technique.name,
"defenses": defenses,
"total": len(defenses),
}
return get_defenses_for_attack_technique(db, mitre_id)
# ---------------------------------------------------------------------------

View File

@@ -7,14 +7,9 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.domain.exceptions import EntityNotFoundError
from app.models.jira_link import JiraLink, JiraLinkEntityType
from app.models.test import Test
from app.models.technique import Technique
from app.models.campaign import Campaign
from app.models.jira_link import JiraLinkEntityType
from app.models.user import User
from app.schemas.jira_schema import (
JiraIssueResult,
@@ -45,23 +40,14 @@ def create_link(
user: User = Depends(get_current_user),
):
"""Associate an Aegis entity with a Jira issue."""
link = JiraLink(
link = jira_service.create_link(
db,
entity_type=body.entity_type,
entity_id=body.entity_id,
jira_issue_key=body.jira_issue_key,
sync_direction=body.sync_direction,
created_by=user.id,
)
db.add(link)
db.flush()
# Pull initial data from Jira if enabled
if settings.JIRA_ENABLED:
try:
jira_service.sync_jira_to_aegis(db, link)
except Exception as e:
logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e)
db.commit()
db.refresh(link)
@@ -88,12 +74,11 @@ def list_links(
user: User = Depends(get_current_user),
):
"""List Jira links, optionally filtered by entity."""
query = db.query(JiraLink)
if entity_type:
query = query.filter(JiraLink.entity_type == entity_type)
if entity_id:
query = query.filter(JiraLink.entity_id == entity_id)
return query.order_by(JiraLink.created_at.desc()).all()
return jira_service.list_links(
db,
entity_type=entity_type,
entity_id=entity_id,
)
@router.post("/links/{link_id}/sync")
@@ -103,9 +88,7 @@ def sync_link(
user: User = Depends(require_role("admin")),
):
"""Force bidirectional sync for a specific Jira link."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
link = jira_service.get_link_or_raise(db, link_id)
jira_service.sync_jira_to_aegis(db, link)
db.commit()
return {"message": "Sync completed", "jira_status": link.jira_status}
@@ -118,10 +101,7 @@ def delete_link(
user: User = Depends(get_current_user),
):
"""Remove a Jira link."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
db.delete(link)
link = jira_service.delete_link(db, link_id)
db.commit()
audit_service.log_action(
db,
@@ -141,61 +121,11 @@ def create_issue_from_entity(
user: User = Depends(get_current_user),
):
"""Auto-create a Jira issue from an Aegis entity and link them."""
summary, description = _build_issue_data(db, entity_type, entity_id)
result = jira_service.create_jira_issue(
project_key=settings.JIRA_DEFAULT_PROJECT,
summary=summary,
description=description,
labels=["aegis", entity_type.value],
)
link = JiraLink(
result = jira_service.create_issue_and_link(
db,
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=result["issue_key"],
jira_issue_id=result["issue_id"],
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
created_by=user.id,
)
db.add(link)
db.commit()
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
def _build_issue_data(
db: Session,
entity_type: JiraLinkEntityType,
entity_id: UUID,
) -> tuple[str, str]:
"""Build Jira issue summary + description from an Aegis entity."""
if entity_type == JiraLinkEntityType.test:
entity = db.query(Test).filter(Test.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Test", str(entity_id))
return (
f"[Aegis Test] {entity.name}",
f"Test: {entity.name}\n"
f"State: {entity.state.value if entity.state else 'draft'}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.campaign:
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Campaign", str(entity_id))
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\n"
f"Type: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.technique:
entity = db.query(Technique).filter(Technique.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Technique", str(entity_id))
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
else:
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
return result

View File

@@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.domain.unit_of_work import UnitOfWork
from app.models.notification import Notification
from app.models.user import User
from app.schemas.notification import NotificationOut, UnreadCountOut
from app.services.notification_service import (
list_notifications,
mark_as_read,
mark_all_as_read,
get_unread_count,
@@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
@router.get("", response_model=list[NotificationOut])
def list_notifications(
def list_notifications_endpoint(
offset: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Return paginated notifications for the current user, newest first."""
notifs = (
db.query(Notification)
.filter(Notification.user_id == current_user.id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return notifs
return list_notifications(db, current_user.id, offset=offset, limit=limit)
# ---------------------------------------------------------------------------
@@ -80,14 +72,8 @@ def read_notification(
):
"""Mark a single notification as read."""
with UnitOfWork(db) as uow:
success = mark_as_read(db, notification_id, current_user.id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Notification not found",
)
notif = mark_as_read(db, notification_id, current_user.id)
uow.commit()
notif = db.query(Notification).filter(Notification.id == notification_id).first()
return notif

View File

@@ -10,14 +10,14 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.osint_item import OsintItem
from app.models.technique import Technique
from app.models.user import User
from app.services.osint_enrichment_service import (
enrich_technique_with_cves,
get_osint_items_for_technique,
get_osint_summary,
get_technique_or_raise,
list_osint_items as service_list_osint_items,
mark_osint_reviewed,
get_unreviewed_count,
)
router = APIRouter(prefix="/osint", tags=["osint"])
@@ -56,41 +56,15 @@ def list_osint_items(
user: User = Depends(get_current_user),
):
"""List OSINT items with optional filters."""
query = db.query(OsintItem)
if technique_id:
query = query.filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
return service_list_osint_items(
db,
technique_id=technique_id,
source_type=source_type,
reviewed=reviewed,
offset=offset,
limit=limit,
)
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
@router.get("/summary")
def osint_summary(
@@ -98,34 +72,7 @@ def osint_summary(
user: User = Depends(get_current_user),
):
"""Summary statistics for OSINT items."""
from sqlalchemy import func
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
return get_osint_summary(db)
@router.post("/items/{item_id}/review")
@@ -151,13 +98,7 @@ def trigger_technique_enrichment(
user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Manually trigger OSINT enrichment for a single technique."""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if not technique:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Technique not found",
)
technique = get_technique_or_raise(db, technique_id)
count = enrich_technique_with_cves(db, technique)
return {
"technique_id": str(technique.id),

View File

@@ -5,19 +5,17 @@ Provides granular scoring with breakdowns and configurable weights.
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role
from app.models.user import User
from app.models.technique import Technique
from app.models.threat_actor import ThreatActor
from app.services.scoring_service import (
calculate_technique_score,
score_technique_by_mitre_id,
score_actor_by_id,
calculate_tactic_score,
calculate_actor_coverage_score,
calculate_organization_score,
get_score_history,
)
@@ -39,23 +37,7 @@ def score_technique(
current_user: User = Depends(get_current_user),
):
"""Get detailed score with breakdown for a specific technique."""
technique = (
db.query(Technique)
.filter(Technique.mitre_id == mitre_id)
.first()
)
if not technique:
raise HTTPException(status_code=404, detail="Technique not found")
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
return score_technique_by_mitre_id(db, mitre_id)
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
@@ -81,11 +63,7 @@ def score_threat_actor(
current_user: User = Depends(get_current_user),
):
"""Get coverage score against a specific threat actor."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise HTTPException(status_code=404, detail="Threat actor not found")
return calculate_actor_coverage_score(actor_id, db)
return score_actor_by_id(db, actor_id)
# ── GET /scores/organization ─────────────────────────────────────────

View File

@@ -8,18 +8,24 @@ import logging
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.domain.errors import BusinessRuleViolation
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.services.snapshot_service import (
create_snapshot,
compare_snapshots,
cleanup_old_snapshots,
serialize_snapshot_summary,
list_snapshots as list_snapshots_svc,
get_snapshot_or_raise,
get_snapshot_detail,
delete_snapshot,
)
from app.services.audit_service import log_action
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
name: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = _serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
# ---------------------------------------------------------------------------
# GET /snapshots — List snapshots (paginated)
# ---------------------------------------------------------------------------
@@ -88,23 +52,7 @@ def list_snapshots(
current_user: User = Depends(get_current_user),
):
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_snapshot_summary(s) for s in snapshots],
}
return list_snapshots_svc(db, offset=offset, limit=limit)
# ---------------------------------------------------------------------------
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
details={"name": snapshot.name, "score": snapshot.organization_score},
)
return _serialize_snapshot_summary(snapshot)
return serialize_snapshot_summary(snapshot)
# ---------------------------------------------------------------------------
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
a_id = uuid.UUID(a)
b_id = uuid.UUID(b)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
raise BusinessRuleViolation("Invalid snapshot ID format")
result = compare_snapshots(db, a_id, b_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
return compare_snapshots(db, a_id, b_id)
# ---------------------------------------------------------------------------
@@ -168,11 +112,7 @@ def get_snapshot(
current_user: User = Depends(get_current_user),
):
"""Get detailed snapshot information including per-technique states."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return _serialize_snapshot_detail(db, snapshot)
return get_snapshot_detail(db, snapshot_id)
# ---------------------------------------------------------------------------
@@ -180,15 +120,13 @@ def get_snapshot(
# ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}")
def delete_snapshot(
def delete_snapshot_endpoint(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Delete a snapshot (admin only)."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
snapshot = get_snapshot_or_raise(db, snapshot_id)
log_action(
db,
@@ -199,7 +137,8 @@ def delete_snapshot(
details={"name": snapshot.name},
)
db.delete(snapshot)
db.commit()
with UnitOfWork(db) as uow:
delete_snapshot(db, snapshot_id)
uow.commit()
return {"detail": "Snapshot deleted"}

View File

@@ -6,7 +6,7 @@ exceptions to HTTP responses automatically.
"""
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role
@@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork
from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository,
)
from app.models.technique import Technique
from app.models.user import User
from app.schemas.technique import (
TechniqueCreate,
@@ -27,7 +26,7 @@ from app.schemas.technique import (
TechniqueUpdate,
)
from app.services.audit_service import log_action
from app.services.d3fend_import_service import get_defenses_for_technique
from app.services.technique_query_service import get_technique_detail
router = APIRouter(prefix="/techniques", tags=["techniques"])
@@ -67,45 +66,7 @@ def get_technique(
current_user: User = Depends(get_current_user),
):
"""Return full details for a single technique, including its tests and D3FEND defenses."""
technique = (
db.query(Technique)
.options(joinedload(Technique.tests))
.filter(Technique.mitre_id == mitre_id)
.first()
)
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"id": str(technique.id),
"mitre_id": technique.mitre_id,
"name": technique.name,
"description": technique.description,
"tactic": technique.tactic,
"platforms": technique.platforms or [],
"mitre_version": technique.mitre_version,
"mitre_last_modified": technique.mitre_last_modified,
"is_subtechnique": technique.is_subtechnique,
"parent_mitre_id": technique.parent_mitre_id,
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
"review_required": technique.review_required,
"last_review_date": technique.last_review_date,
"tests": [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"platform": t.platform,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
],
"d3fend_defenses": defenses,
}
return get_technique_detail(db, mitre_id)
# ---------------------------------------------------------------------------

View File

@@ -10,9 +10,7 @@ from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.domain.exceptions import EntityNotFoundError
from app.models.user import User
from app.models.worklog import Worklog
from app.services import worklog_service
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
@@ -97,10 +95,7 @@ def get_one(
_user: User = Depends(get_current_user),
):
"""Get a single worklog by ID."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return wl
return worklog_service.get_worklog_or_raise(db, worklog_id)
@router.get("/{worklog_id}/verify")
@@ -110,9 +105,7 @@ def verify_integrity(
_user: User = Depends(get_current_user),
):
"""Check whether a worklog's integrity hash is still valid."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
return {
"worklog_id": str(wl.id),
"integrity_valid": worklog_service.verify_worklog_integrity(wl),

View File

@@ -0,0 +1,82 @@
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.defensive_technique import DefensiveTechnique
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
from app.utils import escape_like
def list_defensive_techniques(
db: Session,
*,
tactic: Optional[str] = None,
search: Optional[str] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List D3FEND defensive techniques with optional filters."""
query = db.query(DefensiveTechnique)
if tactic:
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
total = query.count()
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(dt.id),
"d3fend_id": dt.d3fend_id,
"name": dt.name,
"description": dt.description,
"tactic": dt.tactic,
"d3fend_url": dt.d3fend_url,
}
for dt in items
],
}
def list_d3fend_tactics(db: Session) -> list[dict]:
"""Return a list of all D3FEND tactics with counts."""
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
.group_by(DefensiveTechnique.tactic)
.order_by(DefensiveTechnique.tactic)
.all()
)
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"mitre_id": mitre_id,
"technique_name": technique.name,
"defenses": defenses,
"total": len(defenses),
}

View File

@@ -3,12 +3,17 @@
import logging
from datetime import datetime
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.domain.exceptions import InvalidOperationError
from app.models.jira_link import JiraLink
from app.models.campaign import Campaign
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
from app.models.technique import Technique
from app.models.test import Test
logger = logging.getLogger(__name__)
@@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str:
lines.append(f"*{key}:* {value}")
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
return "\n".join(lines)
# ── Link CRUD ────────────────────────────────────────────────────────
def create_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
jira_issue_key: str,
sync_direction: JiraSyncDirection,
created_by: UUID,
) -> JiraLink:
"""Create a Jira link and optionally pull initial data from Jira."""
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=jira_issue_key,
sync_direction=sync_direction,
created_by=created_by,
)
db.add(link)
db.flush()
if settings.JIRA_ENABLED:
try:
sync_jira_to_aegis(db, link)
except Exception as e:
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
return link
def list_links(
db: Session,
*,
entity_type: Optional[JiraLinkEntityType] = None,
entity_id: Optional[UUID] = None,
) -> list[JiraLink]:
"""List Jira links with optional filters."""
query = db.query(JiraLink)
if entity_type:
query = query.filter(JiraLink.entity_type == entity_type)
if entity_id:
query = query.filter(JiraLink.entity_id == entity_id)
return query.order_by(JiraLink.created_at.desc()).all()
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
"""Get a Jira link by ID or raise EntityNotFoundError."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
return link
def delete_link(db: Session, link_id: UUID) -> JiraLink:
"""Delete a Jira link. Returns the deleted link (for audit)."""
link = get_link_or_raise(db, link_id)
db.delete(link)
return link
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
"""Build Jira issue summary and description from an Aegis entity."""
if entity_type == JiraLinkEntityType.test:
entity = db.query(Test).filter(Test.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Test", str(entity_id))
return (
f"[Aegis Test] {entity.name}",
f"Test: {entity.name}\n"
f"State: {entity.state.value if entity.state else 'draft'}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.campaign:
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Campaign", str(entity_id))
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\n"
f"Type: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.technique:
entity = db.query(Technique).filter(Technique.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Technique", str(entity_id))
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
else:
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
def create_issue_and_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
created_by: UUID,
) -> dict:
"""Create a Jira issue from an Aegis entity and link them."""
summary, description = build_issue_data(db, entity_type, entity_id)
result = create_jira_issue(
project_key=settings.JIRA_DEFAULT_PROJECT,
summary=summary,
description=description,
labels=["aegis", entity_type.value],
)
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=result["issue_key"],
jira_issue_id=result["issue_id"],
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
created_by=created_by,
)
db.add(link)
return {"issue_key": result["issue_key"], "link_id": str(link.id)}

View File

@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.domain.errors import EntityNotFoundError
from app.models.notification import Notification
from app.models.user import User
@@ -22,6 +23,71 @@ from app.models.user import User
# ---------------------------------------------------------------------------
def list_notifications(
db: Session,
user_id: uuid.UUID,
*,
offset: int = 0,
limit: int = 20,
) -> list[Notification]:
"""Return paginated notifications for a user, newest first."""
return (
db.query(Notification)
.filter(Notification.user_id == user_id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
def get_notification_or_raise(
db: Session,
notification_id: uuid.UUID,
user_id: uuid.UUID,
) -> Notification:
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
notif = (
db.query(Notification)
.filter(
Notification.id == notification_id,
Notification.user_id == user_id,
)
.first()
)
if notif is None:
raise EntityNotFoundError("Notification", str(notification_id))
return notif
def notify_role(
db: Session,
*,
role: str,
type: str,
title: str,
message: str,
entity_type: str,
entity_id: uuid.UUID,
) -> None:
"""Send notifications to all active users with a given role."""
users = (
db.query(User)
.filter(User.role == role, User.is_active == True) # noqa: E712
.all()
)
for user in users:
create_notification(
db,
user_id=user.id,
type=type,
title=title,
message=message,
entity_type=entity_type,
entity_id=entity_id,
)
def create_notification(
db: Session,
user_id: uuid.UUID,
@@ -45,17 +111,13 @@ def create_notification(
return notif
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
"""Mark a single notification as read. Returns True if updated."""
notif = (
db.query(Notification)
.filter(Notification.id == notification_id, Notification.user_id == user_id)
.first()
)
if notif is None:
return False
def mark_as_read(
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
) -> Notification:
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
notif = get_notification_or_raise(db, notification_id, user_id)
notif.read = True
return True
return notif
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:

View File

@@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits
import logging
import time
from typing import Optional
from uuid import UUID
import requests
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.models.osint_item import OsintItem
from app.models.technique import Technique
@@ -189,3 +193,87 @@ def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
def get_unreviewed_count(db: Session) -> int:
"""Return the total number of unreviewed OSINT items."""
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
def list_osint_items(
db: Session,
*,
technique_id: Optional[UUID] = None,
source_type: Optional[str] = None,
reviewed: Optional[bool] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List OSINT items with optional filters and pagination."""
query = db.query(OsintItem)
if technique_id:
query = query.filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
def get_osint_summary(db: Session) -> dict:
"""Summary statistics for OSINT items."""
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
"""Get a technique by ID or raise EntityNotFoundError."""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if not technique:
raise EntityNotFoundError("Technique", str(technique_id))
return technique

View File

@@ -15,6 +15,7 @@ from typing import Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.models.test import Test
from app.models.detection_rule import DetectionRule
@@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict:
# ── Technique-level scoring (single technique — preserved API) ────────
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
"""Get detailed score with breakdown for a technique by MITRE ID."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if not technique:
raise EntityNotFoundError("Technique", mitre_id)
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
def score_actor_by_id(db: Session, actor_id: str) -> dict:
"""Get coverage score for a threat actor by ID."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("ThreatActor", actor_id)
return calculate_actor_coverage_score(actor_id, db)
def calculate_technique_score(technique: Technique, db: Session) -> dict:
"""Calculate a 0-100 score for a technique with detailed breakdown.

View File

@@ -14,6 +14,7 @@ from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus
@@ -25,6 +26,101 @@ from app.services.scoring_service import (
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Serialization and queries
# ---------------------------------------------------------------------------
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
def list_snapshots(
db: Session,
*,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_snapshot_summary(s) for s in snapshots],
}
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
"""Fetch snapshot by ID or raise EntityNotFoundError."""
try:
sid = uuid.UUID(snapshot_id)
except (ValueError, TypeError):
raise EntityNotFoundError("Snapshot", snapshot_id)
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
if snapshot is None:
raise EntityNotFoundError("Snapshot", snapshot_id)
return snapshot
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
"""Get detailed snapshot including per-technique states."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
return serialize_snapshot_detail(db, snapshot)
def delete_snapshot(db: Session, snapshot_id: str) -> None:
"""Delete a snapshot. Does not commit — caller must commit."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
db.delete(snapshot)
# ---------------------------------------------------------------------------
# Create snapshot
# ---------------------------------------------------------------------------
@@ -138,7 +234,7 @@ def compare_snapshots(
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b:
return {"error": "One or both snapshots not found"}
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
# Build lookup dicts: mitre_id -> {status, score}
states_a = {

View File

@@ -0,0 +1,48 @@
"""Technique query service — framework-agnostic queries for technique details."""
from __future__ import annotations
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
def get_technique_detail(db: Session, mitre_id: str) -> dict:
"""Fetch full technique details including tests and D3FEND defenses."""
technique = (
db.query(Technique)
.options(joinedload(Technique.tests))
.filter(Technique.mitre_id == mitre_id)
.first()
)
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"id": str(technique.id),
"mitre_id": technique.mitre_id,
"name": technique.name,
"description": technique.description,
"tactic": technique.tactic,
"platforms": technique.platforms or [],
"mitre_version": technique.mitre_version,
"mitre_last_modified": technique.mitre_last_modified,
"is_subtechnique": technique.is_subtechnique,
"parent_mitre_id": technique.parent_mitre_id,
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
"review_required": technique.review_required,
"last_review_date": technique.last_review_date,
"tests": [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"platform": t.platform,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
],
"d3fend_defenses": defenses,
}

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.worklog import Worklog
logger = logging.getLogger(__name__)
@@ -43,6 +44,14 @@ def create_worklog(
return wl
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
"""Get a worklog by ID or raise EntityNotFoundError."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return wl
def list_worklogs(
db: Session,
*,