feat: move all remaining inline logic from routers to services (Tier 2)
This commit is contained in:
@@ -30,7 +30,7 @@ from app.services.campaign_crud_service import (
|
|||||||
serialize_campaign,
|
serialize_campaign,
|
||||||
update_campaign as crud_update,
|
update_campaign as crud_update,
|
||||||
)
|
)
|
||||||
from app.services.notification_service import create_notification
|
from app.services.notification_service import notify_role
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -237,17 +237,15 @@ def activate_campaign(
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
notify_role(
|
||||||
for user in red_techs:
|
db,
|
||||||
create_notification(
|
role="red_tech",
|
||||||
db,
|
type="campaign_activated",
|
||||||
user_id=user.id,
|
title="Campaign activated",
|
||||||
type="campaign_activated",
|
message=f'Campaign "{campaign.name}" has been activated.',
|
||||||
title="Campaign activated",
|
entity_type="campaign",
|
||||||
message=f'Campaign "{campaign.name}" has been activated.',
|
entity_id=campaign.id,
|
||||||
entity_type="campaign",
|
)
|
||||||
entity_id=campaign.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
|
|||||||
@@ -3,18 +3,20 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping
|
|
||||||
from app.services.d3fend_import_service import (
|
from app.services.d3fend_import_service import (
|
||||||
import_d3fend_techniques,
|
import_d3fend_techniques,
|
||||||
import_d3fend_mappings,
|
import_d3fend_mappings,
|
||||||
get_defenses_for_technique,
|
)
|
||||||
|
from app.services.d3fend_query_service import (
|
||||||
|
list_defensive_techniques as list_defensive_techniques_svc,
|
||||||
|
list_d3fend_tactics,
|
||||||
|
get_defenses_for_attack_technique,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -36,38 +38,9 @@ def list_defensive_techniques(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List all D3FEND defensive techniques with optional filters."""
|
"""List all D3FEND defensive techniques with optional filters."""
|
||||||
query = db.query(DefensiveTechnique)
|
return list_defensive_techniques_svc(
|
||||||
|
db, tactic=tactic, search=search, offset=offset, limit=limit
|
||||||
if tactic:
|
)
|
||||||
query = query.filter(DefensiveTechnique.tactic == tactic)
|
|
||||||
|
|
||||||
if search:
|
|
||||||
from app.utils import escape_like
|
|
||||||
pattern = f"%{escape_like(search)}%"
|
|
||||||
query = query.filter(
|
|
||||||
DefensiveTechnique.name.ilike(pattern)
|
|
||||||
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
|
||||||
)
|
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"offset": offset,
|
|
||||||
"limit": limit,
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"id": str(dt.id),
|
|
||||||
"d3fend_id": dt.d3fend_id,
|
|
||||||
"name": dt.name,
|
|
||||||
"description": dt.description,
|
|
||||||
"tactic": dt.tactic,
|
|
||||||
"d3fend_url": dt.d3fend_url,
|
|
||||||
}
|
|
||||||
for dt in items
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -75,21 +48,12 @@ def list_defensive_techniques(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/tactics")
|
@router.get("/tactics")
|
||||||
def list_d3fend_tactics(
|
def list_d3fend_tactics_endpoint(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return a list of all D3FEND tactics with counts."""
|
"""Return a list of all D3FEND tactics with counts."""
|
||||||
from sqlalchemy import func
|
return list_d3fend_tactics(db)
|
||||||
|
|
||||||
rows = (
|
|
||||||
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
|
||||||
.group_by(DefensiveTechnique.tactic)
|
|
||||||
.order_by(DefensiveTechnique.tactic)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -97,24 +61,13 @@ def list_d3fend_tactics(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.get("/for-technique/{mitre_id}")
|
@router.get("/for-technique/{mitre_id}")
|
||||||
def get_defenses_for_attack_technique(
|
def get_defenses_for_attack_technique_endpoint(
|
||||||
mitre_id: str,
|
mitre_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||||
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
return get_defenses_for_attack_technique(db, mitre_id)
|
||||||
if not technique:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found")
|
|
||||||
|
|
||||||
defenses = get_defenses_for_technique(db, technique.id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mitre_id": mitre_id,
|
|
||||||
"technique_name": technique.name,
|
|
||||||
"defenses": defenses,
|
|
||||||
"total": len(defenses),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -7,14 +7,9 @@ from uuid import UUID
|
|||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
from app.domain.exceptions import EntityNotFoundError
|
from app.models.jira_link import JiraLinkEntityType
|
||||||
from app.models.jira_link import JiraLink, JiraLinkEntityType
|
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.campaign import Campaign
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.jira_schema import (
|
from app.schemas.jira_schema import (
|
||||||
JiraIssueResult,
|
JiraIssueResult,
|
||||||
@@ -45,23 +40,14 @@ def create_link(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Associate an Aegis entity with a Jira issue."""
|
"""Associate an Aegis entity with a Jira issue."""
|
||||||
link = JiraLink(
|
link = jira_service.create_link(
|
||||||
|
db,
|
||||||
entity_type=body.entity_type,
|
entity_type=body.entity_type,
|
||||||
entity_id=body.entity_id,
|
entity_id=body.entity_id,
|
||||||
jira_issue_key=body.jira_issue_key,
|
jira_issue_key=body.jira_issue_key,
|
||||||
sync_direction=body.sync_direction,
|
sync_direction=body.sync_direction,
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
)
|
)
|
||||||
db.add(link)
|
|
||||||
db.flush()
|
|
||||||
|
|
||||||
# Pull initial data from Jira if enabled
|
|
||||||
if settings.JIRA_ENABLED:
|
|
||||||
try:
|
|
||||||
jira_service.sync_jira_to_aegis(db, link)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e)
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(link)
|
db.refresh(link)
|
||||||
|
|
||||||
@@ -88,12 +74,11 @@ def list_links(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List Jira links, optionally filtered by entity."""
|
"""List Jira links, optionally filtered by entity."""
|
||||||
query = db.query(JiraLink)
|
return jira_service.list_links(
|
||||||
if entity_type:
|
db,
|
||||||
query = query.filter(JiraLink.entity_type == entity_type)
|
entity_type=entity_type,
|
||||||
if entity_id:
|
entity_id=entity_id,
|
||||||
query = query.filter(JiraLink.entity_id == entity_id)
|
)
|
||||||
return query.order_by(JiraLink.created_at.desc()).all()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/links/{link_id}/sync")
|
@router.post("/links/{link_id}/sync")
|
||||||
@@ -103,9 +88,7 @@ def sync_link(
|
|||||||
user: User = Depends(require_role("admin")),
|
user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Force bidirectional sync for a specific Jira link."""
|
"""Force bidirectional sync for a specific Jira link."""
|
||||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
link = jira_service.get_link_or_raise(db, link_id)
|
||||||
if not link:
|
|
||||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
|
||||||
jira_service.sync_jira_to_aegis(db, link)
|
jira_service.sync_jira_to_aegis(db, link)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"message": "Sync completed", "jira_status": link.jira_status}
|
return {"message": "Sync completed", "jira_status": link.jira_status}
|
||||||
@@ -118,10 +101,7 @@ def delete_link(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Remove a Jira link."""
|
"""Remove a Jira link."""
|
||||||
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
link = jira_service.delete_link(db, link_id)
|
||||||
if not link:
|
|
||||||
raise EntityNotFoundError("JiraLink", str(link_id))
|
|
||||||
db.delete(link)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
audit_service.log_action(
|
audit_service.log_action(
|
||||||
db,
|
db,
|
||||||
@@ -141,61 +121,11 @@ def create_issue_from_entity(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
"""Auto-create a Jira issue from an Aegis entity and link them."""
|
||||||
summary, description = _build_issue_data(db, entity_type, entity_id)
|
result = jira_service.create_issue_and_link(
|
||||||
result = jira_service.create_jira_issue(
|
db,
|
||||||
project_key=settings.JIRA_DEFAULT_PROJECT,
|
|
||||||
summary=summary,
|
|
||||||
description=description,
|
|
||||||
labels=["aegis", entity_type.value],
|
|
||||||
)
|
|
||||||
link = JiraLink(
|
|
||||||
entity_type=entity_type,
|
entity_type=entity_type,
|
||||||
entity_id=entity_id,
|
entity_id=entity_id,
|
||||||
jira_issue_key=result["issue_key"],
|
|
||||||
jira_issue_id=result["issue_id"],
|
|
||||||
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
|
||||||
created_by=user.id,
|
created_by=user.id,
|
||||||
)
|
)
|
||||||
db.add(link)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _build_issue_data(
|
|
||||||
db: Session,
|
|
||||||
entity_type: JiraLinkEntityType,
|
|
||||||
entity_id: UUID,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""Build Jira issue summary + description from an Aegis entity."""
|
|
||||||
if entity_type == JiraLinkEntityType.test:
|
|
||||||
entity = db.query(Test).filter(Test.id == entity_id).first()
|
|
||||||
if not entity:
|
|
||||||
raise EntityNotFoundError("Test", str(entity_id))
|
|
||||||
return (
|
|
||||||
f"[Aegis Test] {entity.name}",
|
|
||||||
f"Test: {entity.name}\n"
|
|
||||||
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
|
||||||
f"Description: {entity.description or 'N/A'}",
|
|
||||||
)
|
|
||||||
elif entity_type == JiraLinkEntityType.campaign:
|
|
||||||
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
|
||||||
if not entity:
|
|
||||||
raise EntityNotFoundError("Campaign", str(entity_id))
|
|
||||||
return (
|
|
||||||
f"[Aegis Campaign] {entity.name}",
|
|
||||||
f"Campaign: {entity.name}\n"
|
|
||||||
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
|
||||||
f"Description: {entity.description or 'N/A'}",
|
|
||||||
)
|
|
||||||
elif entity_type == JiraLinkEntityType.technique:
|
|
||||||
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
|
||||||
if not entity:
|
|
||||||
raise EntityNotFoundError("Technique", str(entity_id))
|
|
||||||
return (
|
|
||||||
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
|
||||||
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
|
||||||
f"Tactic: {entity.tactic or 'N/A'}\n"
|
|
||||||
f"Description: {entity.description or 'N/A'}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
|
||||||
|
|||||||
@@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.notification import Notification
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.notification import NotificationOut, UnreadCountOut
|
from app.schemas.notification import NotificationOut, UnreadCountOut
|
||||||
from app.services.notification_service import (
|
from app.services.notification_service import (
|
||||||
|
list_notifications,
|
||||||
mark_as_read,
|
mark_as_read,
|
||||||
mark_all_as_read,
|
mark_all_as_read,
|
||||||
get_unread_count,
|
get_unread_count,
|
||||||
@@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"])
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=list[NotificationOut])
|
@router.get("", response_model=list[NotificationOut])
|
||||||
def list_notifications(
|
def list_notifications_endpoint(
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
limit: int = Query(20, ge=1, le=100),
|
limit: int = Query(20, ge=1, le=100),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return paginated notifications for the current user, newest first."""
|
"""Return paginated notifications for the current user, newest first."""
|
||||||
notifs = (
|
return list_notifications(db, current_user.id, offset=offset, limit=limit)
|
||||||
db.query(Notification)
|
|
||||||
.filter(Notification.user_id == current_user.id)
|
|
||||||
.order_by(Notification.created_at.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return notifs
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -80,14 +72,8 @@ def read_notification(
|
|||||||
):
|
):
|
||||||
"""Mark a single notification as read."""
|
"""Mark a single notification as read."""
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
success = mark_as_read(db, notification_id, current_user.id)
|
notif = mark_as_read(db, notification_id, current_user.id)
|
||||||
if not success:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Notification not found",
|
|
||||||
)
|
|
||||||
uow.commit()
|
uow.commit()
|
||||||
notif = db.query(Notification).filter(Notification.id == notification_id).first()
|
|
||||||
return notif
|
return notif
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,14 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.osint_item import OsintItem
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.osint_enrichment_service import (
|
from app.services.osint_enrichment_service import (
|
||||||
enrich_technique_with_cves,
|
enrich_technique_with_cves,
|
||||||
get_osint_items_for_technique,
|
get_osint_items_for_technique,
|
||||||
|
get_osint_summary,
|
||||||
|
get_technique_or_raise,
|
||||||
|
list_osint_items as service_list_osint_items,
|
||||||
mark_osint_reviewed,
|
mark_osint_reviewed,
|
||||||
get_unreviewed_count,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/osint", tags=["osint"])
|
router = APIRouter(prefix="/osint", tags=["osint"])
|
||||||
@@ -56,41 +56,15 @@ def list_osint_items(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List OSINT items with optional filters."""
|
"""List OSINT items with optional filters."""
|
||||||
query = db.query(OsintItem)
|
return service_list_osint_items(
|
||||||
if technique_id:
|
db,
|
||||||
query = query.filter(OsintItem.technique_id == technique_id)
|
technique_id=technique_id,
|
||||||
if source_type:
|
source_type=source_type,
|
||||||
query = query.filter(OsintItem.source_type == source_type)
|
reviewed=reviewed,
|
||||||
if reviewed is not None:
|
offset=offset,
|
||||||
query = query.filter(OsintItem.reviewed == reviewed)
|
limit=limit,
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
items = (
|
|
||||||
query.order_by(OsintItem.discovered_at.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"id": str(item.id),
|
|
||||||
"technique_id": str(item.technique_id),
|
|
||||||
"source_type": item.source_type,
|
|
||||||
"source_url": item.source_url,
|
|
||||||
"title": item.title,
|
|
||||||
"description": item.description,
|
|
||||||
"severity": item.severity,
|
|
||||||
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
|
||||||
"reviewed": item.reviewed,
|
|
||||||
"metadata": item.metadata_,
|
|
||||||
}
|
|
||||||
for item in items
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/summary")
|
@router.get("/summary")
|
||||||
def osint_summary(
|
def osint_summary(
|
||||||
@@ -98,34 +72,7 @@ def osint_summary(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Summary statistics for OSINT items."""
|
"""Summary statistics for OSINT items."""
|
||||||
from sqlalchemy import func
|
return get_osint_summary(db)
|
||||||
|
|
||||||
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
|
||||||
unreviewed = get_unreviewed_count(db)
|
|
||||||
|
|
||||||
by_severity = dict(
|
|
||||||
db.query(OsintItem.severity, func.count(OsintItem.id))
|
|
||||||
.group_by(OsintItem.severity)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
by_type = dict(
|
|
||||||
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
|
||||||
.group_by(OsintItem.source_type)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
techniques_with_items = (
|
|
||||||
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_items": total,
|
|
||||||
"unreviewed": unreviewed,
|
|
||||||
"techniques_with_items": techniques_with_items,
|
|
||||||
"by_severity": by_severity,
|
|
||||||
"by_type": by_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/items/{item_id}/review")
|
@router.post("/items/{item_id}/review")
|
||||||
@@ -151,13 +98,7 @@ def trigger_technique_enrichment(
|
|||||||
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Manually trigger OSINT enrichment for a single technique."""
|
"""Manually trigger OSINT enrichment for a single technique."""
|
||||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
technique = get_technique_or_raise(db, technique_id)
|
||||||
if not technique:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Technique not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
count = enrich_technique_with_cves(db, technique)
|
count = enrich_technique_with_cves(db, technique)
|
||||||
return {
|
return {
|
||||||
"technique_id": str(technique.id),
|
"technique_id": str(technique.id),
|
||||||
|
|||||||
@@ -5,19 +5,17 @@ Provides granular scoring with breakdowns and configurable weights.
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role
|
from app.dependencies.auth import get_current_user, require_role
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.threat_actor import ThreatActor
|
|
||||||
from app.services.scoring_service import (
|
from app.services.scoring_service import (
|
||||||
calculate_technique_score,
|
score_technique_by_mitre_id,
|
||||||
|
score_actor_by_id,
|
||||||
calculate_tactic_score,
|
calculate_tactic_score,
|
||||||
calculate_actor_coverage_score,
|
|
||||||
calculate_organization_score,
|
calculate_organization_score,
|
||||||
get_score_history,
|
get_score_history,
|
||||||
)
|
)
|
||||||
@@ -39,23 +37,7 @@ def score_technique(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get detailed score with breakdown for a specific technique."""
|
"""Get detailed score with breakdown for a specific technique."""
|
||||||
technique = (
|
return score_technique_by_mitre_id(db, mitre_id)
|
||||||
db.query(Technique)
|
|
||||||
.filter(Technique.mitre_id == mitre_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not technique:
|
|
||||||
raise HTTPException(status_code=404, detail="Technique not found")
|
|
||||||
|
|
||||||
result = calculate_technique_score(technique, db)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mitre_id": technique.mitre_id,
|
|
||||||
"name": technique.name,
|
|
||||||
"tactic": technique.tactic,
|
|
||||||
"status_global": technique.status_global.value if technique.status_global else None,
|
|
||||||
**result,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
|
# ── GET /scores/tactic/{tactic} ──────────────────────────────────────
|
||||||
@@ -81,11 +63,7 @@ def score_threat_actor(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get coverage score against a specific threat actor."""
|
"""Get coverage score against a specific threat actor."""
|
||||||
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
return score_actor_by_id(db, actor_id)
|
||||||
if not actor:
|
|
||||||
raise HTTPException(status_code=404, detail="Threat actor not found")
|
|
||||||
|
|
||||||
return calculate_actor_coverage_score(actor_id, db)
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /scores/organization ─────────────────────────────────────────
|
# ── GET /scores/organization ─────────────────────────────────────────
|
||||||
|
|||||||
@@ -8,18 +8,24 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
from app.dependencies.auth import get_current_user, require_any_role, require_role
|
||||||
|
from app.domain.errors import BusinessRuleViolation
|
||||||
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
|
||||||
from app.services.snapshot_service import (
|
from app.services.snapshot_service import (
|
||||||
create_snapshot,
|
create_snapshot,
|
||||||
compare_snapshots,
|
compare_snapshots,
|
||||||
cleanup_old_snapshots,
|
cleanup_old_snapshots,
|
||||||
|
serialize_snapshot_summary,
|
||||||
|
list_snapshots as list_snapshots_svc,
|
||||||
|
get_snapshot_or_raise,
|
||||||
|
get_snapshot_detail,
|
||||||
|
delete_snapshot,
|
||||||
)
|
)
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
@@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel):
|
|||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
|
||||||
"""Lightweight serialization for list views."""
|
|
||||||
return {
|
|
||||||
"id": str(snap.id),
|
|
||||||
"name": snap.name,
|
|
||||||
"organization_score": snap.organization_score,
|
|
||||||
"total_techniques": snap.total_techniques,
|
|
||||||
"validated_count": snap.validated_count,
|
|
||||||
"partial_count": snap.partial_count,
|
|
||||||
"not_covered_count": snap.not_covered_count,
|
|
||||||
"in_progress_count": snap.in_progress_count,
|
|
||||||
"not_evaluated_count": snap.not_evaluated_count,
|
|
||||||
"created_by": str(snap.created_by) if snap.created_by else None,
|
|
||||||
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
|
||||||
"""Full serialization including technique states."""
|
|
||||||
base = _serialize_snapshot_summary(snap)
|
|
||||||
|
|
||||||
technique_states = (
|
|
||||||
db.query(SnapshotTechniqueState)
|
|
||||||
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
|
||||||
.order_by(SnapshotTechniqueState.mitre_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
base["technique_states"] = [
|
|
||||||
{
|
|
||||||
"mitre_id": s.mitre_id,
|
|
||||||
"technique_id": str(s.technique_id),
|
|
||||||
"status": s.status,
|
|
||||||
"score": s.score,
|
|
||||||
}
|
|
||||||
for s in technique_states
|
|
||||||
]
|
|
||||||
return base
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /snapshots — List snapshots (paginated)
|
# GET /snapshots — List snapshots (paginated)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -88,23 +52,7 @@ def list_snapshots(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List coverage snapshots ordered by creation date (newest first)."""
|
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||||
query = db.query(CoverageSnapshot)
|
return list_snapshots_svc(db, offset=offset, limit=limit)
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
snapshots = (
|
|
||||||
query
|
|
||||||
.order_by(CoverageSnapshot.created_at.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"offset": offset,
|
|
||||||
"limit": limit,
|
|
||||||
"items": [_serialize_snapshot_summary(s) for s in snapshots],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -129,7 +77,7 @@ def create_snapshot_endpoint(
|
|||||||
details={"name": snapshot.name, "score": snapshot.organization_score},
|
details={"name": snapshot.name, "score": snapshot.organization_score},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_snapshot_summary(snapshot)
|
return serialize_snapshot_summary(snapshot)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -148,13 +96,9 @@ def compare_snapshots_endpoint(
|
|||||||
a_id = uuid.UUID(a)
|
a_id = uuid.UUID(a)
|
||||||
b_id = uuid.UUID(b)
|
b_id = uuid.UUID(b)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(status_code=400, detail="Invalid snapshot ID format")
|
raise BusinessRuleViolation("Invalid snapshot ID format")
|
||||||
|
|
||||||
result = compare_snapshots(db, a_id, b_id)
|
return compare_snapshots(db, a_id, b_id)
|
||||||
if "error" in result:
|
|
||||||
raise HTTPException(status_code=404, detail=result["error"])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -168,11 +112,7 @@ def get_snapshot(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get detailed snapshot information including per-technique states."""
|
"""Get detailed snapshot information including per-technique states."""
|
||||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
return get_snapshot_detail(db, snapshot_id)
|
||||||
if not snapshot:
|
|
||||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
||||||
|
|
||||||
return _serialize_snapshot_detail(db, snapshot)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -180,15 +120,13 @@ def get_snapshot(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@router.delete("/{snapshot_id}")
|
@router.delete("/{snapshot_id}")
|
||||||
def delete_snapshot(
|
def delete_snapshot_endpoint(
|
||||||
snapshot_id: str,
|
snapshot_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Delete a snapshot (admin only)."""
|
"""Delete a snapshot (admin only)."""
|
||||||
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first()
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
if not snapshot:
|
|
||||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
@@ -199,7 +137,8 @@ def delete_snapshot(
|
|||||||
details={"name": snapshot.name},
|
details={"name": snapshot.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
db.delete(snapshot)
|
with UnitOfWork(db) as uow:
|
||||||
db.commit()
|
delete_snapshot(db, snapshot_id)
|
||||||
|
uow.commit()
|
||||||
|
|
||||||
return {"detail": "Snapshot deleted"}
|
return {"detail": "Snapshot deleted"}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ exceptions to HTTP responses automatically.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
from app.dependencies.auth import get_current_user, require_role, require_any_role
|
||||||
@@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork
|
|||||||
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
from app.infrastructure.persistence.repositories.sa_technique_repository import (
|
||||||
SATechniqueRepository,
|
SATechniqueRepository,
|
||||||
)
|
)
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.technique import (
|
from app.schemas.technique import (
|
||||||
TechniqueCreate,
|
TechniqueCreate,
|
||||||
@@ -27,7 +26,7 @@ from app.schemas.technique import (
|
|||||||
TechniqueUpdate,
|
TechniqueUpdate,
|
||||||
)
|
)
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
from app.services.d3fend_import_service import get_defenses_for_technique
|
from app.services.technique_query_service import get_technique_detail
|
||||||
|
|
||||||
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||||
|
|
||||||
@@ -67,45 +66,7 @@ def get_technique(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
"""Return full details for a single technique, including its tests and D3FEND defenses."""
|
||||||
technique = (
|
return get_technique_detail(db, mitre_id)
|
||||||
db.query(Technique)
|
|
||||||
.options(joinedload(Technique.tests))
|
|
||||||
.filter(Technique.mitre_id == mitre_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if technique is None:
|
|
||||||
raise EntityNotFoundError("Technique", mitre_id)
|
|
||||||
|
|
||||||
defenses = get_defenses_for_technique(db, technique.id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(technique.id),
|
|
||||||
"mitre_id": technique.mitre_id,
|
|
||||||
"name": technique.name,
|
|
||||||
"description": technique.description,
|
|
||||||
"tactic": technique.tactic,
|
|
||||||
"platforms": technique.platforms or [],
|
|
||||||
"mitre_version": technique.mitre_version,
|
|
||||||
"mitre_last_modified": technique.mitre_last_modified,
|
|
||||||
"is_subtechnique": technique.is_subtechnique,
|
|
||||||
"parent_mitre_id": technique.parent_mitre_id,
|
|
||||||
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
|
||||||
"review_required": technique.review_required,
|
|
||||||
"last_review_date": technique.last_review_date,
|
|
||||||
"tests": [
|
|
||||||
{
|
|
||||||
"id": str(t.id),
|
|
||||||
"name": t.name,
|
|
||||||
"state": t.state.value if t.state else None,
|
|
||||||
"result": t.result.value if t.result else None,
|
|
||||||
"platform": t.platform,
|
|
||||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
|
||||||
}
|
|
||||||
for t in technique.tests
|
|
||||||
],
|
|
||||||
"d3fend_defenses": defenses,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.domain.exceptions import EntityNotFoundError
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.worklog import Worklog
|
|
||||||
from app.services import worklog_service
|
from app.services import worklog_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
router = APIRouter(prefix="/worklogs", tags=["worklogs"])
|
||||||
@@ -97,10 +95,7 @@ def get_one(
|
|||||||
_user: User = Depends(get_current_user),
|
_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get a single worklog by ID."""
|
"""Get a single worklog by ID."""
|
||||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
return worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
if not wl:
|
|
||||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
|
||||||
return wl
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{worklog_id}/verify")
|
@router.get("/{worklog_id}/verify")
|
||||||
@@ -110,9 +105,7 @@ def verify_integrity(
|
|||||||
_user: User = Depends(get_current_user),
|
_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Check whether a worklog's integrity hash is still valid."""
|
"""Check whether a worklog's integrity hash is still valid."""
|
||||||
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
|
||||||
if not wl:
|
|
||||||
raise EntityNotFoundError("Worklog", str(worklog_id))
|
|
||||||
return {
|
return {
|
||||||
"worklog_id": str(wl.id),
|
"worklog_id": str(wl.id),
|
||||||
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
"integrity_valid": worklog_service.verify_worklog_integrity(wl),
|
||||||
|
|||||||
82
backend/app/services/d3fend_query_service.py
Normal file
82
backend/app/services/d3fend_query_service.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""D3FEND query service — framework-agnostic queries for defensive techniques."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
|
from app.models.defensive_technique import DefensiveTechnique
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||||
|
from app.utils import escape_like
|
||||||
|
|
||||||
|
|
||||||
|
def list_defensive_techniques(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
tactic: Optional[str] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""List D3FEND defensive techniques with optional filters."""
|
||||||
|
query = db.query(DefensiveTechnique)
|
||||||
|
|
||||||
|
if tactic:
|
||||||
|
query = query.filter(DefensiveTechnique.tactic == tactic)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
pattern = f"%{escape_like(search)}%"
|
||||||
|
query = query.filter(
|
||||||
|
DefensiveTechnique.name.ilike(pattern)
|
||||||
|
| DefensiveTechnique.d3fend_id.ilike(pattern)
|
||||||
|
)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"offset": offset,
|
||||||
|
"limit": limit,
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"id": str(dt.id),
|
||||||
|
"d3fend_id": dt.d3fend_id,
|
||||||
|
"name": dt.name,
|
||||||
|
"description": dt.description,
|
||||||
|
"tactic": dt.tactic,
|
||||||
|
"d3fend_url": dt.d3fend_url,
|
||||||
|
}
|
||||||
|
for dt in items
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def list_d3fend_tactics(db: Session) -> list[dict]:
|
||||||
|
"""Return a list of all D3FEND tactics with counts."""
|
||||||
|
rows = (
|
||||||
|
db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id))
|
||||||
|
.group_by(DefensiveTechnique.tactic)
|
||||||
|
.order_by(DefensiveTechnique.tactic)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict:
|
||||||
|
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
|
||||||
|
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
|
||||||
|
defenses = get_defenses_for_technique(db, technique.id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mitre_id": mitre_id,
|
||||||
|
"technique_name": technique.name,
|
||||||
|
"defenses": defenses,
|
||||||
|
"total": len(defenses),
|
||||||
|
}
|
||||||
@@ -3,12 +3,17 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.domain.exceptions import InvalidOperationError
|
from app.domain.exceptions import InvalidOperationError
|
||||||
from app.models.jira_link import JiraLink
|
from app.models.campaign import Campaign
|
||||||
|
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str:
|
|||||||
lines.append(f"*{key}:* {value}")
|
lines.append(f"*{key}:* {value}")
|
||||||
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
|
lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Link CRUD ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def create_link(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
entity_type: JiraLinkEntityType,
|
||||||
|
entity_id: UUID,
|
||||||
|
jira_issue_key: str,
|
||||||
|
sync_direction: JiraSyncDirection,
|
||||||
|
created_by: UUID,
|
||||||
|
) -> JiraLink:
|
||||||
|
"""Create a Jira link and optionally pull initial data from Jira."""
|
||||||
|
link = JiraLink(
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
jira_issue_key=jira_issue_key,
|
||||||
|
sync_direction=sync_direction,
|
||||||
|
created_by=created_by,
|
||||||
|
)
|
||||||
|
db.add(link)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
if settings.JIRA_ENABLED:
|
||||||
|
try:
|
||||||
|
sync_jira_to_aegis(db, link)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e)
|
||||||
|
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
def list_links(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
entity_type: Optional[JiraLinkEntityType] = None,
|
||||||
|
entity_id: Optional[UUID] = None,
|
||||||
|
) -> list[JiraLink]:
|
||||||
|
"""List Jira links with optional filters."""
|
||||||
|
query = db.query(JiraLink)
|
||||||
|
if entity_type:
|
||||||
|
query = query.filter(JiraLink.entity_type == entity_type)
|
||||||
|
if entity_id:
|
||||||
|
query = query.filter(JiraLink.entity_id == entity_id)
|
||||||
|
return query.order_by(JiraLink.created_at.desc()).all()
|
||||||
|
|
||||||
|
|
||||||
|
def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink:
|
||||||
|
"""Get a Jira link by ID or raise EntityNotFoundError."""
|
||||||
|
link = db.query(JiraLink).filter(JiraLink.id == link_id).first()
|
||||||
|
if not link:
|
||||||
|
raise EntityNotFoundError("JiraLink", str(link_id))
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
def delete_link(db: Session, link_id: UUID) -> JiraLink:
|
||||||
|
"""Delete a Jira link. Returns the deleted link (for audit)."""
|
||||||
|
link = get_link_or_raise(db, link_id)
|
||||||
|
db.delete(link)
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]:
|
||||||
|
"""Build Jira issue summary and description from an Aegis entity."""
|
||||||
|
if entity_type == JiraLinkEntityType.test:
|
||||||
|
entity = db.query(Test).filter(Test.id == entity_id).first()
|
||||||
|
if not entity:
|
||||||
|
raise EntityNotFoundError("Test", str(entity_id))
|
||||||
|
return (
|
||||||
|
f"[Aegis Test] {entity.name}",
|
||||||
|
f"Test: {entity.name}\n"
|
||||||
|
f"State: {entity.state.value if entity.state else 'draft'}\n"
|
||||||
|
f"Description: {entity.description or 'N/A'}",
|
||||||
|
)
|
||||||
|
elif entity_type == JiraLinkEntityType.campaign:
|
||||||
|
entity = db.query(Campaign).filter(Campaign.id == entity_id).first()
|
||||||
|
if not entity:
|
||||||
|
raise EntityNotFoundError("Campaign", str(entity_id))
|
||||||
|
return (
|
||||||
|
f"[Aegis Campaign] {entity.name}",
|
||||||
|
f"Campaign: {entity.name}\n"
|
||||||
|
f"Type: {entity.type}\nStatus: {entity.status}\n"
|
||||||
|
f"Description: {entity.description or 'N/A'}",
|
||||||
|
)
|
||||||
|
elif entity_type == JiraLinkEntityType.technique:
|
||||||
|
entity = db.query(Technique).filter(Technique.id == entity_id).first()
|
||||||
|
if not entity:
|
||||||
|
raise EntityNotFoundError("Technique", str(entity_id))
|
||||||
|
return (
|
||||||
|
f"[Aegis Technique] {entity.mitre_id} - {entity.name}",
|
||||||
|
f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n"
|
||||||
|
f"Tactic: {entity.tactic or 'N/A'}\n"
|
||||||
|
f"Description: {entity.description or 'N/A'}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_issue_and_link(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
entity_type: JiraLinkEntityType,
|
||||||
|
entity_id: UUID,
|
||||||
|
created_by: UUID,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a Jira issue from an Aegis entity and link them."""
|
||||||
|
summary, description = build_issue_data(db, entity_type, entity_id)
|
||||||
|
result = create_jira_issue(
|
||||||
|
project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||||
|
summary=summary,
|
||||||
|
description=description,
|
||||||
|
labels=["aegis", entity_type.value],
|
||||||
|
)
|
||||||
|
link = JiraLink(
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
jira_issue_key=result["issue_key"],
|
||||||
|
jira_issue_id=result["issue_id"],
|
||||||
|
jira_project_key=settings.JIRA_DEFAULT_PROJECT,
|
||||||
|
created_by=created_by,
|
||||||
|
)
|
||||||
|
db.add(link)
|
||||||
|
return {"issue_key": result["issue_key"], "link_id": str(link.id)}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.notification import Notification
|
from app.models.notification import Notification
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
@@ -22,6 +23,71 @@ from app.models.user import User
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def list_notifications(
|
||||||
|
db: Session,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
*,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[Notification]:
|
||||||
|
"""Return paginated notifications for a user, newest first."""
|
||||||
|
return (
|
||||||
|
db.query(Notification)
|
||||||
|
.filter(Notification.user_id == user_id)
|
||||||
|
.order_by(Notification.created_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_notification_or_raise(
|
||||||
|
db: Session,
|
||||||
|
notification_id: uuid.UUID,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
) -> Notification:
|
||||||
|
"""Fetch a notification by ID and user, or raise EntityNotFoundError."""
|
||||||
|
notif = (
|
||||||
|
db.query(Notification)
|
||||||
|
.filter(
|
||||||
|
Notification.id == notification_id,
|
||||||
|
Notification.user_id == user_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if notif is None:
|
||||||
|
raise EntityNotFoundError("Notification", str(notification_id))
|
||||||
|
return notif
|
||||||
|
|
||||||
|
|
||||||
|
def notify_role(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
role: str,
|
||||||
|
type: str,
|
||||||
|
title: str,
|
||||||
|
message: str,
|
||||||
|
entity_type: str,
|
||||||
|
entity_id: uuid.UUID,
|
||||||
|
) -> None:
|
||||||
|
"""Send notifications to all active users with a given role."""
|
||||||
|
users = (
|
||||||
|
db.query(User)
|
||||||
|
.filter(User.role == role, User.is_active == True) # noqa: E712
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
for user in users:
|
||||||
|
create_notification(
|
||||||
|
db,
|
||||||
|
user_id=user.id,
|
||||||
|
type=type,
|
||||||
|
title=title,
|
||||||
|
message=message,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_notification(
|
def create_notification(
|
||||||
db: Session,
|
db: Session,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
@@ -45,17 +111,13 @@ def create_notification(
|
|||||||
return notif
|
return notif
|
||||||
|
|
||||||
|
|
||||||
def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
|
def mark_as_read(
|
||||||
"""Mark a single notification as read. Returns True if updated."""
|
db: Session, notification_id: uuid.UUID, user_id: uuid.UUID
|
||||||
notif = (
|
) -> Notification:
|
||||||
db.query(Notification)
|
"""Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found."""
|
||||||
.filter(Notification.id == notification_id, Notification.user_id == user_id)
|
notif = get_notification_or_raise(db, notification_id, user_id)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if notif is None:
|
|
||||||
return False
|
|
||||||
notif.read = True
|
notif.read = True
|
||||||
return True
|
return notif
|
||||||
|
|
||||||
|
|
||||||
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int:
|
||||||
|
|||||||
@@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.osint_item import OsintItem
|
from app.models.osint_item import OsintItem
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
|
|
||||||
@@ -189,3 +193,87 @@ def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
|
|||||||
def get_unreviewed_count(db: Session) -> int:
|
def get_unreviewed_count(db: Session) -> int:
|
||||||
"""Return the total number of unreviewed OSINT items."""
|
"""Return the total number of unreviewed OSINT items."""
|
||||||
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
|
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
|
||||||
|
|
||||||
|
|
||||||
|
def list_osint_items(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
technique_id: Optional[UUID] = None,
|
||||||
|
source_type: Optional[str] = None,
|
||||||
|
reviewed: Optional[bool] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""List OSINT items with optional filters and pagination."""
|
||||||
|
query = db.query(OsintItem)
|
||||||
|
if technique_id:
|
||||||
|
query = query.filter(OsintItem.technique_id == technique_id)
|
||||||
|
if source_type:
|
||||||
|
query = query.filter(OsintItem.source_type == source_type)
|
||||||
|
if reviewed is not None:
|
||||||
|
query = query.filter(OsintItem.reviewed == reviewed)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
items = (
|
||||||
|
query.order_by(OsintItem.discovered_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"id": str(item.id),
|
||||||
|
"technique_id": str(item.technique_id),
|
||||||
|
"source_type": item.source_type,
|
||||||
|
"source_url": item.source_url,
|
||||||
|
"title": item.title,
|
||||||
|
"description": item.description,
|
||||||
|
"severity": item.severity,
|
||||||
|
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
|
||||||
|
"reviewed": item.reviewed,
|
||||||
|
"metadata": item.metadata_,
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_osint_summary(db: Session) -> dict:
|
||||||
|
"""Summary statistics for OSINT items."""
|
||||||
|
total = db.query(func.count(OsintItem.id)).scalar() or 0
|
||||||
|
unreviewed = get_unreviewed_count(db)
|
||||||
|
|
||||||
|
by_severity = dict(
|
||||||
|
db.query(OsintItem.severity, func.count(OsintItem.id))
|
||||||
|
.group_by(OsintItem.severity)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
by_type = dict(
|
||||||
|
db.query(OsintItem.source_type, func.count(OsintItem.id))
|
||||||
|
.group_by(OsintItem.source_type)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
techniques_with_items = (
|
||||||
|
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_items": total,
|
||||||
|
"unreviewed": unreviewed,
|
||||||
|
"techniques_with_items": techniques_with_items,
|
||||||
|
"by_severity": by_severity,
|
||||||
|
"by_type": by_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
|
||||||
|
"""Get a technique by ID or raise EntityNotFoundError."""
|
||||||
|
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||||
|
if not technique:
|
||||||
|
raise EntityNotFoundError("Technique", str(technique_id))
|
||||||
|
return technique
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import Optional
|
|||||||
from sqlalchemy import case, func
|
from sqlalchemy import case, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.test import Test
|
from app.models.test import Test
|
||||||
from app.models.detection_rule import DetectionRule
|
from app.models.detection_rule import DetectionRule
|
||||||
@@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict:
|
|||||||
# ── Technique-level scoring (single technique — preserved API) ────────
|
# ── Technique-level scoring (single technique — preserved API) ────────
|
||||||
|
|
||||||
|
|
||||||
|
def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict:
|
||||||
|
"""Get detailed score with breakdown for a technique by MITRE ID."""
|
||||||
|
technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first()
|
||||||
|
if not technique:
|
||||||
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
result = calculate_technique_score(technique, db)
|
||||||
|
return {
|
||||||
|
"mitre_id": technique.mitre_id,
|
||||||
|
"name": technique.name,
|
||||||
|
"tactic": technique.tactic,
|
||||||
|
"status_global": technique.status_global.value if technique.status_global else None,
|
||||||
|
**result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def score_actor_by_id(db: Session, actor_id: str) -> dict:
|
||||||
|
"""Get coverage score for a threat actor by ID."""
|
||||||
|
actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first()
|
||||||
|
if not actor:
|
||||||
|
raise EntityNotFoundError("ThreatActor", actor_id)
|
||||||
|
return calculate_actor_coverage_score(actor_id, db)
|
||||||
|
|
||||||
|
|
||||||
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||||
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from datetime import datetime
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
|
||||||
from app.models.enums import TechniqueStatus
|
from app.models.enums import TechniqueStatus
|
||||||
@@ -25,6 +26,101 @@ from app.services.scoring_service import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serialization and queries
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
|
||||||
|
"""Lightweight serialization for list views."""
|
||||||
|
return {
|
||||||
|
"id": str(snap.id),
|
||||||
|
"name": snap.name,
|
||||||
|
"organization_score": snap.organization_score,
|
||||||
|
"total_techniques": snap.total_techniques,
|
||||||
|
"validated_count": snap.validated_count,
|
||||||
|
"partial_count": snap.partial_count,
|
||||||
|
"not_covered_count": snap.not_covered_count,
|
||||||
|
"in_progress_count": snap.in_progress_count,
|
||||||
|
"not_evaluated_count": snap.not_evaluated_count,
|
||||||
|
"created_by": str(snap.created_by) if snap.created_by else None,
|
||||||
|
"created_at": snap.created_at.isoformat() if snap.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
|
||||||
|
"""Full serialization including technique states."""
|
||||||
|
base = serialize_snapshot_summary(snap)
|
||||||
|
|
||||||
|
technique_states = (
|
||||||
|
db.query(SnapshotTechniqueState)
|
||||||
|
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
|
||||||
|
.order_by(SnapshotTechniqueState.mitre_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
base["technique_states"] = [
|
||||||
|
{
|
||||||
|
"mitre_id": s.mitre_id,
|
||||||
|
"technique_id": str(s.technique_id),
|
||||||
|
"status": s.status,
|
||||||
|
"score": s.score,
|
||||||
|
}
|
||||||
|
for s in technique_states
|
||||||
|
]
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
def list_snapshots(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""List coverage snapshots ordered by creation date (newest first)."""
|
||||||
|
query = db.query(CoverageSnapshot)
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
snapshots = (
|
||||||
|
query
|
||||||
|
.order_by(CoverageSnapshot.created_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"offset": offset,
|
||||||
|
"limit": limit,
|
||||||
|
"items": [serialize_snapshot_summary(s) for s in snapshots],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
|
||||||
|
"""Fetch snapshot by ID or raise EntityNotFoundError."""
|
||||||
|
try:
|
||||||
|
sid = uuid.UUID(snapshot_id)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||||
|
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
|
||||||
|
if snapshot is None:
|
||||||
|
raise EntityNotFoundError("Snapshot", snapshot_id)
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
|
||||||
|
"""Get detailed snapshot including per-technique states."""
|
||||||
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
|
return serialize_snapshot_detail(db, snapshot)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_snapshot(db: Session, snapshot_id: str) -> None:
|
||||||
|
"""Delete a snapshot. Does not commit — caller must commit."""
|
||||||
|
snapshot = get_snapshot_or_raise(db, snapshot_id)
|
||||||
|
db.delete(snapshot)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Create snapshot
|
# Create snapshot
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -138,7 +234,7 @@ def compare_snapshots(
|
|||||||
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
|
||||||
|
|
||||||
if not snap_a or not snap_b:
|
if not snap_a or not snap_b:
|
||||||
return {"error": "One or both snapshots not found"}
|
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
|
||||||
|
|
||||||
# Build lookup dicts: mitre_id -> {status, score}
|
# Build lookup dicts: mitre_id -> {status, score}
|
||||||
states_a = {
|
states_a = {
|
||||||
|
|||||||
48
backend/app/services/technique_query_service.py
Normal file
48
backend/app/services/technique_query_service.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Technique query service — framework-agnostic queries for technique details."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.services.d3fend_import_service import get_defenses_for_technique
|
||||||
|
|
||||||
|
|
||||||
|
def get_technique_detail(db: Session, mitre_id: str) -> dict:
|
||||||
|
"""Fetch full technique details including tests and D3FEND defenses."""
|
||||||
|
technique = (
|
||||||
|
db.query(Technique)
|
||||||
|
.options(joinedload(Technique.tests))
|
||||||
|
.filter(Technique.mitre_id == mitre_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", mitre_id)
|
||||||
|
defenses = get_defenses_for_technique(db, technique.id)
|
||||||
|
return {
|
||||||
|
"id": str(technique.id),
|
||||||
|
"mitre_id": technique.mitre_id,
|
||||||
|
"name": technique.name,
|
||||||
|
"description": technique.description,
|
||||||
|
"tactic": technique.tactic,
|
||||||
|
"platforms": technique.platforms or [],
|
||||||
|
"mitre_version": technique.mitre_version,
|
||||||
|
"mitre_last_modified": technique.mitre_last_modified,
|
||||||
|
"is_subtechnique": technique.is_subtechnique,
|
||||||
|
"parent_mitre_id": technique.parent_mitre_id,
|
||||||
|
"status_global": technique.status_global.value if technique.status_global else "not_evaluated",
|
||||||
|
"review_required": technique.review_required,
|
||||||
|
"last_review_date": technique.last_review_date,
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"id": str(t.id),
|
||||||
|
"name": t.name,
|
||||||
|
"state": t.state.value if t.state else None,
|
||||||
|
"result": t.result.value if t.result else None,
|
||||||
|
"platform": t.platform,
|
||||||
|
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||||
|
}
|
||||||
|
for t in technique.tests
|
||||||
|
],
|
||||||
|
"d3fend_defenses": defenses,
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import EntityNotFoundError
|
||||||
from app.models.worklog import Worklog
|
from app.models.worklog import Worklog
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -43,6 +44,14 @@ def create_worklog(
|
|||||||
return wl
|
return wl
|
||||||
|
|
||||||
|
|
||||||
|
def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog:
|
||||||
|
"""Get a worklog by ID or raise EntityNotFoundError."""
|
||||||
|
wl = db.query(Worklog).filter(Worklog.id == worklog_id).first()
|
||||||
|
if not wl:
|
||||||
|
raise EntityNotFoundError("Worklog", str(worklog_id))
|
||||||
|
return wl
|
||||||
|
|
||||||
|
|
||||||
def list_worklogs(
|
def list_worklogs(
|
||||||
db: Session,
|
db: Session,
|
||||||
*,
|
*,
|
||||||
|
|||||||
Reference in New Issue
Block a user