feat: move all remaining inline logic from routers to services (Tier 2)

This commit is contained in:
2026-02-20 14:34:24 +01:00
parent 9e22fde746
commit 339d669498
17 changed files with 632 additions and 414 deletions

View File

@@ -30,7 +30,7 @@ from app.services.campaign_crud_service import (
serialize_campaign, serialize_campaign,
update_campaign as crud_update, update_campaign as crud_update,
) )
from app.services.notification_service import create_notification from app.services.notification_service import notify_role
from app.services.audit_service import log_action from app.services.audit_service import log_action
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -237,17 +237,15 @@ def activate_campaign(
db.commit() db.commit()
db.refresh(campaign) db.refresh(campaign)
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712 notify_role(
for user in red_techs: db,
create_notification( role="red_tech",
db, type="campaign_activated",
user_id=user.id, title="Campaign activated",
type="campaign_activated", message=f'Campaign "{campaign.name}" has been activated.',
title="Campaign activated", entity_type="campaign",
message=f'Campaign "{campaign.name}" has been activated.', entity_id=campaign.id,
entity_type="campaign", )
entity_id=campaign.id,
)
log_action( log_action(
db, db,

View File

@@ -3,18 +3,20 @@
import logging import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
from app.models.user import User from app.models.user import User
from app.models.technique import Technique
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
from app.services.d3fend_import_service import ( from app.services.d3fend_import_service import (
import_d3fend_techniques, import_d3fend_techniques,
import_d3fend_mappings, import_d3fend_mappings,
get_defenses_for_technique, )
from app.services.d3fend_query_service import (
list_defensive_techniques as list_defensive_techniques_svc,
list_d3fend_tactics,
get_defenses_for_attack_technique,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,38 +38,9 @@ def list_defensive_techniques(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""List all D3FEND defensive techniques with optional filters.""" """List all D3FEND defensive techniques with optional filters."""
query = db.query(DefensiveTechnique) return list_defensive_techniques_svc(
db, tactic=tactic, search=search, offset=offset, limit=limit
if tactic: )
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
total = query.count()
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(dt.id),
"d3fend_id": dt.d3fend_id,
"name": dt.name,
"description": dt.description,
"tactic": dt.tactic,
"d3fend_url": dt.d3fend_url,
}
for dt in items
],
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -75,21 +48,12 @@ def list_defensive_techniques(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/tactics") @router.get("/tactics")
def list_d3fend_tactics( def list_d3fend_tactics_endpoint(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return a list of all D3FEND tactics with counts.""" """Return a list of all D3FEND tactics with counts."""
from sqlalchemy import func return list_d3fend_tactics(db)
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
.group_by(DefensiveTechnique.tactic)
.order_by(DefensiveTechnique.tactic)
.all()
)
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -97,24 +61,13 @@ def list_d3fend_tactics(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/for-technique/{mitre_id}") @router.get("/for-technique/{mitre_id}")
def get_defenses_for_attack_technique( def get_defenses_for_attack_technique_endpoint(
mitre_id: str, mitre_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" """Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() return get_defenses_for_attack_technique(db, mitre_id)
if not technique:
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
defenses = get_defenses_for_technique(db, technique.id)
return {
"mitre_id": mitre_id,
"technique_name": technique.name,
"defenses": defenses,
"total": len(defenses),
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -7,14 +7,9 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
from app.domain.exceptions import EntityNotFoundError from app.models.jira_link import JiraLinkEntityType
from app.models.jira_link import JiraLink, JiraLinkEntityType
from app.models.test import Test
from app.models.technique import Technique
from app.models.campaign import Campaign
from app.models.user import User from app.models.user import User
from app.schemas.jira_schema import ( from app.schemas.jira_schema import (
JiraIssueResult, JiraIssueResult,
@@ -45,23 +40,14 @@ def create_link(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Associate an Aegis entity with a Jira issue.""" """Associate an Aegis entity with a Jira issue."""
link = JiraLink( link = jira_service.create_link(
db,
entity_type=body.entity_type, entity_type=body.entity_type,
entity_id=body.entity_id, entity_id=body.entity_id,
jira_issue_key=body.jira_issue_key, jira_issue_key=body.jira_issue_key,
sync_direction=body.sync_direction, sync_direction=body.sync_direction,
created_by=user.id, created_by=user.id,
) )
db.add(link)
db.flush()
# Pull initial data from Jira if enabled
if settings.JIRA_ENABLED:
try:
jira_service.sync_jira_to_aegis(db, link)
except Exception as e:
logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e)
db.commit() db.commit()
db.refresh(link) db.refresh(link)
@@ -88,12 +74,11 @@ def list_links(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""List Jira links, optionally filtered by entity.""" """List Jira links, optionally filtered by entity."""
query = db.query(JiraLink) return jira_service.list_links(
if entity_type: db,
query = query.filter(JiraLink.entity_type == entity_type) entity_type=entity_type,
if entity_id: entity_id=entity_id,
query = query.filter(JiraLink.entity_id == entity_id) )
return query.order_by(JiraLink.created_at.desc()).all()
@router.post("/links/{link_id}/sync") @router.post("/links/{link_id}/sync")
@@ -103,9 +88,7 @@ def sync_link(
user: User = Depends(require_role("admin")), user: User = Depends(require_role("admin")),
): ):
"""Force bidirectional sync for a specific Jira link.""" """Force bidirectional sync for a specific Jira link."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first() link = jira_service.get_link_or_raise(db, link_id)
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
jira_service.sync_jira_to_aegis(db, link) jira_service.sync_jira_to_aegis(db, link)
db.commit() db.commit()
return {"message": "Sync completed", "jira_status": link.jira_status} return {"message": "Sync completed", "jira_status": link.jira_status}
@@ -118,10 +101,7 @@ def delete_link(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Remove a Jira link.""" """Remove a Jira link."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first() link = jira_service.delete_link(db, link_id)
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
db.delete(link)
db.commit() db.commit()
audit_service.log_action( audit_service.log_action(
db, db,
@@ -141,61 +121,11 @@ def create_issue_from_entity(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Auto-create a Jira issue from an Aegis entity and link them.""" """Auto-create a Jira issue from an Aegis entity and link them."""
summary, description = _build_issue_data(db, entity_type, entity_id) result = jira_service.create_issue_and_link(
result = jira_service.create_jira_issue( db,
project_key=settings.JIRA_DEFAULT_PROJECT,
summary=summary,
description=description,
labels=["aegis", entity_type.value],
)
link = JiraLink(
entity_type=entity_type, entity_type=entity_type,
entity_id=entity_id, entity_id=entity_id,
jira_issue_key=result["issue_key"],
jira_issue_id=result["issue_id"],
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
created_by=user.id, created_by=user.id,
) )
db.add(link)
db.commit() db.commit()
return {"issue_key": result["issue_key"], "link_id": str(link.id)} return result
def _build_issue_data(
db: Session,
entity_type: JiraLinkEntityType,
entity_id: UUID,
) -> tuple[str, str]:
"""Build Jira issue summary + description from an Aegis entity."""
if entity_type == JiraLinkEntityType.test:
entity = db.query(Test).filter(Test.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Test", str(entity_id))
return (
f"[Aegis Test] {entity.name}",
f"Test: {entity.name}\n"
f"State: {entity.state.value if entity.state else 'draft'}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.campaign:
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Campaign", str(entity_id))
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\n"
f"Type: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.technique:
entity = db.query(Technique).filter(Technique.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Technique", str(entity_id))
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
else:
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"

View File

@@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user from app.dependencies.auth import get_current_user
from app.domain.unit_of_work import UnitOfWork from app.domain.unit_of_work import UnitOfWork
from app.models.notification import Notification
from app.models.user import User from app.models.user import User
from app.schemas.notification import NotificationOut, UnreadCountOut from app.schemas.notification import NotificationOut, UnreadCountOut
from app.services.notification_service import ( from app.services.notification_service import (
list_notifications,
mark_as_read, mark_as_read,
mark_all_as_read, mark_all_as_read,
get_unread_count, get_unread_count,
@@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
@router.get("", response_model=list[NotificationOut]) @router.get("", response_model=list[NotificationOut])
def list_notifications( def list_notifications_endpoint(
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100), limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return paginated notifications for the current user, newest first.""" """Return paginated notifications for the current user, newest first."""
notifs = ( return list_notifications(db, current_user.id, offset=offset, limit=limit)
db.query(Notification)
.filter(Notification.user_id == current_user.id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return notifs
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -80,14 +72,8 @@ def read_notification(
): ):
"""Mark a single notification as read.""" """Mark a single notification as read."""
with UnitOfWork(db) as uow: with UnitOfWork(db) as uow:
success = mark_as_read(db, notification_id, current_user.id) notif = mark_as_read(db, notification_id, current_user.id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Notification not found",
)
uow.commit() uow.commit()
notif = db.query(Notification).filter(Notification.id == notification_id).first()
return notif return notif

View File

@@ -10,14 +10,14 @@ from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.models.osint_item import OsintItem
from app.models.technique import Technique
from app.models.user import User from app.models.user import User
from app.services.osint_enrichment_service import ( from app.services.osint_enrichment_service import (
enrich_technique_with_cves, enrich_technique_with_cves,
get_osint_items_for_technique, get_osint_items_for_technique,
get_osint_summary,
get_technique_or_raise,
list_osint_items as service_list_osint_items,
mark_osint_reviewed, mark_osint_reviewed,
get_unreviewed_count,
) )
router = APIRouter(prefix="/osint", tags=["osint"]) router = APIRouter(prefix="/osint", tags=["osint"])
@@ -56,41 +56,15 @@ def list_osint_items(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""List OSINT items with optional filters.""" """List OSINT items with optional filters."""
query = db.query(OsintItem) return service_list_osint_items(
if technique_id: db,
query = query.filter(OsintItem.technique_id == technique_id) technique_id=technique_id,
if source_type: source_type=source_type,
query = query.filter(OsintItem.source_type == source_type) reviewed=reviewed,
if reviewed is not None: offset=offset,
query = query.filter(OsintItem.reviewed == reviewed) limit=limit,
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
) )
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
@router.get("/summary") @router.get("/summary")
def osint_summary( def osint_summary(
@@ -98,34 +72,7 @@ def osint_summary(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
"""Summary statistics for OSINT items.""" """Summary statistics for OSINT items."""
from sqlalchemy import func return get_osint_summary(db)
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
@router.post("/items/{item_id}/review") @router.post("/items/{item_id}/review")
@@ -151,13 +98,7 @@ def trigger_technique_enrichment(
user: User = Depends(require_any_role("red_lead", "blue_lead")), user: User = Depends(require_any_role("red_lead", "blue_lead")),
): ):
"""Manually trigger OSINT enrichment for a single technique.""" """Manually trigger OSINT enrichment for a single technique."""
technique = db.query(Technique).filter(Technique.id == technique_id).first() technique = get_technique_or_raise(db, technique_id)
if not technique:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Technique not found",
)
count = enrich_technique_with_cves(db, technique) count = enrich_technique_with_cves(db, technique)
return { return {
"technique_id": str(technique.id), "technique_id": str(technique.id),

View File

@@ -5,19 +5,17 @@ Provides granular scoring with breakdowns and configurable weights.
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role from app.dependencies.auth import get_current_user, require_role
from app.models.user import User from app.models.user import User
from app.models.technique import Technique
from app.models.threat_actor import ThreatActor
from app.services.scoring_service import ( from app.services.scoring_service import (
calculate_technique_score, score_technique_by_mitre_id,
score_actor_by_id,
calculate_tactic_score, calculate_tactic_score,
calculate_actor_coverage_score,
calculate_organization_score, calculate_organization_score,
get_score_history, get_score_history,
) )
@@ -39,23 +37,7 @@ def score_technique(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get detailed score with breakdown for a specific technique.""" """Get detailed score with breakdown for a specific technique."""
technique = ( return score_technique_by_mitre_id(db, mitre_id)
db.query(Technique)
.filter(Technique.mitre_id == mitre_id)
.first()
)
if not technique:
raise HTTPException(status_code=404, detail="Technique not found")
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
# ── GET /scores/tactic/{tactic} ────────────────────────────────────── # ── GET /scores/tactic/{tactic} ──────────────────────────────────────
@@ -81,11 +63,7 @@ def score_threat_actor(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get coverage score against a specific threat actor.""" """Get coverage score against a specific threat actor."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() return score_actor_by_id(db, actor_id)
if not actor:
raise HTTPException(status_code=404, detail="Threat actor not found")
return calculate_actor_coverage_score(actor_id, db)
# ── GET /scores/organization ───────────────────────────────────────── # ── GET /scores/organization ─────────────────────────────────────────

View File

@@ -8,18 +8,24 @@ import logging
import uuid import uuid
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role, require_role from app.dependencies.auth import get_current_user, require_any_role, require_role
from app.domain.errors import BusinessRuleViolation
from app.domain.unit_of_work import UnitOfWork
from app.models.user import User from app.models.user import User
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.services.snapshot_service import ( from app.services.snapshot_service import (
create_snapshot, create_snapshot,
compare_snapshots, compare_snapshots,
cleanup_old_snapshots, cleanup_old_snapshots,
serialize_snapshot_summary,
list_snapshots as list_snapshots_svc,
get_snapshot_or_raise,
get_snapshot_detail,
delete_snapshot,
) )
from app.services.audit_service import log_action from app.services.audit_service import log_action
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = _serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /snapshots — List snapshots (paginated) # GET /snapshots — List snapshots (paginated)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -88,23 +52,7 @@ def list_snapshots(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""List coverage snapshots ordered by creation date (newest first).""" """List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot) return list_snapshots_svc(db, offset=offset, limit=limit)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_snapshot_summary(s) for s in snapshots],
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
details={"name": snapshot.name, "score": snapshot.organization_score}, details={"name": snapshot.name, "score": snapshot.organization_score},
) )
return _serialize_snapshot_summary(snapshot) return serialize_snapshot_summary(snapshot)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
a_id = uuid.UUID(a) a_id = uuid.UUID(a)
b_id = uuid.UUID(b) b_id = uuid.UUID(b)
except ValueError: except ValueError:
raise HTTPException(status_code=400, detail="Invalid snapshot ID format") raise BusinessRuleViolation("Invalid snapshot ID format")
result = compare_snapshots(db, a_id, b_id) return compare_snapshots(db, a_id, b_id)
if "error" in result:
raise HTTPException(status_code=404, detail=result["error"])
return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -168,11 +112,7 @@ def get_snapshot(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Get detailed snapshot information including per-technique states.""" """Get detailed snapshot information including per-technique states."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() return get_snapshot_detail(db, snapshot_id)
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
return _serialize_snapshot_detail(db, snapshot)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -180,15 +120,13 @@ def get_snapshot(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.delete("/{snapshot_id}") @router.delete("/{snapshot_id}")
def delete_snapshot( def delete_snapshot_endpoint(
snapshot_id: str, snapshot_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")), current_user: User = Depends(require_role("admin")),
): ):
"""Delete a snapshot (admin only).""" """Delete a snapshot (admin only)."""
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() snapshot = get_snapshot_or_raise(db, snapshot_id)
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
log_action( log_action(
db, db,
@@ -199,7 +137,8 @@ def delete_snapshot(
details={"name": snapshot.name}, details={"name": snapshot.name},
) )
db.delete(snapshot) with UnitOfWork(db) as uow:
db.commit() delete_snapshot(db, snapshot_id)
uow.commit()
return {"detail": "Snapshot deleted"} return {"detail": "Snapshot deleted"}

View File

@@ -6,7 +6,7 @@ exceptions to HTTP responses automatically.
""" """
from fastapi import APIRouter, Depends, Query, status from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_role, require_any_role from app.dependencies.auth import get_current_user, require_role, require_any_role
@@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork
from app.infrastructure.persistence.repositories.sa_technique_repository import ( from app.infrastructure.persistence.repositories.sa_technique_repository import (
SATechniqueRepository, SATechniqueRepository,
) )
from app.models.technique import Technique
from app.models.user import User from app.models.user import User
from app.schemas.technique import ( from app.schemas.technique import (
TechniqueCreate, TechniqueCreate,
@@ -27,7 +26,7 @@ from app.schemas.technique import (
TechniqueUpdate, TechniqueUpdate,
) )
from app.services.audit_service import log_action from app.services.audit_service import log_action
from app.services.d3fend_import_service import get_defenses_for_technique from app.services.technique_query_service import get_technique_detail
router = APIRouter(prefix="/techniques", tags=["techniques"]) router = APIRouter(prefix="/techniques", tags=["techniques"])
@@ -67,45 +66,7 @@ def get_technique(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Return full details for a single technique, including its tests and D3FEND defenses.""" """Return full details for a single technique, including its tests and D3FEND defenses."""
technique = ( return get_technique_detail(db, mitre_id)
db.query(Technique)
.options(joinedload(Technique.tests))
.filter(Technique.mitre_id == mitre_id)
.first()
)
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"id": str(technique.id),
"mitre_id": technique.mitre_id,
"name": technique.name,
"description": technique.description,
"tactic": technique.tactic,
"platforms": technique.platforms or [],
"mitre_version": technique.mitre_version,
"mitre_last_modified": technique.mitre_last_modified,
"is_subtechnique": technique.is_subtechnique,
"parent_mitre_id": technique.parent_mitre_id,
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
"review_required": technique.review_required,
"last_review_date": technique.last_review_date,
"tests": [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"platform": t.platform,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
],
"d3fend_defenses": defenses,
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -10,9 +10,7 @@ from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role from app.dependencies.auth import get_current_user, require_any_role
from app.domain.exceptions import EntityNotFoundError
from app.models.user import User from app.models.user import User
from app.models.worklog import Worklog
from app.services import worklog_service from app.services import worklog_service
router = APIRouter(prefix="/worklogs", tags=["worklogs"]) router = APIRouter(prefix="/worklogs", tags=["worklogs"])
@@ -97,10 +95,7 @@ def get_one(
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ):
"""Get a single worklog by ID.""" """Get a single worklog by ID."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first() return worklog_service.get_worklog_or_raise(db, worklog_id)
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return wl
@router.get("/{worklog_id}/verify") @router.get("/{worklog_id}/verify")
@@ -110,9 +105,7 @@ def verify_integrity(
_user: User = Depends(get_current_user), _user: User = Depends(get_current_user),
): ):
"""Check whether a worklog's integrity hash is still valid.""" """Check whether a worklog's integrity hash is still valid."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first() wl = worklog_service.get_worklog_or_raise(db, worklog_id)
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return { return {
"worklog_id": str(wl.id), "worklog_id": str(wl.id),
"integrity_valid": worklog_service.verify_worklog_integrity(wl), "integrity_valid": worklog_service.verify_worklog_integrity(wl),

View File

@@ -0,0 +1,82 @@
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
from __future__ import annotations
from typing import Optional
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.defensive_technique import DefensiveTechnique
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
from app.utils import escape_like
def list_defensive_techniques(
db: Session,
*,
tactic: Optional[str] = None,
search: Optional[str] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List D3FEND defensive techniques with optional filters."""
query = db.query(DefensiveTechnique)
if tactic:
query = query.filter(DefensiveTechnique.tactic == tactic)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(
DefensiveTechnique.name.ilike(pattern)
| DefensiveTechnique.d3fend_id.ilike(pattern)
)
total = query.count()
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [
{
"id": str(dt.id),
"d3fend_id": dt.d3fend_id,
"name": dt.name,
"description": dt.description,
"tactic": dt.tactic,
"d3fend_url": dt.d3fend_url,
}
for dt in items
],
}
def list_d3fend_tactics(db: Session) -> list[dict]:
"""Return a list of all D3FEND tactics with counts."""
rows = (
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
.group_by(DefensiveTechnique.tactic)
.order_by(DefensiveTechnique.tactic)
.all()
)
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"mitre_id": mitre_id,
"technique_name": technique.name,
"defenses": defenses,
"total": len(defenses),
}

View File

@@ -3,12 +3,17 @@
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.domain.exceptions import InvalidOperationError from app.domain.exceptions import InvalidOperationError
from app.models.jira_link import JiraLink from app.models.campaign import Campaign
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
from app.models.technique import Technique
from app.models.test import Test
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str:
lines.append(f"*{key}:* {value}") lines.append(f"*{key}:* {value}")
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_") lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
return "\n".join(lines) return "\n".join(lines)
# ── Link CRUD ────────────────────────────────────────────────────────
def create_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
jira_issue_key: str,
sync_direction: JiraSyncDirection,
created_by: UUID,
) -> JiraLink:
"""Create a Jira link and optionally pull initial data from Jira."""
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=jira_issue_key,
sync_direction=sync_direction,
created_by=created_by,
)
db.add(link)
db.flush()
if settings.JIRA_ENABLED:
try:
sync_jira_to_aegis(db, link)
except Exception as e:
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
return link
def list_links(
db: Session,
*,
entity_type: Optional[JiraLinkEntityType] = None,
entity_id: Optional[UUID] = None,
) -> list[JiraLink]:
"""List Jira links with optional filters."""
query = db.query(JiraLink)
if entity_type:
query = query.filter(JiraLink.entity_type == entity_type)
if entity_id:
query = query.filter(JiraLink.entity_id == entity_id)
return query.order_by(JiraLink.created_at.desc()).all()
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
"""Get a Jira link by ID or raise EntityNotFoundError."""
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
if not link:
raise EntityNotFoundError("JiraLink", str(link_id))
return link
def delete_link(db: Session, link_id: UUID) -> JiraLink:
"""Delete a Jira link. Returns the deleted link (for audit)."""
link = get_link_or_raise(db, link_id)
db.delete(link)
return link
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
"""Build Jira issue summary and description from an Aegis entity."""
if entity_type == JiraLinkEntityType.test:
entity = db.query(Test).filter(Test.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Test", str(entity_id))
return (
f"[Aegis Test] {entity.name}",
f"Test: {entity.name}\n"
f"State: {entity.state.value if entity.state else 'draft'}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.campaign:
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Campaign", str(entity_id))
return (
f"[Aegis Campaign] {entity.name}",
f"Campaign: {entity.name}\n"
f"Type: {entity.type}\nStatus: {entity.status}\n"
f"Description: {entity.description or 'N/A'}",
)
elif entity_type == JiraLinkEntityType.technique:
entity = db.query(Technique).filter(Technique.id == entity_id).first()
if not entity:
raise EntityNotFoundError("Technique", str(entity_id))
return (
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
f"Tactic: {entity.tactic or 'N/A'}\n"
f"Description: {entity.description or 'N/A'}",
)
else:
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
def create_issue_and_link(
db: Session,
*,
entity_type: JiraLinkEntityType,
entity_id: UUID,
created_by: UUID,
) -> dict:
"""Create a Jira issue from an Aegis entity and link them."""
summary, description = build_issue_data(db, entity_type, entity_id)
result = create_jira_issue(
project_key=settings.JIRA_DEFAULT_PROJECT,
summary=summary,
description=description,
labels=["aegis", entity_type.value],
)
link = JiraLink(
entity_type=entity_type,
entity_id=entity_id,
jira_issue_key=result["issue_key"],
jira_issue_id=result["issue_id"],
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
created_by=created_by,
)
db.add(link)
return {"issue_key": result["issue_key"], "link_id": str(link.id)}

View File

@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import func from sqlalchemy import func
from app.domain.errors import EntityNotFoundError
from app.models.notification import Notification from app.models.notification import Notification
from app.models.user import User from app.models.user import User
@@ -22,6 +23,71 @@ from app.models.user import User
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def list_notifications(
db: Session,
user_id: uuid.UUID,
*,
offset: int = 0,
limit: int = 20,
) -> list[Notification]:
"""Return paginated notifications for a user, newest first."""
return (
db.query(Notification)
.filter(Notification.user_id == user_id)
.order_by(Notification.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
def get_notification_or_raise(
db: Session,
notification_id: uuid.UUID,
user_id: uuid.UUID,
) -> Notification:
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
notif = (
db.query(Notification)
.filter(
Notification.id == notification_id,
Notification.user_id == user_id,
)
.first()
)
if notif is None:
raise EntityNotFoundError("Notification", str(notification_id))
return notif
def notify_role(
db: Session,
*,
role: str,
type: str,
title: str,
message: str,
entity_type: str,
entity_id: uuid.UUID,
) -> None:
"""Send notifications to all active users with a given role."""
users = (
db.query(User)
.filter(User.role == role, User.is_active == True) # noqa: E712
.all()
)
for user in users:
create_notification(
db,
user_id=user.id,
type=type,
title=title,
message=message,
entity_type=entity_type,
entity_id=entity_id,
)
def create_notification( def create_notification(
db: Session, db: Session,
user_id: uuid.UUID, user_id: uuid.UUID,
@@ -45,17 +111,13 @@ def create_notification(
return notif return notif
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool: def mark_as_read(
"""Mark a single notification as read. Returns True if updated.""" db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
notif = ( ) -> Notification:
db.query(Notification) """Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
.filter(Notification.id == notification_id, Notification.user_id == user_id) notif = get_notification_or_raise(db, notification_id, user_id)
.first()
)
if notif is None:
return False
notif.read = True notif.read = True
return True return notif
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int: def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:

View File

@@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits
import logging import logging
import time import time
from typing import Optional
from uuid import UUID
import requests import requests
from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config import settings from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.models.osint_item import OsintItem from app.models.osint_item import OsintItem
from app.models.technique import Technique from app.models.technique import Technique
@@ -189,3 +193,87 @@ def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
def get_unreviewed_count(db: Session) -> int: def get_unreviewed_count(db: Session) -> int:
"""Return the total number of unreviewed OSINT items.""" """Return the total number of unreviewed OSINT items."""
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712 return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
def list_osint_items(
db: Session,
*,
technique_id: Optional[UUID] = None,
source_type: Optional[str] = None,
reviewed: Optional[bool] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List OSINT items with optional filters and pagination."""
query = db.query(OsintItem)
if technique_id:
query = query.filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
def get_osint_summary(db: Session) -> dict:
"""Summary statistics for OSINT items."""
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
"""Get a technique by ID or raise EntityNotFoundError."""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if not technique:
raise EntityNotFoundError("Technique", str(technique_id))
return technique

View File

@@ -15,6 +15,7 @@ from typing import Optional
from sqlalchemy import case, func from sqlalchemy import case, func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique from app.models.technique import Technique
from app.models.test import Test from app.models.test import Test
from app.models.detection_rule import DetectionRule from app.models.detection_rule import DetectionRule
@@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict:
# ── Technique-level scoring (single technique — preserved API) ──────── # ── Technique-level scoring (single technique — preserved API) ────────
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
"""Get detailed score with breakdown for a technique by MITRE ID."""
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
if not technique:
raise EntityNotFoundError("Technique", mitre_id)
result = calculate_technique_score(technique, db)
return {
"mitre_id": technique.mitre_id,
"name": technique.name,
"tactic": technique.tactic,
"status_global": technique.status_global.value if technique.status_global else None,
**result,
}
def score_actor_by_id(db: Session, actor_id: str) -> dict:
"""Get coverage score for a threat actor by ID."""
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
if not actor:
raise EntityNotFoundError("ThreatActor", actor_id)
return calculate_actor_coverage_score(actor_id, db)
def calculate_technique_score(technique: Technique, db: Session) -> dict: def calculate_technique_score(technique: Technique, db: Session) -> dict:
"""Calculate a 0-100 score for a technique with detailed breakdown. """Calculate a 0-100 score for a technique with detailed breakdown.

View File

@@ -14,6 +14,7 @@ from datetime import datetime
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique from app.models.technique import Technique
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
from app.models.enums import TechniqueStatus from app.models.enums import TechniqueStatus
@@ -25,6 +26,101 @@ from app.services.scoring_service import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Serialization and queries
# ---------------------------------------------------------------------------
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
return {
"id": str(snap.id),
"name": snap.name,
"organization_score": snap.organization_score,
"total_techniques": snap.total_techniques,
"validated_count": snap.validated_count,
"partial_count": snap.partial_count,
"not_covered_count": snap.not_covered_count,
"in_progress_count": snap.in_progress_count,
"not_evaluated_count": snap.not_evaluated_count,
"created_by": str(snap.created_by) if snap.created_by else None,
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
base = serialize_snapshot_summary(snap)
technique_states = (
db.query(SnapshotTechniqueState)
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
.order_by(SnapshotTechniqueState.mitre_id)
.all()
)
base["technique_states"] = [
{
"mitre_id": s.mitre_id,
"technique_id": str(s.technique_id),
"status": s.status,
"score": s.score,
}
for s in technique_states
]
return base
def list_snapshots(
db: Session,
*,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List coverage snapshots ordered by creation date (newest first)."""
query = db.query(CoverageSnapshot)
total = query.count()
snapshots = (
query
.order_by(CoverageSnapshot.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_snapshot_summary(s) for s in snapshots],
}
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
"""Fetch snapshot by ID or raise EntityNotFoundError."""
try:
sid = uuid.UUID(snapshot_id)
except (ValueError, TypeError):
raise EntityNotFoundError("Snapshot", snapshot_id)
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
if snapshot is None:
raise EntityNotFoundError("Snapshot", snapshot_id)
return snapshot
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
"""Get detailed snapshot including per-technique states."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
return serialize_snapshot_detail(db, snapshot)
def delete_snapshot(db: Session, snapshot_id: str) -> None:
"""Delete a snapshot. Does not commit — caller must commit."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
db.delete(snapshot)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Create snapshot # Create snapshot
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -138,7 +234,7 @@ def compare_snapshots(
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first() snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
if not snap_a or not snap_b: if not snap_a or not snap_b:
return {"error": "One or both snapshots not found"} raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
# Build lookup dicts: mitre_id -> {status, score} # Build lookup dicts: mitre_id -> {status, score}
states_a = { states_a = {

View File

@@ -0,0 +1,48 @@
"""Technique query service — framework-agnostic queries for technique details."""
from __future__ import annotations
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import EntityNotFoundError
from app.models.technique import Technique
from app.services.d3fend_import_service import get_defenses_for_technique
def get_technique_detail(db: Session, mitre_id: str) -> dict:
"""Fetch full technique details including tests and D3FEND defenses."""
technique = (
db.query(Technique)
.options(joinedload(Technique.tests))
.filter(Technique.mitre_id == mitre_id)
.first()
)
if technique is None:
raise EntityNotFoundError("Technique", mitre_id)
defenses = get_defenses_for_technique(db, technique.id)
return {
"id": str(technique.id),
"mitre_id": technique.mitre_id,
"name": technique.name,
"description": technique.description,
"tactic": technique.tactic,
"platforms": technique.platforms or [],
"mitre_version": technique.mitre_version,
"mitre_last_modified": technique.mitre_last_modified,
"is_subtechnique": technique.is_subtechnique,
"parent_mitre_id": technique.parent_mitre_id,
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
"review_required": technique.review_required,
"last_review_date": technique.last_review_date,
"tests": [
{
"id": str(t.id),
"name": t.name,
"state": t.state.value if t.state else None,
"result": t.result.value if t.result else None,
"platform": t.platform,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in technique.tests
],
"d3fend_defenses": defenses,
}

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.domain.errors import EntityNotFoundError
from app.models.worklog import Worklog from app.models.worklog import Worklog
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,6 +44,14 @@ def create_worklog(
return wl return wl
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
"""Get a worklog by ID or raise EntityNotFoundError."""
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
if not wl:
raise EntityNotFoundError("Worklog", str(worklog_id))
return wl
def list_worklogs( def list_worklogs(
db: Session, db: Session,
*, *,