feat: move all remaining inline logic from routers to services (Tier 2)
This commit is contained in:
@@ -30,7 +30,7 @@ from app.services.campaign_crud_service import (
|
||||
serialize_campaign,
|
||||
update_campaign as crud_update,
|
||||
)
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.notification_service import notify_role
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -237,17 +237,15 @@ def activate_campaign(
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||
for user in red_techs:
|
||||
create_notification(
|
||||
db,
|
||||
user_id=user.id,
|
||||
type="campaign_activated",
|
||||
title="Campaign activated",
|
||||
message=f'Campaign "{campaign.name}" has been activated.',
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
)
|
||||
notify_role(
|
||||
db,
|
||||
role="red_tech",
|
||||
type="campaign_activated",
|
||||
title="Campaign activated",
|
||||
message=f'Campaign "{campaign.name}" has been activated.',
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
|
||||
@@ -3,18 +3,20 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
||||
from app.services.d3fend_import_service import (
|
||||
import_d3fend_techniques,
|
||||
import_d3fend_mappings,
|
||||
get_defenses_for_technique,
|
||||
)
|
||||
from app.services.d3fend_query_service import (
|
||||
list_defensive_techniques as list_defensive_techniques_svc,
|
||||
list_d3fend_tactics,
|
||||
get_defenses_for_attack_technique,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,38 +38,9 @@ def list_defensive_techniques(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all D3FEND defensive techniques with optional filters."""
|
||||
query = db.query(DefensiveTechnique)
|
||||
|
||||
if tactic:
|
||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||
|
||||
if search:
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(
|
||||
DefensiveTechnique.name.ilike(pattern)
|
||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [
|
||||
{
|
||||
"id": str(dt.id),
|
||||
"d3fend_id": dt.d3fend_id,
|
||||
"name": dt.name,
|
||||
"description": dt.description,
|
||||
"tactic": dt.tactic,
|
||||
"d3fend_url": dt.d3fend_url,
|
||||
}
|
||||
for dt in items
|
||||
],
|
||||
}
|
||||
return list_defensive_techniques_svc(
|
||||
db, tactic=tactic, search=search, offset=offset, limit=limit
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -75,21 +48,12 @@ def list_defensive_techniques(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/tactics")
|
||||
def list_d3fend_tactics(
|
||||
def list_d3fend_tactics_endpoint(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return a list of all D3FEND tactics with counts."""
|
||||
from sqlalchemy import func
|
||||
|
||||
rows = (
|
||||
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
||||
.group_by(DefensiveTechnique.tactic)
|
||||
.order_by(DefensiveTechnique.tactic)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
||||
return list_d3fend_tactics(db)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -97,24 +61,13 @@ def list_d3fend_tactics(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/for-technique/{mitre_id}")
|
||||
def get_defenses_for_attack_technique(
|
||||
def get_defenses_for_attack_technique_endpoint(
|
||||
mitre_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||
if not technique:
|
||||
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
|
||||
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
|
||||
return {
|
||||
"mitre_id": mitre_id,
|
||||
"technique_name": technique.name,
|
||||
"defenses": defenses,
|
||||
"total": len(defenses),
|
||||
}
|
||||
return get_defenses_for_attack_technique(db, mitre_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -7,14 +7,9 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
from app.models.campaign import Campaign
|
||||
from app.models.jira_link import JiraLinkEntityType
|
||||
from app.models.user import User
|
||||
from app.schemas.jira_schema import (
|
||||
JiraIssueResult,
|
||||
@@ -45,23 +40,14 @@ def create_link(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Associate an Aegis entity with a Jira issue."""
|
||||
link = JiraLink(
|
||||
link = jira_service.create_link(
|
||||
db,
|
||||
entity_type=body.entity_type,
|
||||
entity_id=body.entity_id,
|
||||
jira_issue_key=body.jira_issue_key,
|
||||
sync_direction=body.sync_direction,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.add(link)
|
||||
db.flush()
|
||||
|
||||
# Pull initial data from Jira if enabled
|
||||
if settings.JIRA_ENABLED:
|
||||
try:
|
||||
jira_service.sync_jira_to_aegis(db, link)
|
||||
except Exception as e:
|
||||
logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e)
|
||||
|
||||
db.commit()
|
||||
db.refresh(link)
|
||||
|
||||
@@ -88,12 +74,11 @@ def list_links(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List Jira links, optionally filtered by entity."""
|
||||
query = db.query(JiraLink)
|
||||
if entity_type:
|
||||
query = query.filter(JiraLink.entity_type == entity_type)
|
||||
if entity_id:
|
||||
query = query.filter(JiraLink.entity_id == entity_id)
|
||||
return query.order_by(JiraLink.created_at.desc()).all()
|
||||
return jira_service.list_links(
|
||||
db,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/links/{link_id}/sync")
|
||||
@@ -103,9 +88,7 @@ def sync_link(
|
||||
user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Force bidirectional sync for a specific Jira link."""
|
||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||
if not link:
|
||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||
link = jira_service.get_link_or_raise(db, link_id)
|
||||
jira_service.sync_jira_to_aegis(db, link)
|
||||
db.commit()
|
||||
return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||
@@ -118,10 +101,7 @@ def delete_link(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Remove a Jira link."""
|
||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||
if not link:
|
||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||
db.delete(link)
|
||||
link = jira_service.delete_link(db, link_id)
|
||||
db.commit()
|
||||
audit_service.log_action(
|
||||
db,
|
||||
@@ -141,61 +121,11 @@ def create_issue_from_entity(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
||||
summary, description = _build_issue_data(db, entity_type, entity_id)
|
||||
result = jira_service.create_jira_issue(
|
||||
project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
summary=summary,
|
||||
description=description,
|
||||
labels=["aegis", entity_type.value],
|
||||
)
|
||||
link = JiraLink(
|
||||
result = jira_service.create_issue_and_link(
|
||||
db,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
jira_issue_key=result["issue_key"],
|
||||
jira_issue_id=result["issue_id"],
|
||||
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.add(link)
|
||||
db.commit()
|
||||
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
||||
|
||||
|
||||
def _build_issue_data(
|
||||
db: Session,
|
||||
entity_type: JiraLinkEntityType,
|
||||
entity_id: UUID,
|
||||
) -> tuple[str, str]:
|
||||
"""Build Jira issue summary + description from an Aegis entity."""
|
||||
if entity_type == JiraLinkEntityType.test:
|
||||
entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Test", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Test] {entity.name}",
|
||||
f"Test: {entity.name}\n"
|
||||
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.campaign:
|
||||
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Campaign", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Campaign] {entity.name}",
|
||||
f"Campaign: {entity.name}\n"
|
||||
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
elif entity_type == JiraLinkEntityType.technique:
|
||||
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
||||
if not entity:
|
||||
raise EntityNotFoundError("Technique", str(entity_id))
|
||||
return (
|
||||
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
||||
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
||||
f"Tactic: {entity.tactic or 'N/A'}\n"
|
||||
f"Description: {entity.description or 'N/A'}",
|
||||
)
|
||||
else:
|
||||
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
||||
return result
|
||||
|
||||
@@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.notification import Notification
|
||||
from app.models.user import User
|
||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||
from app.services.notification_service import (
|
||||
list_notifications,
|
||||
mark_as_read,
|
||||
mark_all_as_read,
|
||||
get_unread_count,
|
||||
@@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[NotificationOut])
|
||||
def list_notifications(
|
||||
def list_notifications_endpoint(
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return paginated notifications for the current user, newest first."""
|
||||
notifs = (
|
||||
db.query(Notification)
|
||||
.filter(Notification.user_id == current_user.id)
|
||||
.order_by(Notification.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return notifs
|
||||
return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -80,14 +72,8 @@ def read_notification(
|
||||
):
|
||||
"""Mark a single notification as read."""
|
||||
with UnitOfWork(db) as uow:
|
||||
success = mark_as_read(db, notification_id, current_user.id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Notification not found",
|
||||
)
|
||||
notif = mark_as_read(db, notification_id, current_user.id)
|
||||
uow.commit()
|
||||
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
||||
return notif
|
||||
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.osint_item import OsintItem
|
||||
from app.models.technique import Technique
|
||||
from app.models.user import User
|
||||
from app.services.osint_enrichment_service import (
|
||||
enrich_technique_with_cves,
|
||||
get_osint_items_for_technique,
|
||||
get_osint_summary,
|
||||
get_technique_or_raise,
|
||||
list_osint_items as service_list_osint_items,
|
||||
mark_osint_reviewed,
|
||||
get_unreviewed_count,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/osint", tags=["osint"])
|
||||
@@ -56,41 +56,15 @@ def list_osint_items(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List OSINT items with optional filters."""
|
||||
query = db.query(OsintItem)
|
||||
if technique_id:
|
||||
query = query.filter(OsintItem.technique_id == technique_id)
|
||||
if source_type:
|
||||
query = query.filter(OsintItem.source_type == source_type)
|
||||
if reviewed is not None:
|
||||
query = query.filter(OsintItem.reviewed == reviewed)
|
||||
|
||||
total = query.count()
|
||||
items = (
|
||||
query.order_by(OsintItem.discovered_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
return service_list_osint_items(
|
||||
db,
|
||||
technique_id=technique_id,
|
||||
source_type=source_type,
|
||||
reviewed=reviewed,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": [
|
||||
{
|
||||
"id": str(item.id),
|
||||
"technique_id": str(item.technique_id),
|
||||
"source_type": item.source_type,
|
||||
"source_url": item.source_url,
|
||||
"title": item.title,
|
||||
"description": item.description,
|
||||
"severity": item.severity,
|
||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||
"reviewed": item.reviewed,
|
||||
"metadata": item.metadata_,
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
def osint_summary(
|
||||
@@ -98,34 +72,7 @@ def osint_summary(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Summary statistics for OSINT items."""
|
||||
from sqlalchemy import func
|
||||
|
||||
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
||||
unreviewed = get_unreviewed_count(db)
|
||||
|
||||
by_severity = dict(
|
||||
db.query(OsintItem.severity, func.count(OsintItem.id))
|
||||
.group_by(OsintItem.severity)
|
||||
.all()
|
||||
)
|
||||
|
||||
by_type = dict(
|
||||
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
||||
.group_by(OsintItem.source_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
techniques_with_items = (
|
||||
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_items": total,
|
||||
"unreviewed": unreviewed,
|
||||
"techniques_with_items": techniques_with_items,
|
||||
"by_severity": by_severity,
|
||||
"by_type": by_type,
|
||||
}
|
||||
return get_osint_summary(db)
|
||||
|
||||
|
||||
@router.post("/items/{item_id}/review")
|
||||
@@ -151,13 +98,7 @@ def trigger_technique_enrichment(
|
||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Manually trigger OSINT enrichment for a single technique."""
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if not technique:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Technique not found",
|
||||
)
|
||||
|
||||
technique = get_technique_or_raise(db, technique_id)
|
||||
count = enrich_technique_with_cves(db, technique)
|
||||
return {
|
||||
"technique_id": str(technique.id),
|
||||
|
||||
@@ -5,19 +5,17 @@ Provides granular scoring with breakdowns and configurable weights.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor
|
||||
from app.services.scoring_service import (
|
||||
calculate_technique_score,
|
||||
score_technique_by_mitre_id,
|
||||
score_actor_by_id,
|
||||
calculate_tactic_score,
|
||||
calculate_actor_coverage_score,
|
||||
calculate_organization_score,
|
||||
get_score_history,
|
||||
)
|
||||
@@ -39,23 +37,7 @@ def score_technique(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed score with breakdown for a specific technique."""
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
.first()
|
||||
)
|
||||
if not technique:
|
||||
raise HTTPException(status_code=404, detail="Technique not found")
|
||||
|
||||
result = calculate_technique_score(technique, db)
|
||||
|
||||
return {
|
||||
"mitre_id": technique.mitre_id,
|
||||
"name": technique.name,
|
||||
"tactic": technique.tactic,
|
||||
"status_global": technique.status_global.value if technique.status_global else None,
|
||||
**result,
|
||||
}
|
||||
return score_technique_by_mitre_id(db, mitre_id)
|
||||
|
||||
|
||||
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
|
||||
@@ -81,11 +63,7 @@ def score_threat_actor(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get coverage score against a specific threat actor."""
|
||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||
if not actor:
|
||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
||||
|
||||
return calculate_actor_coverage_score(actor_id, db)
|
||||
return score_actor_by_id(db, actor_id)
|
||||
|
||||
|
||||
# ── GET /scores/organization ─────────────────────────────────────────
|
||||
|
||||
@@ -8,18 +8,24 @@ import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||
from app.domain.errors import BusinessRuleViolation
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.models.user import User
|
||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||
from app.services.snapshot_service import (
|
||||
create_snapshot,
|
||||
compare_snapshots,
|
||||
cleanup_old_snapshots,
|
||||
serialize_snapshot_summary,
|
||||
list_snapshots as list_snapshots_svc,
|
||||
get_snapshot_or_raise,
|
||||
get_snapshot_detail,
|
||||
delete_snapshot,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
||||
"""Lightweight serialization for list views."""
|
||||
return {
|
||||
"id": str(snap.id),
|
||||
"name": snap.name,
|
||||
"organization_score": snap.organization_score,
|
||||
"total_techniques": snap.total_techniques,
|
||||
"validated_count": snap.validated_count,
|
||||
"partial_count": snap.partial_count,
|
||||
"not_covered_count": snap.not_covered_count,
|
||||
"in_progress_count": snap.in_progress_count,
|
||||
"not_evaluated_count": snap.not_evaluated_count,
|
||||
"created_by": str(snap.created_by) if snap.created_by else None,
|
||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
||||
"""Full serialization including technique states."""
|
||||
base = _serialize_snapshot_summary(snap)
|
||||
|
||||
technique_states = (
|
||||
db.query(SnapshotTechniqueState)
|
||||
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
||||
.order_by(SnapshotTechniqueState.mitre_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
base["technique_states"] = [
|
||||
{
|
||||
"mitre_id": s.mitre_id,
|
||||
"technique_id": str(s.technique_id),
|
||||
"status": s.status,
|
||||
"score": s.score,
|
||||
}
|
||||
for s in technique_states
|
||||
]
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /snapshots — List snapshots (paginated)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -88,23 +52,7 @@ def list_snapshots(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||
query = db.query(CoverageSnapshot)
|
||||
total = query.count()
|
||||
|
||||
snapshots = (
|
||||
query
|
||||
.order_by(CoverageSnapshot.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [_serialize_snapshot_summary(s) for s in snapshots],
|
||||
}
|
||||
return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
|
||||
details={"name": snapshot.name, "score": snapshot.organization_score},
|
||||
)
|
||||
|
||||
return _serialize_snapshot_summary(snapshot)
|
||||
return serialize_snapshot_summary(snapshot)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
|
||||
a_id = uuid.UUID(a)
|
||||
b_id = uuid.UUID(b)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
|
||||
raise BusinessRuleViolation("Invalid snapshot ID format")
|
||||
|
||||
result = compare_snapshots(db, a_id, b_id)
|
||||
if "error" in result:
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return result
|
||||
return compare_snapshots(db, a_id, b_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -168,11 +112,7 @@ def get_snapshot(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed snapshot information including per-technique states."""
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
|
||||
return _serialize_snapshot_detail(db, snapshot)
|
||||
return get_snapshot_detail(db, snapshot_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -180,15 +120,13 @@ def get_snapshot(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/{snapshot_id}")
|
||||
def delete_snapshot(
|
||||
def delete_snapshot_endpoint(
|
||||
snapshot_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Delete a snapshot (admin only)."""
|
||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
@@ -199,7 +137,8 @@ def delete_snapshot(
|
||||
details={"name": snapshot.name},
|
||||
)
|
||||
|
||||
db.delete(snapshot)
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
delete_snapshot(db, snapshot_id)
|
||||
uow.commit()
|
||||
|
||||
return {"detail": "Snapshot deleted"}
|
||||
|
||||
@@ -6,7 +6,7 @@ exceptions to HTTP responses automatically.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||
@@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork
|
||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||
SATechniqueRepository,
|
||||
)
|
||||
from app.models.technique import Technique
|
||||
from app.models.user import User
|
||||
from app.schemas.technique import (
|
||||
TechniqueCreate,
|
||||
@@ -27,7 +26,7 @@ from app.schemas.technique import (
|
||||
TechniqueUpdate,
|
||||
)
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||
from app.services.technique_query_service import get_technique_detail
|
||||
|
||||
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||
|
||||
@@ -67,45 +66,7 @@ def get_technique(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
.options(joinedload(Technique.tests))
|
||||
.filter(Technique.mitre_id == mitre_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", mitre_id)
|
||||
|
||||
defenses = get_defenses_for_technique(db, technique.id)
|
||||
|
||||
return {
|
||||
"id": str(technique.id),
|
||||
"mitre_id": technique.mitre_id,
|
||||
"name": technique.name,
|
||||
"description": technique.description,
|
||||
"tactic": technique.tactic,
|
||||
"platforms": technique.platforms or [],
|
||||
"mitre_version": technique.mitre_version,
|
||||
"mitre_last_modified": technique.mitre_last_modified,
|
||||
"is_subtechnique": technique.is_subtechnique,
|
||||
"parent_mitre_id": technique.parent_mitre_id,
|
||||
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
||||
"review_required": technique.review_required,
|
||||
"last_review_date": technique.last_review_date,
|
||||
"tests": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"state": t.state.value if t.state else None,
|
||||
"result": t.result.value if t.result else None,
|
||||
"platform": t.platform,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in technique.tests
|
||||
],
|
||||
"d3fend_defenses": defenses,
|
||||
}
|
||||
return get_technique_detail(db, mitre_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -10,9 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.domain.exceptions import EntityNotFoundError
|
||||
from app.models.user import User
|
||||
from app.models.worklog import Worklog
|
||||
from app.services import worklog_service
|
||||
|
||||
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||
@@ -97,10 +95,7 @@ def get_one(
|
||||
_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a single worklog by ID."""
|
||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||
if not wl:
|
||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||
return wl
|
||||
return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
|
||||
|
||||
@router.get("/{worklog_id}/verify")
|
||||
@@ -110,9 +105,7 @@ def verify_integrity(
|
||||
_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Check whether a worklog's integrity hash is still valid."""
|
||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||
if not wl:
|
||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||
return {
|
||||
"worklog_id": str(wl.id),
|
||||
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
||||
|
||||
Reference in New Issue
Block a user