From 339d66949843314e6db87dbec7d63992a511008e Mon Sep 17 00:00:00 2001 From: Kitos Date: Fri, 20 Feb 2026 14:34:24 +0100 Subject: [PATCH] feat: move all remaining inline logic from routers to services (Tier 2) --- backend/app/routers/campaigns.py | 22 ++- backend/app/routers/d3fend.py | 73 ++-------- backend/app/routers/jira.py | 96 ++----------- backend/app/routers/notifications.py | 24 +--- backend/app/routers/osint.py | 83 ++--------- backend/app/routers/scores.py | 32 +---- backend/app/routers/snapshots.py | 97 +++---------- backend/app/routers/techniques.py | 45 +----- backend/app/routers/worklogs.py | 11 +- backend/app/services/d3fend_query_service.py | 82 +++++++++++ backend/app/services/jira_service.py | 132 +++++++++++++++++- backend/app/services/notification_service.py | 82 +++++++++-- .../app/services/osint_enrichment_service.py | 88 ++++++++++++ backend/app/services/scoring_service.py | 24 ++++ backend/app/services/snapshot_service.py | 98 ++++++++++++- .../app/services/technique_query_service.py | 48 +++++++ backend/app/services/worklog_service.py | 9 ++ 17 files changed, 632 insertions(+), 414 deletions(-) create mode 100644 backend/app/services/d3fend_query_service.py create mode 100644 backend/app/services/technique_query_service.py diff --git a/backend/app/routers/campaigns.py b/backend/app/routers/campaigns.py index 335c906..a149f10 100644 --- a/backend/app/routers/campaigns.py +++ b/backend/app/routers/campaigns.py @@ -30,7 +30,7 @@ from app.services.campaign_crud_service import ( serialize_campaign, update_campaign as crud_update, ) -from app.services.notification_service import create_notification +from app.services.notification_service import notify_role from app.services.audit_service import log_action logger = logging.getLogger(__name__) @@ -237,17 +237,15 @@ def activate_campaign( db.commit() db.refresh(campaign) - red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712 - for user in red_techs: - create_notification( - db, - user_id=user.id, - type="campaign_activated", - title="Campaign activated", - message=f'Campaign "{campaign.name}" has been activated.', - entity_type="campaign", - entity_id=campaign.id, - ) + notify_role( + db, + role="red_tech", + type="campaign_activated", + title="Campaign activated", + message=f'Campaign "{campaign.name}" has been activated.', + entity_type="campaign", + entity_id=campaign.id, + ) log_action( db, diff --git a/backend/app/routers/d3fend.py b/backend/app/routers/d3fend.py index 307936e..4afa50a 100644 --- a/backend/app/routers/d3fend.py +++ b/backend/app/routers/d3fend.py @@ -3,18 +3,20 @@ import logging from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_role from app.models.user import User -from app.models.technique import Technique -from app.models.defensive_technique import DefensiveTechnique, DefensiveTechniqueMapping from app.services.d3fend_import_service import ( import_d3fend_techniques, import_d3fend_mappings, - get_defenses_for_technique, +) +from app.services.d3fend_query_service import ( + list_defensive_techniques as list_defensive_techniques_svc, + list_d3fend_tactics, + get_defenses_for_attack_technique, ) logger = logging.getLogger(__name__) @@ -36,38 +38,9 @@ def list_defensive_techniques( current_user: User = Depends(get_current_user), ): """List all D3FEND defensive techniques with optional filters.""" - query = db.query(DefensiveTechnique) - - if tactic: - query = query.filter(DefensiveTechnique.tactic == tactic) - - if search: - from app.utils import escape_like - pattern = f"%{escape_like(search)}%" - query = query.filter( - DefensiveTechnique.name.ilike(pattern) - | DefensiveTechnique.d3fend_id.ilike(pattern) - ) - - total = query.count() - items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all() - - return { - "total": total, - "offset": offset, - "limit": limit, - "items": [ - { - "id": str(dt.id), - "d3fend_id": dt.d3fend_id, - "name": dt.name, - "description": dt.description, - "tactic": dt.tactic, - "d3fend_url": dt.d3fend_url, - } - for dt in items - ], - } + return list_defensive_techniques_svc( + db, tactic=tactic, search=search, offset=offset, limit=limit + ) # --------------------------------------------------------------------------- @@ -75,21 +48,12 @@ def list_defensive_techniques( # --------------------------------------------------------------------------- @router.get("/tactics") -def list_d3fend_tactics( +def list_d3fend_tactics_endpoint( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Return a list of all D3FEND tactics with counts.""" - from sqlalchemy import func - - rows = ( - db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id)) - .group_by(DefensiveTechnique.tactic) - .order_by(DefensiveTechnique.tactic) - .all() - ) - - return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows] + return list_d3fend_tactics(db) # --------------------------------------------------------------------------- @@ -97,24 +61,13 @@ def list_d3fend_tactics( # --------------------------------------------------------------------------- @router.get("/for-technique/{mitre_id}") -def get_defenses_for_attack_technique( +def get_defenses_for_attack_technique_endpoint( mitre_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" - technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() - if not technique: - raise HTTPException(status_code=404, detail=f"Technique {mitre_id} not found") - - defenses = get_defenses_for_technique(db, technique.id) - - return { - "mitre_id": mitre_id, - "technique_name": technique.name, - "defenses": defenses, - "total": len(defenses), - } + return get_defenses_for_attack_technique(db, mitre_id) # --------------------------------------------------------------------------- diff --git a/backend/app/routers/jira.py b/backend/app/routers/jira.py index f647539..c06b570 100644 --- a/backend/app/routers/jira.py +++ b/backend/app/routers/jira.py @@ -7,14 +7,9 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session -from app.config import settings from app.database import get_db from app.dependencies.auth import get_current_user, require_role -from app.domain.exceptions import EntityNotFoundError -from app.models.jira_link import JiraLink, JiraLinkEntityType -from app.models.test import Test -from app.models.technique import Technique -from app.models.campaign import Campaign +from app.models.jira_link import JiraLinkEntityType from app.models.user import User from app.schemas.jira_schema import ( JiraIssueResult, @@ -45,23 +40,14 @@ def create_link( user: User = Depends(get_current_user), ): """Associate an Aegis entity with a Jira issue.""" - link = JiraLink( + link = jira_service.create_link( + db, entity_type=body.entity_type, entity_id=body.entity_id, jira_issue_key=body.jira_issue_key, sync_direction=body.sync_direction, created_by=user.id, ) - db.add(link) - db.flush() - - # Pull initial data from Jira if enabled - if settings.JIRA_ENABLED: - try: - jira_service.sync_jira_to_aegis(db, link) - except Exception as e: - logger.warning("Initial Jira sync failed for %s: %s", body.jira_issue_key, e) - db.commit() db.refresh(link) @@ -88,12 +74,11 @@ def list_links( user: User = Depends(get_current_user), ): """List Jira links, optionally filtered by entity.""" - query = db.query(JiraLink) - if entity_type: - query = query.filter(JiraLink.entity_type == entity_type) - if entity_id: - query = query.filter(JiraLink.entity_id == entity_id) - return query.order_by(JiraLink.created_at.desc()).all() + return jira_service.list_links( + db, + entity_type=entity_type, + entity_id=entity_id, + ) @router.post("/links/{link_id}/sync") @@ -103,9 +88,7 @@ def sync_link( user: User = Depends(require_role("admin")), ): """Force bidirectional sync for a specific Jira link.""" - link = db.query(JiraLink).filter(JiraLink.id == link_id).first() - if not link: - raise EntityNotFoundError("JiraLink", str(link_id)) + link = jira_service.get_link_or_raise(db, link_id) jira_service.sync_jira_to_aegis(db, link) db.commit() return {"message": "Sync completed", "jira_status": link.jira_status} @@ -118,10 +101,7 @@ def delete_link( user: User = Depends(get_current_user), ): """Remove a Jira link.""" - link = db.query(JiraLink).filter(JiraLink.id == link_id).first() - if not link: - raise EntityNotFoundError("JiraLink", str(link_id)) - db.delete(link) + link = jira_service.delete_link(db, link_id) db.commit() audit_service.log_action( db, @@ -141,61 +121,11 @@ def create_issue_from_entity( user: User = Depends(get_current_user), ): """Auto-create a Jira issue from an Aegis entity and link them.""" - summary, description = _build_issue_data(db, entity_type, entity_id) - result = jira_service.create_jira_issue( - project_key=settings.JIRA_DEFAULT_PROJECT, - summary=summary, - description=description, - labels=["aegis", entity_type.value], - ) - link = JiraLink( + result = jira_service.create_issue_and_link( + db, entity_type=entity_type, entity_id=entity_id, - jira_issue_key=result["issue_key"], - jira_issue_id=result["issue_id"], - jira_project_key=settings.JIRA_DEFAULT_PROJECT, created_by=user.id, ) - db.add(link) db.commit() - return {"issue_key": result["issue_key"], "link_id": str(link.id)} - - -def _build_issue_data( - db: Session, - entity_type: JiraLinkEntityType, - entity_id: UUID, -) -> tuple[str, str]: - """Build Jira issue summary + description from an Aegis entity.""" - if entity_type == JiraLinkEntityType.test: - entity = db.query(Test).filter(Test.id == entity_id).first() - if not entity: - raise EntityNotFoundError("Test", str(entity_id)) - return ( - f"[Aegis Test] {entity.name}", - f"Test: {entity.name}\n" - f"State: {entity.state.value if entity.state else 'draft'}\n" - f"Description: {entity.description or 'N/A'}", - ) - elif entity_type == JiraLinkEntityType.campaign: - entity = db.query(Campaign).filter(Campaign.id == entity_id).first() - if not entity: - raise EntityNotFoundError("Campaign", str(entity_id)) - return ( - f"[Aegis Campaign] {entity.name}", - f"Campaign: {entity.name}\n" - f"Type: {entity.type}\nStatus: {entity.status}\n" - f"Description: {entity.description or 'N/A'}", - ) - elif entity_type == JiraLinkEntityType.technique: - entity = db.query(Technique).filter(Technique.id == entity_id).first() - if not entity: - raise EntityNotFoundError("Technique", str(entity_id)) - return ( - f"[Aegis Technique] {entity.mitre_id} - {entity.name}", - f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n" - f"Tactic: {entity.tactic or 'N/A'}\n" - f"Description: {entity.description or 'N/A'}", - ) - else: - return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}" + return result diff --git a/backend/app/routers/notifications.py b/backend/app/routers/notifications.py index fa7d6b2..da42c58 100644 --- a/backend/app/routers/notifications.py +++ b/backend/app/routers/notifications.py @@ -10,16 +10,16 @@ POST /notifications/read-all — mark all as read import uuid -from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user from app.domain.unit_of_work import UnitOfWork -from app.models.notification import Notification from app.models.user import User from app.schemas.notification import NotificationOut, UnreadCountOut from app.services.notification_service import ( + list_notifications, mark_as_read, mark_all_as_read, get_unread_count, @@ -34,22 +34,14 @@ router = APIRouter(prefix="/notifications", tags=["notifications"]) @router.get("", response_model=list[NotificationOut]) -def list_notifications( +def list_notifications_endpoint( offset: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=100), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """Return paginated notifications for the current user, newest first.""" - notifs = ( - db.query(Notification) - .filter(Notification.user_id == current_user.id) - .order_by(Notification.created_at.desc()) - .offset(offset) - .limit(limit) - .all() - ) - return notifs + return list_notifications(db, current_user.id, offset=offset, limit=limit) # --------------------------------------------------------------------------- @@ -80,14 +72,8 @@ def read_notification( ): """Mark a single notification as read.""" with UnitOfWork(db) as uow: - success = mark_as_read(db, notification_id, current_user.id) - if not success: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Notification not found", - ) + notif = mark_as_read(db, notification_id, current_user.id) uow.commit() - notif = db.query(Notification).filter(Notification.id == notification_id).first() return notif diff --git a/backend/app/routers/osint.py b/backend/app/routers/osint.py index 13bec68..6b16b06 100644 --- a/backend/app/routers/osint.py +++ b/backend/app/routers/osint.py @@ -10,14 +10,14 @@ from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role -from app.models.osint_item import OsintItem -from app.models.technique import Technique from app.models.user import User from app.services.osint_enrichment_service import ( enrich_technique_with_cves, get_osint_items_for_technique, + get_osint_summary, + get_technique_or_raise, + list_osint_items as service_list_osint_items, mark_osint_reviewed, - get_unreviewed_count, ) router = APIRouter(prefix="/osint", tags=["osint"]) @@ -56,41 +56,15 @@ def list_osint_items( user: User = Depends(get_current_user), ): """List OSINT items with optional filters.""" - query = db.query(OsintItem) - if technique_id: - query = query.filter(OsintItem.technique_id == technique_id) - if source_type: - query = query.filter(OsintItem.source_type == source_type) - if reviewed is not None: - query = query.filter(OsintItem.reviewed == reviewed) - - total = query.count() - items = ( - query.order_by(OsintItem.discovered_at.desc()) - .offset(offset) - .limit(limit) - .all() + return service_list_osint_items( + db, + technique_id=technique_id, + source_type=source_type, + reviewed=reviewed, + offset=offset, + limit=limit, ) - return { - "total": total, - "items": [ - { - "id": str(item.id), - "technique_id": str(item.technique_id), - "source_type": item.source_type, - "source_url": item.source_url, - "title": item.title, - "description": item.description, - "severity": item.severity, - "discovered_at": item.discovered_at.isoformat() if item.discovered_at else None, - "reviewed": item.reviewed, - "metadata": item.metadata_, - } - for item in items - ], - } - @router.get("/summary") def osint_summary( @@ -98,34 +72,7 @@ def osint_summary( user: User = Depends(get_current_user), ): """Summary statistics for OSINT items.""" - from sqlalchemy import func - - total = db.query(func.count(OsintItem.id)).scalar() or 0 - unreviewed = get_unreviewed_count(db) - - by_severity = dict( - db.query(OsintItem.severity, func.count(OsintItem.id)) - .group_by(OsintItem.severity) - .all() - ) - - by_type = dict( - db.query(OsintItem.source_type, func.count(OsintItem.id)) - .group_by(OsintItem.source_type) - .all() - ) - - techniques_with_items = ( - db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0 - ) - - return { - "total_items": total, - "unreviewed": unreviewed, - "techniques_with_items": techniques_with_items, - "by_severity": by_severity, - "by_type": by_type, - } + return get_osint_summary(db) @router.post("/items/{item_id}/review") @@ -151,13 +98,7 @@ def trigger_technique_enrichment( user: User = Depends(require_any_role("red_lead", "blue_lead")), ): """Manually trigger OSINT enrichment for a single technique.""" - technique = db.query(Technique).filter(Technique.id == technique_id).first() - if not technique: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Technique not found", - ) - + technique = get_technique_or_raise(db, technique_id) count = enrich_technique_with_cves(db, technique) return { "technique_id": str(technique.id), diff --git a/backend/app/routers/scores.py b/backend/app/routers/scores.py index 0e9621a..80dce66 100644 --- a/backend/app/routers/scores.py +++ b/backend/app/routers/scores.py @@ -5,19 +5,17 @@ Provides granular scoring with breakdowns and configurable weights. from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_role from app.models.user import User -from app.models.technique import Technique -from app.models.threat_actor import ThreatActor from app.services.scoring_service import ( - calculate_technique_score, + score_technique_by_mitre_id, + score_actor_by_id, calculate_tactic_score, - calculate_actor_coverage_score, calculate_organization_score, get_score_history, ) @@ -39,23 +37,7 @@ def score_technique( current_user: User = Depends(get_current_user), ): """Get detailed score with breakdown for a specific technique.""" - technique = ( - db.query(Technique) - .filter(Technique.mitre_id == mitre_id) - .first() - ) - if not technique: - raise HTTPException(status_code=404, detail="Technique not found") - - result = calculate_technique_score(technique, db) - - return { - "mitre_id": technique.mitre_id, - "name": technique.name, - "tactic": technique.tactic, - "status_global": technique.status_global.value if technique.status_global else None, - **result, - } + return score_technique_by_mitre_id(db, mitre_id) # ── GET /scores/tactic/{tactic} ────────────────────────────────────── @@ -81,11 +63,7 @@ def score_threat_actor( current_user: User = Depends(get_current_user), ): """Get coverage score against a specific threat actor.""" - actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() - if not actor: - raise HTTPException(status_code=404, detail="Threat actor not found") - - return calculate_actor_coverage_score(actor_id, db) + return score_actor_by_id(db, actor_id) # ── GET /scores/organization ───────────────────────────────────────── diff --git a/backend/app/routers/snapshots.py b/backend/app/routers/snapshots.py index cacea02..78fa89b 100644 --- a/backend/app/routers/snapshots.py +++ b/backend/app/routers/snapshots.py @@ -8,18 +8,24 @@ import logging import uuid from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role, require_role +from app.domain.errors import BusinessRuleViolation +from app.domain.unit_of_work import UnitOfWork from app.models.user import User -from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.services.snapshot_service import ( create_snapshot, compare_snapshots, cleanup_old_snapshots, + serialize_snapshot_summary, + list_snapshots as list_snapshots_svc, + get_snapshot_or_raise, + get_snapshot_detail, + delete_snapshot, ) from app.services.audit_service import log_action @@ -34,48 +40,6 @@ class SnapshotCreate(BaseModel): name: Optional[str] = None -# ── Helpers ────────────────────────────────────────────────────────── - -def _serialize_snapshot_summary(snap: CoverageSnapshot) -> dict: - """Lightweight serialization for list views.""" - return { - "id": str(snap.id), - "name": snap.name, - "organization_score": snap.organization_score, - "total_techniques": snap.total_techniques, - "validated_count": snap.validated_count, - "partial_count": snap.partial_count, - "not_covered_count": snap.not_covered_count, - "in_progress_count": snap.in_progress_count, - "not_evaluated_count": snap.not_evaluated_count, - "created_by": str(snap.created_by) if snap.created_by else None, - "created_at": snap.created_at.isoformat() if snap.created_at else None, - } - - -def _serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict: - """Full serialization including technique states.""" - base = _serialize_snapshot_summary(snap) - - technique_states = ( - db.query(SnapshotTechniqueState) - .filter(SnapshotTechniqueState.snapshot_id == snap.id) - .order_by(SnapshotTechniqueState.mitre_id) - .all() - ) - - base["technique_states"] = [ - { - "mitre_id": s.mitre_id, - "technique_id": str(s.technique_id), - "status": s.status, - "score": s.score, - } - for s in technique_states - ] - return base - - # --------------------------------------------------------------------------- # GET /snapshots — List snapshots (paginated) # --------------------------------------------------------------------------- @@ -88,23 +52,7 @@ def list_snapshots( current_user: User = Depends(get_current_user), ): """List coverage snapshots ordered by creation date (newest first).""" - query = db.query(CoverageSnapshot) - total = query.count() - - snapshots = ( - query - .order_by(CoverageSnapshot.created_at.desc()) - .offset(offset) - .limit(limit) - .all() - ) - - return { - "total": total, - "offset": offset, - "limit": limit, - "items": [_serialize_snapshot_summary(s) for s in snapshots], - } + return list_snapshots_svc(db, offset=offset, limit=limit) # --------------------------------------------------------------------------- @@ -129,7 +77,7 @@ def create_snapshot_endpoint( details={"name": snapshot.name, "score": snapshot.organization_score}, ) - return _serialize_snapshot_summary(snapshot) + return serialize_snapshot_summary(snapshot) # --------------------------------------------------------------------------- @@ -148,13 +96,9 @@ def compare_snapshots_endpoint( a_id = uuid.UUID(a) b_id = uuid.UUID(b) except ValueError: - raise HTTPException(status_code=400, detail="Invalid snapshot ID format") + raise BusinessRuleViolation("Invalid snapshot ID format") - result = compare_snapshots(db, a_id, b_id) - if "error" in result: - raise HTTPException(status_code=404, detail=result["error"]) - - return result + return compare_snapshots(db, a_id, b_id) # --------------------------------------------------------------------------- @@ -168,11 +112,7 @@ def get_snapshot( current_user: User = Depends(get_current_user), ): """Get detailed snapshot information including per-technique states.""" - snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() - if not snapshot: - raise HTTPException(status_code=404, detail="Snapshot not found") - - return _serialize_snapshot_detail(db, snapshot) + return get_snapshot_detail(db, snapshot_id) # --------------------------------------------------------------------------- @@ -180,15 +120,13 @@ def get_snapshot( # --------------------------------------------------------------------------- @router.delete("/{snapshot_id}") -def delete_snapshot( +def delete_snapshot_endpoint( snapshot_id: str, db: Session = Depends(get_db), current_user: User = Depends(require_role("admin")), ): """Delete a snapshot (admin only).""" - snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_id).first() - if not snapshot: - raise HTTPException(status_code=404, detail="Snapshot not found") + snapshot = get_snapshot_or_raise(db, snapshot_id) log_action( db, @@ -199,7 +137,8 @@ def delete_snapshot( details={"name": snapshot.name}, ) - db.delete(snapshot) - db.commit() + with UnitOfWork(db) as uow: + delete_snapshot(db, snapshot_id) + uow.commit() return {"detail": "Snapshot deleted"} diff --git a/backend/app/routers/techniques.py b/backend/app/routers/techniques.py index 5ddb604..92f524d 100644 --- a/backend/app/routers/techniques.py +++ b/backend/app/routers/techniques.py @@ -6,7 +6,7 @@ exceptions to HTTP responses automatically. """ from fastapi import APIRouter, Depends, Query, status -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_role, require_any_role @@ -18,7 +18,6 @@ from app.domain.unit_of_work import UnitOfWork from app.infrastructure.persistence.repositories.sa_technique_repository import ( SATechniqueRepository, ) -from app.models.technique import Technique from app.models.user import User from app.schemas.technique import ( TechniqueCreate, @@ -27,7 +26,7 @@ from app.schemas.technique import ( TechniqueUpdate, ) from app.services.audit_service import log_action -from app.services.d3fend_import_service import get_defenses_for_technique +from app.services.technique_query_service import get_technique_detail router = APIRouter(prefix="/techniques", tags=["techniques"]) @@ -67,45 +66,7 @@ def get_technique( current_user: User = Depends(get_current_user), ): """Return full details for a single technique, including its tests and D3FEND defenses.""" - technique = ( - db.query(Technique) - .options(joinedload(Technique.tests)) - .filter(Technique.mitre_id == mitre_id) - .first() - ) - - if technique is None: - raise EntityNotFoundError("Technique", mitre_id) - - defenses = get_defenses_for_technique(db, technique.id) - - return { - "id": str(technique.id), - "mitre_id": technique.mitre_id, - "name": technique.name, - "description": technique.description, - "tactic": technique.tactic, - "platforms": technique.platforms or [], - "mitre_version": technique.mitre_version, - "mitre_last_modified": technique.mitre_last_modified, - "is_subtechnique": technique.is_subtechnique, - "parent_mitre_id": technique.parent_mitre_id, - "status_global": technique.status_global.value if technique.status_global else "not_evaluated", - "review_required": technique.review_required, - "last_review_date": technique.last_review_date, - "tests": [ - { - "id": str(t.id), - "name": t.name, - "state": t.state.value if t.state else None, - "result": t.result.value if t.result else None, - "platform": t.platform, - "created_at": t.created_at.isoformat() if t.created_at else None, - } - for t in technique.tests - ], - "d3fend_defenses": defenses, - } + return get_technique_detail(db, mitre_id) # --------------------------------------------------------------------------- diff --git a/backend/app/routers/worklogs.py b/backend/app/routers/worklogs.py index be84999..bb5a5cf 100644 --- a/backend/app/routers/worklogs.py +++ b/backend/app/routers/worklogs.py @@ -10,9 +10,7 @@ from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role -from app.domain.exceptions import EntityNotFoundError from app.models.user import User -from app.models.worklog import Worklog from app.services import worklog_service router = APIRouter(prefix="/worklogs", tags=["worklogs"]) @@ -97,10 +95,7 @@ def get_one( _user: User = Depends(get_current_user), ): """Get a single worklog by ID.""" - wl = db.query(Worklog).filter(Worklog.id == worklog_id).first() - if not wl: - raise EntityNotFoundError("Worklog", str(worklog_id)) - return wl + return worklog_service.get_worklog_or_raise(db, worklog_id) @router.get("/{worklog_id}/verify") @@ -110,9 +105,7 @@ def verify_integrity( _user: User = Depends(get_current_user), ): """Check whether a worklog's integrity hash is still valid.""" - wl = db.query(Worklog).filter(Worklog.id == worklog_id).first() - if not wl: - raise EntityNotFoundError("Worklog", str(worklog_id)) + wl = worklog_service.get_worklog_or_raise(db, worklog_id) return { "worklog_id": str(wl.id), "integrity_valid": worklog_service.verify_worklog_integrity(wl), diff --git a/backend/app/services/d3fend_query_service.py b/backend/app/services/d3fend_query_service.py new file mode 100644 index 0000000..dc433ee --- /dev/null +++ b/backend/app/services/d3fend_query_service.py @@ -0,0 +1,82 @@ +"""D3FEND query service — framework-agnostic queries for defensive techniques.""" +from __future__ import annotations + +from typing import Optional + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.domain.errors import EntityNotFoundError +from app.models.defensive_technique import DefensiveTechnique +from app.models.technique import Technique +from app.services.d3fend_import_service import get_defenses_for_technique +from app.utils import escape_like + + +def list_defensive_techniques( + db: Session, + *, + tactic: Optional[str] = None, + search: Optional[str] = None, + offset: int = 0, + limit: int = 50, +) -> dict: + """List D3FEND defensive techniques with optional filters.""" + query = db.query(DefensiveTechnique) + + if tactic: + query = query.filter(DefensiveTechnique.tactic == tactic) + + if search: + pattern = f"%{escape_like(search)}%" + query = query.filter( + DefensiveTechnique.name.ilike(pattern) + | DefensiveTechnique.d3fend_id.ilike(pattern) + ) + + total = query.count() + items = query.order_by(DefensiveTechnique.d3fend_id).offset(offset).limit(limit).all() + + return { + "total": total, + "offset": offset, + "limit": limit, + "items": [ + { + "id": str(dt.id), + "d3fend_id": dt.d3fend_id, + "name": dt.name, + "description": dt.description, + "tactic": dt.tactic, + "d3fend_url": dt.d3fend_url, + } + for dt in items + ], + } + + +def list_d3fend_tactics(db: Session) -> list[dict]: + """Return a list of all D3FEND tactics with counts.""" + rows = ( + db.query(DefensiveTechnique.tactic, func.count(DefensiveTechnique.id)) + .group_by(DefensiveTechnique.tactic) + .order_by(DefensiveTechnique.tactic) + .all() + ) + return [{"tactic": tactic or "Unknown", "count": count} for tactic, count in rows] + + +def get_defenses_for_attack_technique(db: Session, mitre_id: str) -> dict: + """Get all D3FEND defensive techniques mapped to a given ATT&CK technique.""" + technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() + if technique is None: + raise EntityNotFoundError("Technique", mitre_id) + + defenses = get_defenses_for_technique(db, technique.id) + + return { + "mitre_id": mitre_id, + "technique_name": technique.name, + "defenses": defenses, + "total": len(defenses), + } diff --git a/backend/app/services/jira_service.py b/backend/app/services/jira_service.py index cc6b13f..9bc760e 100644 --- a/backend/app/services/jira_service.py +++ b/backend/app/services/jira_service.py @@ -3,12 +3,17 @@ import logging from datetime import datetime from typing import Optional +from uuid import UUID from sqlalchemy.orm import Session from app.config import settings +from app.domain.errors import EntityNotFoundError from app.domain.exceptions import InvalidOperationError -from app.models.jira_link import JiraLink +from app.models.campaign import Campaign +from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection +from app.models.technique import Technique +from app.models.test import Test logger = logging.getLogger(__name__) @@ -103,3 +108,128 @@ def _build_sync_comment(data: dict) -> str: lines.append(f"*{key}:* {value}") lines.append(f"\n_Synced at {datetime.utcnow().isoformat()}_") return "\n".join(lines) + + +# ── Link CRUD ──────────────────────────────────────────────────────── + + +def create_link( + db: Session, + *, + entity_type: JiraLinkEntityType, + entity_id: UUID, + jira_issue_key: str, + sync_direction: JiraSyncDirection, + created_by: UUID, +) -> JiraLink: + """Create a Jira link and optionally pull initial data from Jira.""" + link = JiraLink( + entity_type=entity_type, + entity_id=entity_id, + jira_issue_key=jira_issue_key, + sync_direction=sync_direction, + created_by=created_by, + ) + db.add(link) + db.flush() + + if settings.JIRA_ENABLED: + try: + sync_jira_to_aegis(db, link) + except Exception as e: + logger.warning("Initial Jira sync failed for %s: %s", jira_issue_key, e) + + return link + + +def list_links( + db: Session, + *, + entity_type: Optional[JiraLinkEntityType] = None, + entity_id: Optional[UUID] = None, +) -> list[JiraLink]: + """List Jira links with optional filters.""" + query = db.query(JiraLink) + if entity_type: + query = query.filter(JiraLink.entity_type == entity_type) + if entity_id: + query = query.filter(JiraLink.entity_id == entity_id) + return query.order_by(JiraLink.created_at.desc()).all() + + +def get_link_or_raise(db: Session, link_id: UUID) -> JiraLink: + """Get a Jira link by ID or raise EntityNotFoundError.""" + link = db.query(JiraLink).filter(JiraLink.id == link_id).first() + if not link: + raise EntityNotFoundError("JiraLink", str(link_id)) + return link + + +def delete_link(db: Session, link_id: UUID) -> JiraLink: + """Delete a Jira link. Returns the deleted link (for audit).""" + link = get_link_or_raise(db, link_id) + db.delete(link) + return link + + +def build_issue_data(db: Session, entity_type: JiraLinkEntityType, entity_id: UUID) -> tuple[str, str]: + """Build Jira issue summary and description from an Aegis entity.""" + if entity_type == JiraLinkEntityType.test: + entity = db.query(Test).filter(Test.id == entity_id).first() + if not entity: + raise EntityNotFoundError("Test", str(entity_id)) + return ( + f"[Aegis Test] {entity.name}", + f"Test: {entity.name}\n" + f"State: {entity.state.value if entity.state else 'draft'}\n" + f"Description: {entity.description or 'N/A'}", + ) + elif entity_type == JiraLinkEntityType.campaign: + entity = db.query(Campaign).filter(Campaign.id == entity_id).first() + if not entity: + raise EntityNotFoundError("Campaign", str(entity_id)) + return ( + f"[Aegis Campaign] {entity.name}", + f"Campaign: {entity.name}\n" + f"Type: {entity.type}\nStatus: {entity.status}\n" + f"Description: {entity.description or 'N/A'}", + ) + elif entity_type == JiraLinkEntityType.technique: + entity = db.query(Technique).filter(Technique.id == entity_id).first() + if not entity: + raise EntityNotFoundError("Technique", str(entity_id)) + return ( + f"[Aegis Technique] {entity.mitre_id} - {entity.name}", + f"MITRE ID: {entity.mitre_id}\nName: {entity.name}\n" + f"Tactic: {entity.tactic or 'N/A'}\n" + f"Description: {entity.description or 'N/A'}", + ) + else: + return f"[Aegis] Entity {entity_id}", f"Entity type: {entity_type.value}" + + +def create_issue_and_link( + db: Session, + *, + entity_type: JiraLinkEntityType, + entity_id: UUID, + created_by: UUID, +) -> dict: + """Create a Jira issue from an Aegis entity and link them.""" + summary, description = build_issue_data(db, entity_type, entity_id) + result = create_jira_issue( + project_key=settings.JIRA_DEFAULT_PROJECT, + summary=summary, + description=description, + labels=["aegis", entity_type.value], + ) + link = JiraLink( + entity_type=entity_type, + entity_id=entity_id, + jira_issue_key=result["issue_key"], + jira_issue_id=result["issue_id"], + jira_project_key=settings.JIRA_DEFAULT_PROJECT, + created_by=created_by, + ) + db.add(link) + return {"issue_key": result["issue_key"], "link_id": str(link.id)} diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 86d523f..4838468 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -13,6 +13,7 @@ from datetime import datetime, timedelta from sqlalchemy.orm import Session from sqlalchemy import func +from app.domain.errors import EntityNotFoundError from app.models.notification import Notification from app.models.user import User @@ -22,6 +23,71 @@ from app.models.user import User # --------------------------------------------------------------------------- +def list_notifications( + db: Session, + user_id: uuid.UUID, + *, + offset: int = 0, + limit: int = 20, +) -> list[Notification]: + """Return paginated notifications for a user, newest first.""" + return ( + db.query(Notification) + .filter(Notification.user_id == user_id) + .order_by(Notification.created_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) + + +def get_notification_or_raise( + db: Session, + notification_id: uuid.UUID, + user_id: uuid.UUID, +) -> Notification: + """Fetch a notification by ID and user, or raise EntityNotFoundError.""" + notif = ( + db.query(Notification) + .filter( + Notification.id == notification_id, + Notification.user_id == user_id, + ) + .first() + ) + if notif is None: + raise EntityNotFoundError("Notification", str(notification_id)) + return notif + + +def notify_role( + db: Session, + *, + role: str, + type: str, + title: str, + message: str, + entity_type: str, + entity_id: uuid.UUID, +) -> None: + """Send notifications to all active users with a given role.""" + users = ( + db.query(User) + .filter(User.role == role, User.is_active == True) # noqa: E712 + .all() + ) + for user in users: + create_notification( + db, + user_id=user.id, + type=type, + title=title, + message=message, + entity_type=entity_type, + entity_id=entity_id, + ) + + def create_notification( db: Session, user_id: uuid.UUID, @@ -45,17 +111,13 @@ def create_notification( return notif -def mark_as_read(db: Session, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool: - """Mark a single notification as read. Returns True if updated.""" - notif = ( - db.query(Notification) - .filter(Notification.id == notification_id, Notification.user_id == user_id) - .first() - ) - if notif is None: - return False +def mark_as_read( + db: Session, notification_id: uuid.UUID, user_id: uuid.UUID +) -> Notification: + """Mark a single notification as read. Returns the notification. Raises EntityNotFoundError if not found.""" + notif = get_notification_or_raise(db, notification_id, user_id) notif.read = True - return True + return notif def mark_all_as_read(db: Session, user_id: uuid.UUID) -> int: diff --git a/backend/app/services/osint_enrichment_service.py b/backend/app/services/osint_enrichment_service.py index 47f0b2d..7143735 100644 --- a/backend/app/services/osint_enrichment_service.py +++ b/backend/app/services/osint_enrichment_service.py @@ -7,11 +7,15 @@ Designed to run as a weekly background job. Respects NVD rate limits import logging import time +from typing import Optional +from uuid import UUID import requests +from sqlalchemy import func from sqlalchemy.orm import Session from app.config import settings +from app.domain.errors import EntityNotFoundError from app.models.osint_item import OsintItem from app.models.technique import Technique @@ -189,3 +193,87 @@ def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None: def get_unreviewed_count(db: Session) -> int: """Return the total number of unreviewed OSINT items.""" return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712 + + +def list_osint_items( + db: Session, + *, + technique_id: Optional[UUID] = None, + source_type: Optional[str] = None, + reviewed: Optional[bool] = None, + offset: int = 0, + limit: int = 50, +) -> dict: + """List OSINT items with optional filters and pagination.""" + query = db.query(OsintItem) + if technique_id: + query = query.filter(OsintItem.technique_id == technique_id) + if source_type: + query = query.filter(OsintItem.source_type == source_type) + if reviewed is not None: + query = query.filter(OsintItem.reviewed == reviewed) + + total = query.count() + items = ( + query.order_by(OsintItem.discovered_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) + + return { + "total": total, + "items": [ + { + "id": str(item.id), + "technique_id": str(item.technique_id), + "source_type": item.source_type, + "source_url": item.source_url, + "title": item.title, + "description": item.description, + "severity": item.severity, + "discovered_at": item.discovered_at.isoformat() if item.discovered_at else None, + "reviewed": item.reviewed, + "metadata": item.metadata_, + } + for item in items + ], + } + + +def get_osint_summary(db: Session) -> dict: + """Summary statistics for OSINT items.""" + total = db.query(func.count(OsintItem.id)).scalar() or 0 + unreviewed = get_unreviewed_count(db) + + by_severity = dict( + db.query(OsintItem.severity, func.count(OsintItem.id)) + .group_by(OsintItem.severity) + .all() + ) + + by_type = dict( + db.query(OsintItem.source_type, func.count(OsintItem.id)) + .group_by(OsintItem.source_type) + .all() + ) + + techniques_with_items = ( + db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0 + ) + + return { + "total_items": total, + "unreviewed": unreviewed, + "techniques_with_items": techniques_with_items, + "by_severity": by_severity, + "by_type": by_type, + } + + +def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique: + """Get a technique by ID or raise EntityNotFoundError.""" + technique = db.query(Technique).filter(Technique.id == technique_id).first() + if not technique: + raise EntityNotFoundError("Technique", str(technique_id)) + return technique diff --git a/backend/app/services/scoring_service.py b/backend/app/services/scoring_service.py index 25436e1..f53c9e7 100644 --- a/backend/app/services/scoring_service.py +++ b/backend/app/services/scoring_service.py @@ -15,6 +15,7 @@ from typing import Optional from sqlalchemy import case, func from sqlalchemy.orm import Session +from app.domain.errors import EntityNotFoundError from app.models.technique import Technique from app.models.test import Test from app.models.detection_rule import DetectionRule @@ -232,6 +233,29 @@ def bulk_technique_scores(db: Session) -> dict: # ── Technique-level scoring (single technique — preserved API) ──────── +def score_technique_by_mitre_id(db: Session, mitre_id: str) -> dict: + """Get detailed score with breakdown for a technique by MITRE ID.""" + technique = db.query(Technique).filter(Technique.mitre_id == mitre_id).first() + if not technique: + raise EntityNotFoundError("Technique", mitre_id) + result = calculate_technique_score(technique, db) + return { + "mitre_id": technique.mitre_id, + "name": technique.name, + "tactic": technique.tactic, + "status_global": technique.status_global.value if technique.status_global else None, + **result, + } + + +def score_actor_by_id(db: Session, actor_id: str) -> dict: + """Get coverage score for a threat actor by ID.""" + actor = db.query(ThreatActor).filter(ThreatActor.id == actor_id).first() + if not actor: + raise EntityNotFoundError("ThreatActor", actor_id) + return calculate_actor_coverage_score(actor_id, db) + + def calculate_technique_score(technique: Technique, db: Session) -> dict: """Calculate a 0-100 score for a technique with detailed breakdown. diff --git a/backend/app/services/snapshot_service.py b/backend/app/services/snapshot_service.py index e04f222..bdd9b49 100644 --- a/backend/app/services/snapshot_service.py +++ b/backend/app/services/snapshot_service.py @@ -14,6 +14,7 @@ from datetime import datetime from sqlalchemy import func from sqlalchemy.orm import Session +from app.domain.errors import EntityNotFoundError from app.models.technique import Technique from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState from app.models.enums import TechniqueStatus @@ -25,6 +26,101 @@ from app.services.scoring_service import ( logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Serialization and queries +# --------------------------------------------------------------------------- + + +def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict: + """Lightweight serialization for list views.""" + return { + "id": str(snap.id), + "name": snap.name, + "organization_score": snap.organization_score, + "total_techniques": snap.total_techniques, + "validated_count": snap.validated_count, + "partial_count": snap.partial_count, + "not_covered_count": snap.not_covered_count, + "in_progress_count": snap.in_progress_count, + "not_evaluated_count": snap.not_evaluated_count, + "created_by": str(snap.created_by) if snap.created_by else None, + "created_at": snap.created_at.isoformat() if snap.created_at else None, + } + + +def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict: + """Full serialization including technique states.""" + base = serialize_snapshot_summary(snap) + + technique_states = ( + db.query(SnapshotTechniqueState) + .filter(SnapshotTechniqueState.snapshot_id == snap.id) + .order_by(SnapshotTechniqueState.mitre_id) + .all() + ) + + base["technique_states"] = [ + { + "mitre_id": s.mitre_id, + "technique_id": str(s.technique_id), + "status": s.status, + "score": s.score, + } + for s in technique_states + ] + return base + + +def list_snapshots( + db: Session, + *, + offset: int = 0, + limit: int = 50, +) -> dict: + """List coverage snapshots ordered by creation date (newest first).""" + query = db.query(CoverageSnapshot) + total = query.count() + + snapshots = ( + query + .order_by(CoverageSnapshot.created_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) + + return { + "total": total, + "offset": offset, + "limit": limit, + "items": [serialize_snapshot_summary(s) for s in snapshots], + } + + +def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot: + """Fetch snapshot by ID or raise EntityNotFoundError.""" + try: + sid = uuid.UUID(snapshot_id) + except (ValueError, TypeError): + raise EntityNotFoundError("Snapshot", snapshot_id) + snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first() + if snapshot is None: + raise EntityNotFoundError("Snapshot", snapshot_id) + return snapshot + + +def get_snapshot_detail(db: Session, snapshot_id: str) -> dict: + """Get detailed snapshot including per-technique states.""" + snapshot = get_snapshot_or_raise(db, snapshot_id) + return serialize_snapshot_detail(db, snapshot) + + +def delete_snapshot(db: Session, snapshot_id: str) -> None: + """Delete a snapshot. Does not commit — caller must commit.""" + snapshot = get_snapshot_or_raise(db, snapshot_id) + db.delete(snapshot) + + # --------------------------------------------------------------------------- # Create snapshot # --------------------------------------------------------------------------- @@ -138,7 +234,7 @@ def compare_snapshots( snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first() if not snap_a or not snap_b: - return {"error": "One or both snapshots not found"} + raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}") # Build lookup dicts: mitre_id -> {status, score} states_a = { diff --git a/backend/app/services/technique_query_service.py b/backend/app/services/technique_query_service.py new file mode 100644 index 0000000..6e9e049 --- /dev/null +++ b/backend/app/services/technique_query_service.py @@ -0,0 +1,48 @@ +"""Technique query service — framework-agnostic queries for technique details.""" +from __future__ import annotations + +from sqlalchemy.orm import Session, joinedload + +from app.domain.errors import EntityNotFoundError +from app.models.technique import Technique +from app.services.d3fend_import_service import get_defenses_for_technique + + +def get_technique_detail(db: Session, mitre_id: str) -> dict: + """Fetch full technique details including tests and D3FEND defenses.""" + technique = ( + db.query(Technique) + .options(joinedload(Technique.tests)) + .filter(Technique.mitre_id == mitre_id) + .first() + ) + if technique is None: + raise EntityNotFoundError("Technique", mitre_id) + defenses = get_defenses_for_technique(db, technique.id) + return { + "id": str(technique.id), + "mitre_id": technique.mitre_id, + "name": technique.name, + "description": technique.description, + "tactic": technique.tactic, + "platforms": technique.platforms or [], + "mitre_version": technique.mitre_version, + "mitre_last_modified": technique.mitre_last_modified, + "is_subtechnique": technique.is_subtechnique, + "parent_mitre_id": technique.parent_mitre_id, + "status_global": technique.status_global.value if technique.status_global else "not_evaluated", + "review_required": technique.review_required, + "last_review_date": technique.last_review_date, + "tests": [ + { + "id": str(t.id), + "name": t.name, + "state": t.state.value if t.state else None, + "result": t.result.value if t.result else None, + "platform": t.platform, + "created_at": t.created_at.isoformat() if t.created_at else None, + } + for t in technique.tests + ], + "d3fend_defenses": defenses, + } diff --git a/backend/app/services/worklog_service.py b/backend/app/services/worklog_service.py index 7e1358f..7cee816 100644 --- a/backend/app/services/worklog_service.py +++ b/backend/app/services/worklog_service.py @@ -8,6 +8,7 @@ from uuid import UUID from sqlalchemy.orm import Session +from app.domain.errors import EntityNotFoundError from app.models.worklog import Worklog logger = logging.getLogger(__name__) @@ -43,6 +44,14 @@ def create_worklog( return wl +def get_worklog_or_raise(db: Session, worklog_id: UUID) -> Worklog: + """Get a worklog by ID or raise EntityNotFoundError.""" + wl = db.query(Worklog).filter(Worklog.id == worklog_id).first() + if not wl: + raise EntityNotFoundError("Worklog", str(worklog_id)) + return wl + + def list_worklogs( db: Session, *,