Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0eff48c768 | |||
| 764a2f7579 | |||
| f4c74230ec | |||
| 50b70704ae | |||
| 20738d11b3 | |||
| 4e3787d091 |
@@ -0,0 +1,37 @@
|
|||||||
|
"""add_scoring_config
|
||||||
|
|
||||||
|
Single-row table to persist scoring weights in the database,
|
||||||
|
replacing the mutable in-process Settings approach.
|
||||||
|
|
||||||
|
Revision ID: b027scorecfg
|
||||||
|
Revises: b026techidx
|
||||||
|
Create Date: 2026-02-19 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "b027scorecfg"
|
||||||
|
down_revision: Union[str, None] = "b026techidx"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"scoring_config",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||||
|
sa.Column("weight_tests", sa.Float(), nullable=False, server_default="40.0"),
|
||||||
|
sa.Column("weight_detection_rules", sa.Float(), nullable=False, server_default="20.0"),
|
||||||
|
sa.Column("weight_d3fend", sa.Float(), nullable=False, server_default="15.0"),
|
||||||
|
sa.Column("weight_freshness", sa.Float(), nullable=False, server_default="15.0"),
|
||||||
|
sa.Column("weight_platform_diversity", sa.Float(), nullable=False, server_default="10.0"),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("scoring_config")
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
"""Structured JSON logging configuration.
|
||||||
|
|
||||||
|
In **production** (``AEGIS_ENV=production``), emits one JSON object per
|
||||||
|
line so that log aggregators (ELK, CloudWatch, Datadog) can ingest them
|
||||||
|
without custom parsing.
|
||||||
|
|
||||||
|
In **development** (default), uses a human-readable text format for
|
||||||
|
comfortable local work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
class _JSONFormatter(logging.Formatter):
|
||||||
|
"""Emit each log record as a single-line JSON object."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
payload: dict = {
|
||||||
|
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.exc_info and record.exc_info[1] is not None:
|
||||||
|
payload["exception"] = self.formatException(record.exc_info)
|
||||||
|
|
||||||
|
extra = getattr(record, "_extra", None)
|
||||||
|
if extra:
|
||||||
|
payload.update(extra)
|
||||||
|
|
||||||
|
return json.dumps(payload, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
_DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging() -> None:
|
||||||
|
"""Configure the root logger based on the environment."""
|
||||||
|
is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||||
|
level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||||
|
level = getattr(logging, level_name, logging.INFO)
|
||||||
|
|
||||||
|
root = logging.getLogger()
|
||||||
|
root.setLevel(level)
|
||||||
|
|
||||||
|
if root.handlers:
|
||||||
|
root.handlers.clear()
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setLevel(level)
|
||||||
|
|
||||||
|
if is_production:
|
||||||
|
handler.setFormatter(_JSONFormatter())
|
||||||
|
else:
|
||||||
|
handler.setFormatter(logging.Formatter(_DEV_FORMAT))
|
||||||
|
|
||||||
|
root.addHandler(handler)
|
||||||
|
|
||||||
|
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
+3
-4
@@ -47,10 +47,9 @@ from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
|||||||
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||||
|
|
||||||
# ── Logging ───────────────────────────────────────────────────────────────
|
# ── Logging ───────────────────────────────────────────────────────────────
|
||||||
logging.basicConfig(
|
from app.logging_config import setup_logging
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)-8s %(name)s — %(message)s",
|
setup_logging()
|
||||||
)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueStat
|
|||||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||||
from app.models.worklog import Worklog
|
from app.models.worklog import Worklog
|
||||||
from app.models.osint_item import OsintItem
|
from app.models.osint_item import OsintItem
|
||||||
|
from app.models.scoring_config import ScoringConfig
|
||||||
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -31,6 +32,6 @@ __all__ = [
|
|||||||
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
|
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
|
||||||
"CoverageSnapshot", "SnapshotTechniqueState",
|
"CoverageSnapshot", "SnapshotTechniqueState",
|
||||||
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
|
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
|
||||||
"Worklog", "OsintItem",
|
"Worklog", "OsintItem", "ScoringConfig",
|
||||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
"""ScoringConfig — single-row table for persisted scoring weights.
|
||||||
|
|
||||||
|
Replaces the mutable-settings approach where PATCH /scores/config
|
||||||
|
mutated the in-process ``Settings`` object (lost on restart).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Float, DateTime, func
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringConfig(Base):
|
||||||
|
__tablename__ = "scoring_config"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
weight_tests = Column(Float, nullable=False, default=40.0)
|
||||||
|
weight_detection_rules = Column(Float, nullable=False, default=20.0)
|
||||||
|
weight_d3fend = Column(Float, nullable=False, default=15.0)
|
||||||
|
weight_freshness = Column(Float, nullable=False, default=15.0)
|
||||||
|
weight_platform_diversity = Column(Float, nullable=False, default=10.0)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||||
@@ -7,26 +7,29 @@ test ordering, progress tracking, and threat actor integration.
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||||
from app.models.test import Test
|
from app.services.campaign_crud_service import (
|
||||||
from app.models.technique import Technique
|
add_test_to_campaign as crud_add_test,
|
||||||
from app.models.threat_actor import ThreatActor
|
activate_campaign as crud_activate,
|
||||||
from app.services.campaign_service import (
|
complete_campaign as crud_complete,
|
||||||
validate_no_circular_dependency,
|
create_campaign as crud_create,
|
||||||
get_campaign_progress,
|
get_campaign_detail as crud_get_detail,
|
||||||
generate_campaign_from_threat_actor,
|
get_campaign_history as crud_get_history,
|
||||||
TACTIC_TO_PHASE,
|
get_campaign_progress_data as crud_get_progress,
|
||||||
|
list_campaigns as crud_list,
|
||||||
|
remove_test_from_campaign as crud_remove_test,
|
||||||
|
schedule_campaign as crud_schedule,
|
||||||
|
serialize_campaign,
|
||||||
|
update_campaign as crud_update,
|
||||||
)
|
)
|
||||||
from app.services.campaign_scheduler_service import calculate_next_run
|
|
||||||
from app.services.notification_service import create_notification
|
from app.services.notification_service import create_notification
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
|
||||||
@@ -67,89 +70,6 @@ class SchedulePayload(BaseModel):
|
|||||||
next_run_at: Optional[str] = None
|
next_run_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
|
||||||
"""Serialize a campaign with its tests and progress."""
|
|
||||||
progress = get_campaign_progress(db, campaign.id)
|
|
||||||
|
|
||||||
campaign_tests = (
|
|
||||||
db.query(CampaignTest)
|
|
||||||
.filter(CampaignTest.campaign_id == campaign.id)
|
|
||||||
.order_by(CampaignTest.order_index)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
tests = []
|
|
||||||
for ct in campaign_tests:
|
|
||||||
test = ct.test
|
|
||||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
|
||||||
|
|
||||||
tests.append({
|
|
||||||
"id": str(ct.id),
|
|
||||||
"test_id": str(ct.test_id),
|
|
||||||
"order_index": ct.order_index,
|
|
||||||
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
|
||||||
"phase": ct.phase,
|
|
||||||
"test_name": test.name if test else None,
|
|
||||||
"test_state": test.state.value if test and test.state else None,
|
|
||||||
"test_result": test.result.value if test and test.result else None,
|
|
||||||
"technique_mitre_id": technique.mitre_id if technique else None,
|
|
||||||
"technique_name": technique.name if technique else None,
|
|
||||||
"platform": test.platform if test else None,
|
|
||||||
})
|
|
||||||
|
|
||||||
actor = campaign.threat_actor
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(campaign.id),
|
|
||||||
"name": campaign.name,
|
|
||||||
"description": campaign.description,
|
|
||||||
"type": campaign.type,
|
|
||||||
"status": campaign.status,
|
|
||||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
|
||||||
"threat_actor_name": actor.name if actor else None,
|
|
||||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
|
||||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
|
||||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
|
||||||
"target_platform": campaign.target_platform,
|
|
||||||
"tags": campaign.tags or [],
|
|
||||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
|
||||||
"is_recurring": campaign.is_recurring or False,
|
|
||||||
"recurrence_pattern": campaign.recurrence_pattern,
|
|
||||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
|
||||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
|
||||||
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
|
||||||
"tests": tests,
|
|
||||||
"progress": progress,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
|
||||||
"""Lightweight campaign serialization for list views."""
|
|
||||||
progress = get_campaign_progress(db, campaign.id)
|
|
||||||
actor = campaign.threat_actor
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(campaign.id),
|
|
||||||
"name": campaign.name,
|
|
||||||
"description": campaign.description,
|
|
||||||
"type": campaign.type,
|
|
||||||
"status": campaign.status,
|
|
||||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
|
||||||
"threat_actor_name": actor.name if actor else None,
|
|
||||||
"target_platform": campaign.target_platform,
|
|
||||||
"tags": campaign.tags or [],
|
|
||||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
|
||||||
"is_recurring": campaign.is_recurring or False,
|
|
||||||
"recurrence_pattern": campaign.recurrence_pattern,
|
|
||||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
|
||||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
|
||||||
"test_count": progress["total"],
|
|
||||||
"completion_pct": progress["completion_pct"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /campaigns — List campaigns with filters
|
# GET /campaigns — List campaigns with filters
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -166,28 +86,15 @@ def list_campaigns(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List campaigns with optional filters and pagination."""
|
"""List campaigns with optional filters and pagination."""
|
||||||
query = db.query(Campaign)
|
return crud_list(
|
||||||
|
db,
|
||||||
if type:
|
type=type,
|
||||||
query = query.filter(Campaign.type == type)
|
status=status,
|
||||||
if status:
|
threat_actor_id=threat_actor_id,
|
||||||
query = query.filter(Campaign.status == status)
|
search=search,
|
||||||
if threat_actor_id:
|
offset=offset,
|
||||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
limit=limit,
|
||||||
if search:
|
)
|
||||||
from app.utils import escape_like
|
|
||||||
pattern = f"%{escape_like(search)}%"
|
|
||||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": total,
|
|
||||||
"offset": offset,
|
|
||||||
"limit": limit,
|
|
||||||
"items": [_serialize_campaign_summary(db, c) for c in campaigns],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -201,30 +108,29 @@ def create_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Create a new campaign."""
|
"""Create a new campaign."""
|
||||||
campaign = Campaign(
|
result = crud_create(
|
||||||
|
db,
|
||||||
|
creator_id=current_user.id,
|
||||||
name=payload.name,
|
name=payload.name,
|
||||||
description=payload.description,
|
description=payload.description,
|
||||||
type=payload.type,
|
type=payload.type,
|
||||||
threat_actor_id=payload.threat_actor_id,
|
threat_actor_id=payload.threat_actor_id,
|
||||||
target_platform=payload.target_platform,
|
target_platform=payload.target_platform,
|
||||||
tags=payload.tags or [],
|
tags=payload.tags,
|
||||||
created_by=current_user.id,
|
scheduled_at=payload.scheduled_at,
|
||||||
scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None,
|
|
||||||
)
|
)
|
||||||
db.add(campaign)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(campaign)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
action="create_campaign",
|
action="create_campaign",
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
entity_id=campaign.id,
|
entity_id=result["id"],
|
||||||
details={"name": campaign.name, "type": campaign.type},
|
details={"name": payload.name, "type": payload.type},
|
||||||
)
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -238,11 +144,7 @@ def get_campaign(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get detailed campaign info including tests and progress."""
|
"""Get detailed campaign info including tests and progress."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
return crud_get_detail(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -257,37 +159,26 @@ def update_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Update a campaign. Only allowed in draft or active state."""
|
"""Update a campaign. Only allowed in draft or active state."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status not in ("draft", "active"):
|
|
||||||
raise HTTPException(status_code=400, detail="Can only update draft or active campaigns")
|
|
||||||
|
|
||||||
# Check ownership or admin
|
|
||||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
|
||||||
raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign")
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
if "scheduled_at" in update_data and update_data["scheduled_at"]:
|
result = crud_update(
|
||||||
update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"])
|
db,
|
||||||
|
campaign_id,
|
||||||
for field, value in update_data.items():
|
updater_id=current_user.id,
|
||||||
setattr(campaign, field, value)
|
updater_role=current_user.role,
|
||||||
|
**update_data,
|
||||||
db.commit()
|
)
|
||||||
db.refresh(campaign)
|
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
db,
|
db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
action="update_campaign",
|
action="update_campaign",
|
||||||
entity_type="campaign",
|
entity_type="campaign",
|
||||||
entity_id=campaign.id,
|
entity_id=campaign_id,
|
||||||
details={"updated_fields": list(update_data.keys())},
|
details={"updated_fields": list(update_data.keys())},
|
||||||
)
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -302,63 +193,16 @@ def add_test_to_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Add a test to a campaign with optional ordering and dependency."""
|
"""Add a test to a campaign with optional ordering and dependency."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
result = crud_add_test(
|
||||||
if not campaign:
|
db,
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
campaign_id,
|
||||||
|
|
||||||
if campaign.status not in ("draft", "active"):
|
|
||||||
raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns")
|
|
||||||
|
|
||||||
test = db.query(Test).filter(Test.id == payload.test_id).first()
|
|
||||||
if not test:
|
|
||||||
raise HTTPException(status_code=404, detail="Test not found")
|
|
||||||
|
|
||||||
# Calculate order_index if not provided
|
|
||||||
if payload.order_index is not None:
|
|
||||||
order_index = payload.order_index
|
|
||||||
else:
|
|
||||||
max_order = (
|
|
||||||
db.query(CampaignTest.order_index)
|
|
||||||
.filter(CampaignTest.campaign_id == campaign_id)
|
|
||||||
.order_by(CampaignTest.order_index.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
order_index = (max_order[0] + 1) if max_order else 0
|
|
||||||
|
|
||||||
depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None
|
|
||||||
|
|
||||||
# Validate circular dependency
|
|
||||||
ct_id = uuid.uuid4()
|
|
||||||
if depends_on:
|
|
||||||
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on)
|
|
||||||
|
|
||||||
# Auto-detect kill chain phase from the test's technique tactic if not provided
|
|
||||||
phase = payload.phase
|
|
||||||
if not phase and test.technique_id:
|
|
||||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
|
||||||
if technique and technique.tactic:
|
|
||||||
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
|
||||||
|
|
||||||
campaign_test = CampaignTest(
|
|
||||||
id=ct_id,
|
|
||||||
campaign_id=campaign_id,
|
|
||||||
test_id=payload.test_id,
|
test_id=payload.test_id,
|
||||||
order_index=order_index,
|
order_index=payload.order_index,
|
||||||
depends_on=depends_on,
|
depends_on=payload.depends_on,
|
||||||
phase=phase,
|
phase=payload.phase,
|
||||||
)
|
)
|
||||||
db.add(campaign_test)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign_test)
|
return result
|
||||||
|
|
||||||
return {
|
|
||||||
"id": str(campaign_test.id),
|
|
||||||
"campaign_id": str(campaign_test.campaign_id),
|
|
||||||
"test_id": str(campaign_test.test_id),
|
|
||||||
"order_index": campaign_test.order_index,
|
|
||||||
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
|
||||||
"phase": campaign_test.phase,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -373,36 +217,8 @@ def remove_test_from_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Remove a test from a campaign."""
|
"""Remove a test from a campaign."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
crud_remove_test(db, campaign_id, campaign_test_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status not in ("draft", "active"):
|
|
||||||
raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns")
|
|
||||||
|
|
||||||
ct = (
|
|
||||||
db.query(CampaignTest)
|
|
||||||
.filter(
|
|
||||||
CampaignTest.id == campaign_test_id,
|
|
||||||
CampaignTest.campaign_id == campaign_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not ct:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign test not found")
|
|
||||||
|
|
||||||
# Clear any references to this campaign_test
|
|
||||||
dependents = (
|
|
||||||
db.query(CampaignTest)
|
|
||||||
.filter(CampaignTest.depends_on == campaign_test_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
for dep in dependents:
|
|
||||||
dep.depends_on = None
|
|
||||||
|
|
||||||
db.delete(ct)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return {"detail": "Test removed from campaign"}
|
return {"detail": "Test removed from campaign"}
|
||||||
|
|
||||||
|
|
||||||
@@ -417,23 +233,10 @@ def activate_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Activate a campaign, moving it from draft to active."""
|
"""Activate a campaign, moving it from draft to active."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
campaign = crud_activate(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status != "draft":
|
|
||||||
raise HTTPException(status_code=400, detail="Only draft campaigns can be activated")
|
|
||||||
|
|
||||||
# Verify campaign has at least one test
|
|
||||||
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
|
||||||
if test_count == 0:
|
|
||||||
raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate")
|
|
||||||
|
|
||||||
campaign.status = "active"
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
# Notify relevant users
|
|
||||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||||
for user in red_techs:
|
for user in red_techs:
|
||||||
create_notification(
|
create_notification(
|
||||||
@@ -455,7 +258,7 @@ def activate_campaign(
|
|||||||
details={"name": campaign.name},
|
details={"name": campaign.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -469,15 +272,7 @@ def complete_campaign(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
||||||
):
|
):
|
||||||
"""Mark a campaign as completed."""
|
"""Mark a campaign as completed."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
campaign = crud_complete(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
if campaign.status != "active":
|
|
||||||
raise HTTPException(status_code=400, detail="Only active campaigns can be completed")
|
|
||||||
|
|
||||||
campaign.status = "completed"
|
|
||||||
campaign.completed_at = datetime.utcnow()
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
@@ -490,7 +285,7 @@ def complete_campaign(
|
|||||||
details={"name": campaign.name},
|
details={"name": campaign.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -504,16 +299,7 @@ def get_campaign_progress_endpoint(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Get progress statistics for a campaign."""
|
"""Get progress statistics for a campaign."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
return crud_get_progress(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
|
||||||
return {
|
|
||||||
"campaign_id": str(campaign.id),
|
|
||||||
"campaign_name": campaign.name,
|
|
||||||
**progress,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -546,7 +332,7 @@ def generate_campaign_from_actor(
|
|||||||
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -564,31 +350,15 @@ def schedule_campaign(
|
|||||||
|
|
||||||
Only the campaign creator or admin can change scheduling.
|
Only the campaign creator or admin can change scheduling.
|
||||||
"""
|
"""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
campaign = crud_schedule(
|
||||||
if not campaign:
|
db,
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
campaign_id,
|
||||||
|
owner_id=current_user.id,
|
||||||
# Check ownership or admin
|
owner_role=current_user.role,
|
||||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
is_recurring=payload.is_recurring,
|
||||||
raise HTTPException(status_code=403, detail="Only the creator or admin can configure scheduling")
|
recurrence_pattern=payload.recurrence_pattern,
|
||||||
|
next_run_at=payload.next_run_at,
|
||||||
campaign.is_recurring = payload.is_recurring
|
)
|
||||||
|
|
||||||
if payload.is_recurring:
|
|
||||||
if payload.recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'",
|
|
||||||
)
|
|
||||||
campaign.recurrence_pattern = payload.recurrence_pattern
|
|
||||||
if payload.next_run_at:
|
|
||||||
campaign.next_run_at = datetime.fromisoformat(payload.next_run_at.replace("Z", "+00:00").replace("+00:00", ""))
|
|
||||||
elif not campaign.next_run_at:
|
|
||||||
campaign.next_run_at = calculate_next_run(datetime.utcnow(), payload.recurrence_pattern)
|
|
||||||
else:
|
|
||||||
campaign.recurrence_pattern = None
|
|
||||||
campaign.next_run_at = None
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(campaign)
|
db.refresh(campaign)
|
||||||
|
|
||||||
@@ -605,7 +375,7 @@ def schedule_campaign(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _serialize_campaign(db, campaign)
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -619,30 +389,4 @@ def get_campaign_history(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List all child campaigns (execution history) of a recurring campaign."""
|
"""List all child campaigns (execution history) of a recurring campaign."""
|
||||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
return crud_get_history(db, campaign_id)
|
||||||
if not campaign:
|
|
||||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
|
||||||
|
|
||||||
children = (
|
|
||||||
db.query(Campaign)
|
|
||||||
.filter(Campaign.parent_campaign_id == campaign_id)
|
|
||||||
.order_by(Campaign.created_at.desc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"campaign_id": str(campaign.id),
|
|
||||||
"campaign_name": campaign.name,
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"id": str(child.id),
|
|
||||||
"name": child.name,
|
|
||||||
"status": child.status,
|
|
||||||
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
|
||||||
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
|
||||||
"created_at": child.created_at.isoformat() if child.created_at else None,
|
|
||||||
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
|
||||||
}
|
|
||||||
for child in children
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|||||||
+23
-174
@@ -24,52 +24,32 @@ import os
|
|||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user
|
from app.dependencies.auth import get_current_user
|
||||||
from app.models.enums import TeamSide, TestState
|
from app.models.enums import TeamSide
|
||||||
from app.models.evidence import Evidence
|
from app.models.evidence import Evidence
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.evidence import EvidenceOut
|
from app.schemas.evidence import EvidenceOut
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
|
from app.services.evidence_service import (
|
||||||
|
get_evidence_or_raise,
|
||||||
|
get_test_or_raise,
|
||||||
|
list_evidence_for_test,
|
||||||
|
MAX_UPLOAD_SIZE,
|
||||||
|
validate_delete_permission,
|
||||||
|
validate_file,
|
||||||
|
validate_upload_permission,
|
||||||
|
)
|
||||||
from app.storage import get_presigned_url, upload_file
|
from app.storage import get_presigned_url, upload_file
|
||||||
|
|
||||||
router = APIRouter(tags=["evidence"])
|
router = APIRouter(tags=["evidence"])
|
||||||
|
|
||||||
# States where red evidence can be uploaded / deleted
|
|
||||||
_RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
|
|
||||||
# States where blue evidence can be uploaded / deleted
|
|
||||||
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Upload safety limits
|
# Helpers (router-specific: infrastructure / HTTP concerns)
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Maximum upload size in bytes (default 50 MB)
|
|
||||||
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024
|
|
||||||
|
|
||||||
# Allowed file extensions (lowercase, with leading dot)
|
|
||||||
_ALLOWED_EXTENSIONS: set[str] = {
|
|
||||||
# Images / screenshots
|
|
||||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
|
|
||||||
# Documents
|
|
||||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
|
|
||||||
".md", ".rtf", ".odt", ".ods",
|
|
||||||
# Logs & captures
|
|
||||||
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
|
|
||||||
".yaml", ".yml", ".toml",
|
|
||||||
# Archives (for bundled evidence)
|
|
||||||
".zip", ".tar", ".gz", ".7z",
|
|
||||||
# Other common evidence types
|
|
||||||
".har", ".eml", ".msg",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||||
@@ -87,85 +67,6 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_upload_permission(
|
|
||||||
test: Test,
|
|
||||||
team: TeamSide,
|
|
||||||
user: User,
|
|
||||||
) -> None:
|
|
||||||
"""Raise 403 if the user/team combination is not allowed in the current state."""
|
|
||||||
# Admins bypass all checks
|
|
||||||
if user.role == "admin":
|
|
||||||
return
|
|
||||||
|
|
||||||
if team == TeamSide.red:
|
|
||||||
if user.role not in ("red_tech", "red_lead"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Only red_tech, red_lead or admin can upload red evidence",
|
|
||||||
)
|
|
||||||
if test.state not in _RED_EDITABLE_STATES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Cannot upload red evidence in '{test.state.value}' state "
|
|
||||||
f"(allowed in: draft, red_executing)",
|
|
||||||
)
|
|
||||||
elif team == TeamSide.blue:
|
|
||||||
if user.role not in ("blue_tech", "blue_lead"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Only blue_tech, blue_lead or admin can upload blue evidence",
|
|
||||||
)
|
|
||||||
if test.state not in _BLUE_EDITABLE_STATES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Cannot upload blue evidence in '{test.state.value}' state "
|
|
||||||
f"(allowed in: blue_evaluating)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_delete_permission(
|
|
||||||
test: Test,
|
|
||||||
evidence: Evidence,
|
|
||||||
user: User,
|
|
||||||
) -> None:
|
|
||||||
"""Raise 403 if the user cannot delete this evidence in the current state."""
|
|
||||||
# No deletions in review / validated / rejected
|
|
||||||
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=f"Cannot delete evidence when test is in '{test.state.value}' state",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Admin can delete in editable states
|
|
||||||
if user.role == "admin":
|
|
||||||
return
|
|
||||||
|
|
||||||
ev_team = evidence.team
|
|
||||||
|
|
||||||
if ev_team == TeamSide.red:
|
|
||||||
if test.state not in _RED_EDITABLE_STATES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Cannot delete red evidence outside draft/red_executing",
|
|
||||||
)
|
|
||||||
if user.role not in ("red_tech", "red_lead") and evidence.uploaded_by != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Not enough permissions to delete this evidence",
|
|
||||||
)
|
|
||||||
elif ev_team == TeamSide.blue:
|
|
||||||
if test.state not in _BLUE_EDITABLE_STATES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Cannot delete blue evidence outside blue_evaluating",
|
|
||||||
)
|
|
||||||
if user.role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Not enough permissions to delete this evidence",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# POST /tests/{test_id}/evidence — upload with team
|
# POST /tests/{test_id}/evidence — upload with team
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -189,36 +90,14 @@ async def upload_evidence(
|
|||||||
The ``team`` field (sent as form data) determines whether this is
|
The ``team`` field (sent as form data) determines whether this is
|
||||||
Red Team (attack) or Blue Team (detection) evidence.
|
Red Team (attack) or Blue Team (detection) evidence.
|
||||||
"""
|
"""
|
||||||
test = db.query(Test).filter(Test.id == test_id).first()
|
test = get_test_or_raise(db, test_id)
|
||||||
if test is None:
|
validate_upload_permission(test, team, current_user.role)
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate permissions
|
|
||||||
_validate_upload_permission(test, team, current_user)
|
|
||||||
|
|
||||||
# 1. Validate file extension
|
|
||||||
file_name = file.filename or "unnamed"
|
file_name = file.filename or "unnamed"
|
||||||
_, ext = os.path.splitext(file_name)
|
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||||
if ext.lower() not in _ALLOWED_EXTENSIONS:
|
validate_file(file_name, len(content))
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"File type '{ext}' is not allowed. "
|
|
||||||
f"Permitted types: {', '.join(sorted(_ALLOWED_EXTENSIONS))}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Read content with size limit
|
# Hash
|
||||||
content = await file.read(_MAX_UPLOAD_SIZE + 1)
|
|
||||||
if len(content) > _MAX_UPLOAD_SIZE:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
|
||||||
detail=f"File exceeds maximum upload size of "
|
|
||||||
f"{_MAX_UPLOAD_SIZE // (1024 * 1024)} MB",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Hash
|
|
||||||
sha256 = hashlib.sha256(content).hexdigest()
|
sha256 = hashlib.sha256(content).hexdigest()
|
||||||
|
|
||||||
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
||||||
@@ -273,19 +152,8 @@ def list_evidence(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""List all evidences for a test, optionally filtered by team."""
|
"""List all evidences for a test, optionally filtered by team."""
|
||||||
test = db.query(Test).filter(Test.id == test_id).first()
|
get_test_or_raise(db, test_id)
|
||||||
if test is None:
|
evidences = list_evidence_for_test(db, test_id, team=team)
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
query = db.query(Evidence).filter(Evidence.test_id == test_id)
|
|
||||||
|
|
||||||
if team:
|
|
||||||
query = query.filter(Evidence.team == team)
|
|
||||||
|
|
||||||
evidences = query.order_by(Evidence.uploaded_at.desc()).all()
|
|
||||||
return [_evidence_to_out(e) for e in evidences]
|
return [_evidence_to_out(e) for e in evidences]
|
||||||
|
|
||||||
|
|
||||||
@@ -301,13 +169,7 @@ def get_evidence(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return evidence metadata together with a presigned download URL."""
|
"""Return evidence metadata together with a presigned download URL."""
|
||||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
evidence = get_evidence_or_raise(db, evidence_id)
|
||||||
if evidence is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Evidence not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return _evidence_to_out(evidence)
|
return _evidence_to_out(evidence)
|
||||||
|
|
||||||
|
|
||||||
@@ -329,22 +191,9 @@ def delete_evidence(
|
|||||||
- Blue evidence: ``blue_evaluating``
|
- Blue evidence: ``blue_evaluating``
|
||||||
- No deletions in ``in_review``, ``validated``, ``rejected``
|
- No deletions in ``in_review``, ``validated``, ``rejected``
|
||||||
"""
|
"""
|
||||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
evidence = get_evidence_or_raise(db, evidence_id)
|
||||||
if evidence is None:
|
test = get_test_or_raise(db, evidence.test_id)
|
||||||
raise HTTPException(
|
validate_delete_permission(test, evidence, current_user.role, current_user.id)
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Evidence not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
test = db.query(Test).filter(Test.id == evidence.test_id).first()
|
|
||||||
if test is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Parent test not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Permission checks
|
|
||||||
_validate_delete_permission(test, evidence, current_user)
|
|
||||||
|
|
||||||
# Audit before deletion
|
# Audit before deletion
|
||||||
log_action(
|
log_action(
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from app.dependencies.auth import get_current_user, require_role
|
|||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.threat_actor import ThreatActor
|
from app.models.threat_actor import ThreatActor
|
||||||
from app.config import settings
|
|
||||||
from app.services.scoring_service import (
|
from app.services.scoring_service import (
|
||||||
calculate_technique_score,
|
calculate_technique_score,
|
||||||
calculate_tactic_score,
|
calculate_tactic_score,
|
||||||
@@ -22,6 +21,10 @@ from app.services.scoring_service import (
|
|||||||
calculate_organization_score,
|
calculate_organization_score,
|
||||||
get_score_history,
|
get_score_history,
|
||||||
)
|
)
|
||||||
|
from app.services.scoring_config_service import (
|
||||||
|
get_weights_dict,
|
||||||
|
update_scoring_weights,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/scores", tags=["scores"])
|
router = APIRouter(prefix="/scores", tags=["scores"])
|
||||||
|
|
||||||
@@ -117,79 +120,45 @@ def score_history(
|
|||||||
|
|
||||||
@router.get("/config")
|
@router.get("/config")
|
||||||
def get_scoring_config(
|
def get_scoring_config(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Get current scoring weights (admin only)."""
|
"""Get current scoring weights (admin only)."""
|
||||||
return {
|
return get_weights_dict(db)
|
||||||
"weights": {
|
|
||||||
"tests": settings.SCORING_WEIGHT_TESTS,
|
|
||||||
"detection_rules": settings.SCORING_WEIGHT_DETECTION_RULES,
|
|
||||||
"d3fend": settings.SCORING_WEIGHT_D3FEND,
|
|
||||||
"freshness": settings.SCORING_WEIGHT_FRESHNESS,
|
|
||||||
"platform_diversity": settings.SCORING_WEIGHT_PLATFORM_DIVERSITY,
|
|
||||||
},
|
|
||||||
"total": (
|
|
||||||
settings.SCORING_WEIGHT_TESTS
|
|
||||||
+ settings.SCORING_WEIGHT_DETECTION_RULES
|
|
||||||
+ settings.SCORING_WEIGHT_D3FEND
|
|
||||||
+ settings.SCORING_WEIGHT_FRESHNESS
|
|
||||||
+ settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── PATCH /scores/config ─────────────────────────────────────────────
|
# ── PATCH /scores/config ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class ScoringConfigUpdate(BaseModel):
|
class ScoringConfigUpdate(BaseModel):
|
||||||
tests: Optional[int] = None
|
tests: Optional[float] = None
|
||||||
detection_rules: Optional[int] = None
|
detection_rules: Optional[float] = None
|
||||||
d3fend: Optional[int] = None
|
d3fend: Optional[float] = None
|
||||||
freshness: Optional[int] = None
|
freshness: Optional[float] = None
|
||||||
platform_diversity: Optional[int] = None
|
platform_diversity: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/config")
|
@router.patch("/config")
|
||||||
def update_scoring_config(
|
def update_scoring_config(
|
||||||
payload: ScoringConfigUpdate,
|
payload: ScoringConfigUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(require_role("admin")),
|
current_user: User = Depends(require_role("admin")),
|
||||||
):
|
):
|
||||||
"""Update scoring weights (admin only).
|
"""Update scoring weights (admin only).
|
||||||
|
|
||||||
Note: Since we're using Opcion A (env vars / Settings), changes
|
Weights are persisted in the database and survive restarts.
|
||||||
are applied at runtime but won't persist across restarts unless
|
Validation enforces that all weights are non-negative and sum to 100.
|
||||||
the .env file is also updated. For production, consider migrating
|
|
||||||
to Option B (database table).
|
|
||||||
"""
|
"""
|
||||||
if payload.tests is not None:
|
result = update_scoring_weights(
|
||||||
settings.SCORING_WEIGHT_TESTS = payload.tests
|
db,
|
||||||
if payload.detection_rules is not None:
|
tests=payload.tests,
|
||||||
settings.SCORING_WEIGHT_DETECTION_RULES = payload.detection_rules
|
detection_rules=payload.detection_rules,
|
||||||
if payload.d3fend is not None:
|
d3fend=payload.d3fend,
|
||||||
settings.SCORING_WEIGHT_D3FEND = payload.d3fend
|
freshness=payload.freshness,
|
||||||
if payload.freshness is not None:
|
platform_diversity=payload.platform_diversity,
|
||||||
settings.SCORING_WEIGHT_FRESHNESS = payload.freshness
|
)
|
||||||
if payload.platform_diversity is not None:
|
|
||||||
settings.SCORING_WEIGHT_PLATFORM_DIVERSITY = payload.platform_diversity
|
|
||||||
|
|
||||||
# Weights changed — bust the score cache
|
|
||||||
from app.services.score_cache import invalidate
|
from app.services.score_cache import invalidate
|
||||||
invalidate()
|
invalidate()
|
||||||
|
|
||||||
return {
|
return {"message": "Scoring config updated", **result}
|
||||||
"message": "Scoring config updated",
|
|
||||||
"weights": {
|
|
||||||
"tests": settings.SCORING_WEIGHT_TESTS,
|
|
||||||
"detection_rules": settings.SCORING_WEIGHT_DETECTION_RULES,
|
|
||||||
"d3fend": settings.SCORING_WEIGHT_D3FEND,
|
|
||||||
"freshness": settings.SCORING_WEIGHT_FRESHNESS,
|
|
||||||
"platform_diversity": settings.SCORING_WEIGHT_PLATFORM_DIVERSITY,
|
|
||||||
},
|
|
||||||
"total": (
|
|
||||||
settings.SCORING_WEIGHT_TESTS
|
|
||||||
+ settings.SCORING_WEIGHT_DETECTION_RULES
|
|
||||||
+ settings.SCORING_WEIGHT_D3FEND
|
|
||||||
+ settings.SCORING_WEIGHT_FRESHNESS
|
|
||||||
+ settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|||||||
+65
-203
@@ -22,15 +22,11 @@ import uuid
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.dependencies.auth import get_current_user, require_any_role
|
from app.dependencies.auth import get_current_user, require_any_role
|
||||||
from app.models.audit import AuditLog
|
from app.models.enums import TestState
|
||||||
from app.models.enums import TestState, TeamSide
|
|
||||||
from app.models.technique import Technique
|
|
||||||
from app.models.test import Test
|
|
||||||
from app.models.test_template import TestTemplate
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.test import (
|
from app.schemas.test import (
|
||||||
TestCreate,
|
TestCreate,
|
||||||
@@ -46,6 +42,18 @@ from app.schemas.test_template import TestTemplateInstantiate
|
|||||||
from app.domain.unit_of_work import UnitOfWork
|
from app.domain.unit_of_work import UnitOfWork
|
||||||
from app.services.audit_service import log_action
|
from app.services.audit_service import log_action
|
||||||
from app.services.status_service import recalculate_technique_status
|
from app.services.status_service import recalculate_technique_status
|
||||||
|
from app.services.test_crud_service import (
|
||||||
|
create_test as crud_create_test,
|
||||||
|
create_test_from_template as crud_create_from_template,
|
||||||
|
get_test_detail as crud_get_test_detail,
|
||||||
|
get_test_or_raise as crud_get_test_or_raise,
|
||||||
|
get_test_timeline as crud_get_test_timeline,
|
||||||
|
get_test_with_technique as crud_get_test_with_technique,
|
||||||
|
list_tests as crud_list_tests,
|
||||||
|
update_test as crud_update_test,
|
||||||
|
update_test_blue as crud_update_test_blue,
|
||||||
|
update_test_red as crud_update_test_red,
|
||||||
|
)
|
||||||
from app.services.test_workflow_service import (
|
from app.services.test_workflow_service import (
|
||||||
start_execution as wf_start_execution,
|
start_execution as wf_start_execution,
|
||||||
submit_red_evidence as wf_submit_red,
|
submit_red_evidence as wf_submit_red,
|
||||||
@@ -62,29 +70,6 @@ from app.services.test_workflow_service import (
|
|||||||
router = APIRouter(prefix="/tests", tags=["tests"])
|
router = APIRouter(prefix="/tests", tags=["tests"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test:
|
|
||||||
test = db.query(Test).filter(Test.id == test_id).first()
|
|
||||||
if test is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
|
|
||||||
return test
|
|
||||||
|
|
||||||
|
|
||||||
def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
|
||||||
test = (
|
|
||||||
db.query(Test)
|
|
||||||
.options(joinedload(Test.technique))
|
|
||||||
.filter(Test.id == test_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if test is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
|
|
||||||
return test
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /tests — list with filters
|
# GET /tests — list with filters
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -105,30 +90,16 @@ def list_tests(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
|
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
|
||||||
query = db.query(Test).options(joinedload(Test.technique))
|
return crud_list_tests(
|
||||||
|
db,
|
||||||
if state:
|
state=state,
|
||||||
query = query.filter(Test.state == state)
|
technique_id=technique_id,
|
||||||
if technique_id:
|
platform=platform,
|
||||||
query = query.filter(Test.technique_id == technique_id)
|
created_by=created_by,
|
||||||
if platform:
|
pending_validation_side=pending_validation_side,
|
||||||
from app.utils import escape_like
|
offset=offset,
|
||||||
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
limit=limit,
|
||||||
if created_by:
|
)
|
||||||
query = query.filter(Test.created_by == created_by)
|
|
||||||
if pending_validation_side == "red":
|
|
||||||
query = query.filter(
|
|
||||||
Test.state == TestState.in_review,
|
|
||||||
Test.red_validation_status.in_(["pending", None]),
|
|
||||||
)
|
|
||||||
elif pending_validation_side == "blue":
|
|
||||||
query = query.filter(
|
|
||||||
Test.state == TestState.in_review,
|
|
||||||
Test.blue_validation_status.in_(["pending", None]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
|
||||||
return tests
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -150,20 +121,14 @@ def create_test(
|
|||||||
|
|
||||||
``created_by`` is set automatically and ``state`` defaults to *draft*.
|
``created_by`` is set automatically and ``state`` defaults to *draft*.
|
||||||
"""
|
"""
|
||||||
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
|
with UnitOfWork(db) as uow:
|
||||||
if technique is None:
|
test = crud_create_test(
|
||||||
raise HTTPException(
|
db,
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
technique_id=payload.technique_id,
|
||||||
detail=f"Technique with id '{payload.technique_id}' not found",
|
creator_id=current_user.id,
|
||||||
|
**payload.model_dump(exclude={"technique_id"}),
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
test = Test(
|
|
||||||
**payload.model_dump(),
|
|
||||||
created_by=current_user.id,
|
|
||||||
state=TestState.draft,
|
|
||||||
)
|
|
||||||
db.add(test)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -197,43 +162,14 @@ def create_test_from_template(
|
|||||||
|
|
||||||
The template's fields are copied into the new test as starting data.
|
The template's fields are copied into the new test as starting data.
|
||||||
"""
|
"""
|
||||||
template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first()
|
with UnitOfWork(db) as uow:
|
||||||
if template is None:
|
test = crud_create_from_template(
|
||||||
raise HTTPException(
|
db,
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
template_id=payload.template_id,
|
||||||
detail=f"TestTemplate with id '{payload.template_id}' not found",
|
technique_id_or_mitre=payload.technique_id,
|
||||||
|
creator_id=current_user.id,
|
||||||
)
|
)
|
||||||
|
uow.commit()
|
||||||
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
|
|
||||||
technique = None
|
|
||||||
try:
|
|
||||||
technique_uuid = uuid.UUID(payload.technique_id)
|
|
||||||
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if technique is None:
|
|
||||||
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
|
|
||||||
|
|
||||||
if technique is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Technique '{payload.technique_id}' not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
test = Test(
|
|
||||||
technique_id=technique.id,
|
|
||||||
name=template.name,
|
|
||||||
description=template.description,
|
|
||||||
platform=template.platform,
|
|
||||||
procedure_text=template.attack_procedure,
|
|
||||||
tool_used=template.tool_suggested,
|
|
||||||
remediation_steps=template.suggested_remediation,
|
|
||||||
created_by=current_user.id,
|
|
||||||
state=TestState.draft,
|
|
||||||
)
|
|
||||||
db.add(test)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -244,7 +180,7 @@ def create_test_from_template(
|
|||||||
entity_id=test.id,
|
entity_id=test.id,
|
||||||
details={
|
details={
|
||||||
"name": test.name,
|
"name": test.name,
|
||||||
"template_id": str(template.id),
|
"template_id": str(payload.template_id),
|
||||||
"technique_id": str(test.technique_id),
|
"technique_id": str(test.technique_id),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -264,20 +200,7 @@ def get_test(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return full details for a single test, including its evidences."""
|
"""Return full details for a single test, including its evidences."""
|
||||||
test = (
|
return crud_get_test_detail(db, test_id)
|
||||||
db.query(Test)
|
|
||||||
.options(joinedload(Test.evidences))
|
|
||||||
.filter(Test.id == test_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if test is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Test not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return test
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -297,29 +220,16 @@ def update_test(
|
|||||||
Only leads or admins can update general test fields.
|
Only leads or admins can update general test fields.
|
||||||
The test must be in ``draft`` or ``rejected`` state.
|
The test must be in ``draft`` or ``rejected`` state.
|
||||||
"""
|
"""
|
||||||
test = _get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
if current_user.role != "admin" and test.created_by != current_user.id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if test.state not in (TestState.draft, TestState.rejected):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
|
|
||||||
"code": "INVALID_STATE",
|
|
||||||
"current_state": test.state.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
with UnitOfWork(db) as uow:
|
||||||
setattr(test, field, value)
|
test = crud_update_test(
|
||||||
|
db,
|
||||||
db.commit()
|
test_id,
|
||||||
|
updater_id=current_user.id,
|
||||||
|
updater_role=current_user.role,
|
||||||
|
**update_data,
|
||||||
|
)
|
||||||
|
uow.commit()
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -347,23 +257,10 @@ def update_test_red(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||||
):
|
):
|
||||||
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
|
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
if test.state not in (TestState.draft, TestState.red_executing):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
|
|
||||||
"code": "INVALID_STATE",
|
|
||||||
"current_state": test.state.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
with UnitOfWork(db) as uow:
|
||||||
setattr(test, field, value)
|
test = crud_update_test_red(db, test_id, **update_data)
|
||||||
|
uow.commit()
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -391,23 +288,10 @@ def update_test_blue(
|
|||||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
|
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
if test.state != TestState.blue_evaluating:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
|
|
||||||
"code": "INVALID_STATE",
|
|
||||||
"current_state": test.state.value,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
update_data = payload.model_dump(exclude_unset=True)
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
with UnitOfWork(db) as uow:
|
||||||
setattr(test, field, value)
|
test = crud_update_test_blue(db, test_id, **update_data)
|
||||||
|
uow.commit()
|
||||||
db.commit()
|
|
||||||
db.refresh(test)
|
db.refresh(test)
|
||||||
|
|
||||||
log_action(
|
log_action(
|
||||||
@@ -434,7 +318,7 @@ def start_execution(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||||
):
|
):
|
||||||
"""Move a test from ``draft`` to ``red_executing``."""
|
"""Move a test from ``draft`` to ``red_executing``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_start_execution(db, test, current_user)
|
test = wf_start_execution(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -454,7 +338,7 @@ def submit_red(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||||
):
|
):
|
||||||
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_submit_red(db, test, current_user)
|
test = wf_submit_red(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -474,7 +358,7 @@ def submit_blue(
|
|||||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_submit_blue(db, test, current_user)
|
test = wf_submit_blue(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -494,7 +378,7 @@ def pause_timer(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_pause_timer(db, test, current_user)
|
test = wf_pause_timer(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -514,7 +398,7 @@ def resume_timer(
|
|||||||
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Resume the paused timer for the current phase."""
|
"""Resume the paused timer for the current phase."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_resume_timer(db, test, current_user)
|
test = wf_resume_timer(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -535,7 +419,7 @@ def validate_red(
|
|||||||
current_user: User = Depends(require_any_role("red_lead")),
|
current_user: User = Depends(require_any_role("red_lead")),
|
||||||
):
|
):
|
||||||
"""Red Lead approves or rejects the red side of a test."""
|
"""Red Lead approves or rejects the red side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = crud_get_test_with_technique(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_validate_red(
|
test = wf_validate_red(
|
||||||
db, test, current_user,
|
db, test, current_user,
|
||||||
@@ -562,7 +446,7 @@ def validate_blue(
|
|||||||
current_user: User = Depends(require_any_role("blue_lead")),
|
current_user: User = Depends(require_any_role("blue_lead")),
|
||||||
):
|
):
|
||||||
"""Blue Lead approves or rejects the blue side of a test."""
|
"""Blue Lead approves or rejects the blue side of a test."""
|
||||||
test = _get_test_with_technique(db, test_id)
|
test = crud_get_test_with_technique(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_validate_blue(
|
test = wf_validate_blue(
|
||||||
db, test, current_user,
|
db, test, current_user,
|
||||||
@@ -588,7 +472,7 @@ def reopen(
|
|||||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||||
):
|
):
|
||||||
"""Reopen a rejected test, moving it back to ``draft``."""
|
"""Reopen a rejected test, moving it back to ``draft``."""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
with UnitOfWork(db) as uow:
|
with UnitOfWork(db) as uow:
|
||||||
test = wf_reopen(db, test, current_user)
|
test = wf_reopen(db, test, current_user)
|
||||||
uow.commit()
|
uow.commit()
|
||||||
@@ -613,7 +497,7 @@ def update_remediation(
|
|||||||
When ``remediation_status`` transitions to ``'completed'``, an automatic
|
When ``remediation_status`` transitions to ``'completed'``, an automatic
|
||||||
re-test is created (subject to ``MAX_RETEST_COUNT``).
|
re-test is created (subject to ``MAX_RETEST_COUNT``).
|
||||||
"""
|
"""
|
||||||
test = _get_test_or_404(db, test_id)
|
test = crud_get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
old_remediation_status = test.remediation_status
|
old_remediation_status = test.remediation_status
|
||||||
|
|
||||||
@@ -653,29 +537,7 @@ def get_test_timeline(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Return the chronological audit-log history for a test."""
|
"""Return the chronological audit-log history for a test."""
|
||||||
# Verify the test exists
|
return crud_get_test_timeline(db, test_id)
|
||||||
_get_test_or_404(db, test_id)
|
|
||||||
|
|
||||||
logs = (
|
|
||||||
db.query(AuditLog)
|
|
||||||
.filter(
|
|
||||||
AuditLog.entity_type == "test",
|
|
||||||
AuditLog.entity_id == str(test_id),
|
|
||||||
)
|
|
||||||
.order_by(AuditLog.timestamp.asc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": str(log.id),
|
|
||||||
"action": log.action,
|
|
||||||
"user_id": str(log.user_id) if log.user_id else None,
|
|
||||||
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
|
||||||
"details": log.details,
|
|
||||||
}
|
|
||||||
for log in logs
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -0,0 +1,460 @@
|
|||||||
|
"""Campaign CRUD service — list, create, update, and business logic.
|
||||||
|
|
||||||
|
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||||
|
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import (
|
||||||
|
BusinessRuleViolation,
|
||||||
|
EntityNotFoundError,
|
||||||
|
PermissionViolation,
|
||||||
|
)
|
||||||
|
from app.models.campaign import Campaign, CampaignTest
|
||||||
|
from app.models.test import Test
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.utils import escape_like
|
||||||
|
from app.services.campaign_service import (
|
||||||
|
get_campaign_progress,
|
||||||
|
validate_no_circular_dependency,
|
||||||
|
TACTIC_TO_PHASE,
|
||||||
|
)
|
||||||
|
from app.services.campaign_scheduler_service import calculate_next_run
|
||||||
|
|
||||||
|
|
||||||
|
# ── Serialization helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||||
|
"""Serialize a campaign with its tests and progress."""
|
||||||
|
progress = get_campaign_progress(db, campaign.id)
|
||||||
|
|
||||||
|
campaign_tests = (
|
||||||
|
db.query(CampaignTest)
|
||||||
|
.filter(CampaignTest.campaign_id == campaign.id)
|
||||||
|
.order_by(CampaignTest.order_index)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
tests = []
|
||||||
|
for ct in campaign_tests:
|
||||||
|
test = ct.test
|
||||||
|
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
||||||
|
|
||||||
|
tests.append({
|
||||||
|
"id": str(ct.id),
|
||||||
|
"test_id": str(ct.test_id),
|
||||||
|
"order_index": ct.order_index,
|
||||||
|
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
||||||
|
"phase": ct.phase,
|
||||||
|
"test_name": test.name if test else None,
|
||||||
|
"test_state": test.state.value if test and test.state else None,
|
||||||
|
"test_result": test.result.value if test and test.result else None,
|
||||||
|
"technique_mitre_id": technique.mitre_id if technique else None,
|
||||||
|
"technique_name": technique.name if technique else None,
|
||||||
|
"platform": test.platform if test else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
actor = campaign.threat_actor
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(campaign.id),
|
||||||
|
"name": campaign.name,
|
||||||
|
"description": campaign.description,
|
||||||
|
"type": campaign.type,
|
||||||
|
"status": campaign.status,
|
||||||
|
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||||
|
"threat_actor_name": actor.name if actor else None,
|
||||||
|
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||||
|
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||||
|
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||||
|
"target_platform": campaign.target_platform,
|
||||||
|
"tags": campaign.tags or [],
|
||||||
|
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||||
|
"is_recurring": campaign.is_recurring or False,
|
||||||
|
"recurrence_pattern": campaign.recurrence_pattern,
|
||||||
|
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||||
|
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||||
|
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
||||||
|
"tests": tests,
|
||||||
|
"progress": progress,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||||
|
"""Lightweight campaign serialization for list views."""
|
||||||
|
progress = get_campaign_progress(db, campaign.id)
|
||||||
|
actor = campaign.threat_actor
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(campaign.id),
|
||||||
|
"name": campaign.name,
|
||||||
|
"description": campaign.description,
|
||||||
|
"type": campaign.type,
|
||||||
|
"status": campaign.status,
|
||||||
|
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||||
|
"threat_actor_name": actor.name if actor else None,
|
||||||
|
"target_platform": campaign.target_platform,
|
||||||
|
"tags": campaign.tags or [],
|
||||||
|
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||||
|
"is_recurring": campaign.is_recurring or False,
|
||||||
|
"recurrence_pattern": campaign.recurrence_pattern,
|
||||||
|
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||||
|
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||||
|
"test_count": progress["total"],
|
||||||
|
"completion_pct": progress["completion_pct"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── CRUD operations ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def list_campaigns(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
type: Optional[str] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
threat_actor_id: Optional[str] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""Return a paginated list of campaigns with optional filters."""
|
||||||
|
query = db.query(Campaign)
|
||||||
|
|
||||||
|
if type:
|
||||||
|
query = query.filter(Campaign.type == type)
|
||||||
|
if status:
|
||||||
|
query = query.filter(Campaign.status == status)
|
||||||
|
if threat_actor_id:
|
||||||
|
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||||
|
if search:
|
||||||
|
pattern = f"%{escape_like(search)}%"
|
||||||
|
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"offset": offset,
|
||||||
|
"limit": limit,
|
||||||
|
"items": [serialize_campaign_summary(db, c) for c in campaigns],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_campaign(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
creator_id: uuid.UUID,
|
||||||
|
name: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
type: str = "custom",
|
||||||
|
threat_actor_id: Optional[str] = None,
|
||||||
|
target_platform: Optional[str] = None,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
scheduled_at: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a new campaign. Does not commit; caller commits."""
|
||||||
|
campaign = Campaign(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
type=type,
|
||||||
|
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
|
||||||
|
target_platform=target_platform,
|
||||||
|
tags=tags or [],
|
||||||
|
created_by=creator_id,
|
||||||
|
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
|
||||||
|
)
|
||||||
|
db.add(campaign)
|
||||||
|
db.flush()
|
||||||
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
|
||||||
|
"""Get detailed campaign info including tests and progress.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if campaign not found.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
|
def update_campaign(
|
||||||
|
db: Session,
|
||||||
|
campaign_id: str,
|
||||||
|
*,
|
||||||
|
updater_id: uuid.UUID,
|
||||||
|
updater_role: str,
|
||||||
|
**fields,
|
||||||
|
) -> dict:
|
||||||
|
"""Update a campaign. Only allowed in draft or active state.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status not in ("draft", "active"):
|
||||||
|
raise BusinessRuleViolation("Can only update draft or active campaigns")
|
||||||
|
|
||||||
|
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
|
||||||
|
raise PermissionViolation("Only the creator or admin can update this campaign")
|
||||||
|
|
||||||
|
if "scheduled_at" in fields and fields["scheduled_at"]:
|
||||||
|
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(campaign, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return serialize_campaign(db, campaign)
|
||||||
|
|
||||||
|
|
||||||
|
def add_test_to_campaign(
|
||||||
|
db: Session,
|
||||||
|
campaign_id: str,
|
||||||
|
*,
|
||||||
|
test_id: str,
|
||||||
|
order_index: Optional[int] = None,
|
||||||
|
depends_on: Optional[str] = None,
|
||||||
|
phase: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Add a test to a campaign with optional ordering and dependency.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError for missing campaign or test.
|
||||||
|
Raises BusinessRuleViolation for invalid state or circular dependency.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status not in ("draft", "active"):
|
||||||
|
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
|
||||||
|
|
||||||
|
test = db.query(Test).filter(Test.id == test_id).first()
|
||||||
|
if not test:
|
||||||
|
raise EntityNotFoundError("Test", test_id)
|
||||||
|
|
||||||
|
if order_index is not None:
|
||||||
|
final_order_index = order_index
|
||||||
|
else:
|
||||||
|
max_order = (
|
||||||
|
db.query(CampaignTest.order_index)
|
||||||
|
.filter(CampaignTest.campaign_id == campaign_id)
|
||||||
|
.order_by(CampaignTest.order_index.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
final_order_index = (max_order[0] + 1) if max_order else 0
|
||||||
|
|
||||||
|
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
|
||||||
|
|
||||||
|
ct_id = uuid.uuid4()
|
||||||
|
if depends_on_uuid:
|
||||||
|
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
|
||||||
|
|
||||||
|
if not phase and test.technique_id:
|
||||||
|
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||||
|
if technique and technique.tactic:
|
||||||
|
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
||||||
|
|
||||||
|
campaign_test = CampaignTest(
|
||||||
|
id=ct_id,
|
||||||
|
campaign_id=campaign_id,
|
||||||
|
test_id=test_id,
|
||||||
|
order_index=final_order_index,
|
||||||
|
depends_on=depends_on_uuid,
|
||||||
|
phase=phase,
|
||||||
|
)
|
||||||
|
db.add(campaign_test)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(campaign_test.id),
|
||||||
|
"campaign_id": str(campaign_test.campaign_id),
|
||||||
|
"test_id": str(campaign_test.test_id),
|
||||||
|
"order_index": campaign_test.order_index,
|
||||||
|
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
||||||
|
"phase": campaign_test.phase,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
|
||||||
|
"""Remove a test from a campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError for missing campaign or campaign test.
|
||||||
|
Raises BusinessRuleViolation for invalid state.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status not in ("draft", "active"):
|
||||||
|
raise BusinessRuleViolation("Can only modify draft or active campaigns")
|
||||||
|
|
||||||
|
ct = (
|
||||||
|
db.query(CampaignTest)
|
||||||
|
.filter(
|
||||||
|
CampaignTest.id == campaign_test_id,
|
||||||
|
CampaignTest.campaign_id == campaign_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not ct:
|
||||||
|
raise EntityNotFoundError("CampaignTest", campaign_test_id)
|
||||||
|
|
||||||
|
dep_id = uuid.UUID(campaign_test_id)
|
||||||
|
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
|
||||||
|
for dep in dependents:
|
||||||
|
dep.depends_on = None
|
||||||
|
|
||||||
|
db.delete(ct)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||||
|
"""Activate a campaign, moving it from draft to active.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status != "draft":
|
||||||
|
raise BusinessRuleViolation("Only draft campaigns can be activated")
|
||||||
|
|
||||||
|
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
||||||
|
if test_count == 0:
|
||||||
|
raise BusinessRuleViolation("Campaign must have at least one test to activate")
|
||||||
|
|
||||||
|
campaign.status = "active"
|
||||||
|
db.flush()
|
||||||
|
return campaign
|
||||||
|
|
||||||
|
|
||||||
|
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||||
|
"""Mark a campaign as completed.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if campaign.status != "active":
|
||||||
|
raise BusinessRuleViolation("Only active campaigns can be completed")
|
||||||
|
|
||||||
|
campaign.status = "completed"
|
||||||
|
campaign.completed_at = datetime.utcnow()
|
||||||
|
db.flush()
|
||||||
|
return campaign
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
|
||||||
|
"""Get progress statistics for a campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if campaign not found.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||||
|
return {
|
||||||
|
"campaign_id": str(campaign.id),
|
||||||
|
"campaign_name": campaign.name,
|
||||||
|
**progress,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_campaign(
|
||||||
|
db: Session,
|
||||||
|
campaign_id: str,
|
||||||
|
*,
|
||||||
|
owner_id: uuid.UUID,
|
||||||
|
owner_role: str,
|
||||||
|
is_recurring: bool,
|
||||||
|
recurrence_pattern: Optional[str] = None,
|
||||||
|
next_run_at: Optional[str] = None,
|
||||||
|
) -> Campaign:
|
||||||
|
"""Configure or update the recurrence schedule for a campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
|
||||||
|
Does not commit; caller commits.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
|
||||||
|
raise PermissionViolation("Only the creator or admin can configure scheduling")
|
||||||
|
|
||||||
|
campaign.is_recurring = is_recurring
|
||||||
|
|
||||||
|
if is_recurring:
|
||||||
|
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
|
||||||
|
)
|
||||||
|
campaign.recurrence_pattern = recurrence_pattern
|
||||||
|
if next_run_at:
|
||||||
|
campaign.next_run_at = datetime.fromisoformat(
|
||||||
|
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
|
||||||
|
)
|
||||||
|
elif not campaign.next_run_at:
|
||||||
|
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
|
||||||
|
else:
|
||||||
|
campaign.recurrence_pattern = None
|
||||||
|
campaign.next_run_at = None
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return campaign
|
||||||
|
|
||||||
|
|
||||||
|
def get_campaign_history(db: Session, campaign_id: str) -> dict:
|
||||||
|
"""List all child campaigns (execution history) of a recurring campaign.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if campaign not found.
|
||||||
|
"""
|
||||||
|
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||||
|
if not campaign:
|
||||||
|
raise EntityNotFoundError("Campaign", campaign_id)
|
||||||
|
|
||||||
|
campaign_uuid = uuid.UUID(campaign_id)
|
||||||
|
children = (
|
||||||
|
db.query(Campaign)
|
||||||
|
.filter(Campaign.parent_campaign_id == campaign_uuid)
|
||||||
|
.order_by(Campaign.created_at.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"campaign_id": str(campaign.id),
|
||||||
|
"campaign_name": campaign.name,
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"id": str(child.id),
|
||||||
|
"name": child.name,
|
||||||
|
"status": child.status,
|
||||||
|
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
||||||
|
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
||||||
|
"created_at": child.created_at.isoformat() if child.created_at else None,
|
||||||
|
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
||||||
|
}
|
||||||
|
for child in children
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -0,0 +1,167 @@
|
|||||||
|
"""Evidence service — permission validation, file validation, and query logic.
|
||||||
|
|
||||||
|
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||||
|
The router is responsible for HTTP concerns, file I/O, MinIO upload,
|
||||||
|
audit logging, and response formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.domain.errors import (
|
||||||
|
BusinessRuleViolation,
|
||||||
|
EntityNotFoundError,
|
||||||
|
PermissionViolation,
|
||||||
|
)
|
||||||
|
from app.models.enums import TeamSide, TestState
|
||||||
|
from app.models.evidence import Evidence
|
||||||
|
from app.models.test import Test
|
||||||
|
|
||||||
|
# States where red evidence can be uploaded / deleted
|
||||||
|
RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
|
||||||
|
# States where blue evidence can be uploaded / deleted
|
||||||
|
BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
|
||||||
|
|
||||||
|
# Maximum upload size in bytes (50 MB)
|
||||||
|
MAX_UPLOAD_SIZE = 50 * 1024 * 1024
|
||||||
|
|
||||||
|
# Allowed file extensions (lowercase, with leading dot)
|
||||||
|
ALLOWED_EXTENSIONS: frozenset[str] = frozenset({
|
||||||
|
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
|
||||||
|
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
|
||||||
|
".md", ".rtf", ".odt", ".ods",
|
||||||
|
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
|
||||||
|
".yaml", ".yml", ".toml",
|
||||||
|
".zip", ".tar", ".gz", ".7z",
|
||||||
|
".har", ".eml", ".msg",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def validate_upload_permission(
|
||||||
|
test: Test,
|
||||||
|
team: TeamSide,
|
||||||
|
user_role: str,
|
||||||
|
) -> None:
|
||||||
|
"""Validate that the user can upload evidence for the given team in the current state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionViolation: If user lacks role to upload for this team.
|
||||||
|
BusinessRuleViolation: If test state does not allow uploading for this team.
|
||||||
|
"""
|
||||||
|
if user_role == "admin":
|
||||||
|
return
|
||||||
|
|
||||||
|
if team == TeamSide.red:
|
||||||
|
if user_role not in ("red_tech", "red_lead"):
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Only red_tech, red_lead or admin can upload red evidence"
|
||||||
|
)
|
||||||
|
if test.state not in RED_EDITABLE_STATES:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot upload red evidence in '{test.state.value}' state "
|
||||||
|
"(allowed in: draft, red_executing)"
|
||||||
|
)
|
||||||
|
elif team == TeamSide.blue:
|
||||||
|
if user_role not in ("blue_tech", "blue_lead"):
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Only blue_tech, blue_lead or admin can upload blue evidence"
|
||||||
|
)
|
||||||
|
if test.state not in BLUE_EDITABLE_STATES:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot upload blue evidence in '{test.state.value}' state "
|
||||||
|
"(allowed in: blue_evaluating)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_delete_permission(
|
||||||
|
test: Test,
|
||||||
|
evidence: Evidence,
|
||||||
|
user_role: str,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
) -> None:
|
||||||
|
"""Validate that the user can delete this evidence in the current state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionViolation: If user cannot delete in this state or lacks permission.
|
||||||
|
"""
|
||||||
|
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
|
||||||
|
raise PermissionViolation(
|
||||||
|
f"Cannot delete evidence when test is in '{test.state.value}' state"
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_role == "admin":
|
||||||
|
return
|
||||||
|
|
||||||
|
ev_team = evidence.team
|
||||||
|
|
||||||
|
if ev_team == TeamSide.red:
|
||||||
|
if test.state not in RED_EDITABLE_STATES:
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Cannot delete red evidence outside draft/red_executing"
|
||||||
|
)
|
||||||
|
if user_role not in ("red_tech", "red_lead") and evidence.uploaded_by != user_id:
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Not enough permissions to delete this evidence"
|
||||||
|
)
|
||||||
|
elif ev_team == TeamSide.blue:
|
||||||
|
if test.state not in BLUE_EDITABLE_STATES:
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Cannot delete blue evidence outside blue_evaluating"
|
||||||
|
)
|
||||||
|
if user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user_id:
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Not enough permissions to delete this evidence"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_file(file_name: str, content_size: int) -> None:
|
||||||
|
"""Validate file extension and size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessRuleViolation: If extension is not allowed or file exceeds size limit.
|
||||||
|
"""
|
||||||
|
_, ext = os.path.splitext(file_name)
|
||||||
|
ext_lower = ext.lower() if ext else ""
|
||||||
|
if ext_lower not in ALLOWED_EXTENSIONS:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"File type '{ext}' is not allowed. "
|
||||||
|
f"Permitted types: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
|
||||||
|
)
|
||||||
|
if content_size > MAX_UPLOAD_SIZE:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"File exceeds maximum upload size of {MAX_UPLOAD_SIZE // (1024 * 1024)} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_evidence_for_test(
|
||||||
|
db: Session,
|
||||||
|
test_id: uuid.UUID,
|
||||||
|
*,
|
||||||
|
team: TeamSide | str | None = None,
|
||||||
|
) -> list[Evidence]:
|
||||||
|
"""Return evidence for a test, optionally filtered by team."""
|
||||||
|
query = db.query(Evidence).filter(Evidence.test_id == test_id)
|
||||||
|
if team is not None:
|
||||||
|
team_enum = TeamSide(team) if isinstance(team, str) else team
|
||||||
|
query = query.filter(Evidence.team == team_enum)
|
||||||
|
return query.order_by(Evidence.uploaded_at.desc()).all()
|
||||||
|
|
||||||
|
|
||||||
|
def get_evidence_or_raise(db: Session, evidence_id: uuid.UUID) -> Evidence:
|
||||||
|
"""Fetch evidence by ID. Raises EntityNotFoundError if not found."""
|
||||||
|
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||||
|
if evidence is None:
|
||||||
|
raise EntityNotFoundError("Evidence", str(evidence_id))
|
||||||
|
return evidence
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||||
|
"""Fetch test by ID. Raises EntityNotFoundError if not found."""
|
||||||
|
test = db.query(Test).filter(Test.id == test_id).first()
|
||||||
|
if test is None:
|
||||||
|
raise EntityNotFoundError("Test", str(test_id))
|
||||||
|
return test
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Scoring configuration persistence service.
|
||||||
|
|
||||||
|
Reads and writes scoring weights from the ``scoring_config`` table.
|
||||||
|
Falls back to environment-variable defaults (from ``Settings``) when
|
||||||
|
no row has been persisted yet.
|
||||||
|
|
||||||
|
This module is framework-agnostic: no FastAPI imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||||
|
from app.models.scoring_config import ScoringConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_scoring_weights(db: Session) -> ScoringWeights:
|
||||||
|
"""Return the active scoring weights.
|
||||||
|
|
||||||
|
Reads the single ``scoring_config`` row. If the table is empty
|
||||||
|
(first run or migration just applied), falls back to the values
|
||||||
|
from the environment / ``Settings``.
|
||||||
|
"""
|
||||||
|
row = db.query(ScoringConfig).first()
|
||||||
|
if row is not None:
|
||||||
|
return ScoringWeights(
|
||||||
|
tests=row.weight_tests,
|
||||||
|
detection_rules=row.weight_detection_rules,
|
||||||
|
d3fend=row.weight_d3fend,
|
||||||
|
freshness=row.weight_freshness,
|
||||||
|
platform_diversity=row.weight_platform_diversity,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScoringWeights(
|
||||||
|
tests=float(settings.SCORING_WEIGHT_TESTS),
|
||||||
|
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
|
||||||
|
d3fend=float(settings.SCORING_WEIGHT_D3FEND),
|
||||||
|
freshness=float(settings.SCORING_WEIGHT_FRESHNESS),
|
||||||
|
platform_diversity=float(settings.SCORING_WEIGHT_PLATFORM_DIVERSITY),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_scoring_weights(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
tests: float | None = None,
|
||||||
|
detection_rules: float | None = None,
|
||||||
|
d3fend: float | None = None,
|
||||||
|
freshness: float | None = None,
|
||||||
|
platform_diversity: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Upsert scoring weights into the database.
|
||||||
|
|
||||||
|
Only provided fields are overwritten; ``None`` values keep the
|
||||||
|
current (or default) value. Validates via ``ScoringWeights``
|
||||||
|
before persisting.
|
||||||
|
|
||||||
|
Returns a dict with ``weights`` and ``total``.
|
||||||
|
"""
|
||||||
|
current = get_scoring_weights(db)
|
||||||
|
|
||||||
|
new = ScoringWeights(
|
||||||
|
tests=tests if tests is not None else current.tests,
|
||||||
|
detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
|
||||||
|
d3fend=d3fend if d3fend is not None else current.d3fend,
|
||||||
|
freshness=freshness if freshness is not None else current.freshness,
|
||||||
|
platform_diversity=platform_diversity if platform_diversity is not None else current.platform_diversity,
|
||||||
|
)
|
||||||
|
|
||||||
|
row = db.query(ScoringConfig).first()
|
||||||
|
if row is None:
|
||||||
|
row = ScoringConfig()
|
||||||
|
db.add(row)
|
||||||
|
|
||||||
|
row.weight_tests = new.tests
|
||||||
|
row.weight_detection_rules = new.detection_rules
|
||||||
|
row.weight_d3fend = new.d3fend
|
||||||
|
row.weight_freshness = new.freshness
|
||||||
|
row.weight_platform_diversity = new.platform_diversity
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(row)
|
||||||
|
|
||||||
|
return _weights_dict(new)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weights_dict(db: Session) -> dict[str, Any]:
|
||||||
|
"""Return current weights as a serialisable dict."""
|
||||||
|
return _weights_dict(get_scoring_weights(db))
|
||||||
|
|
||||||
|
|
||||||
|
def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
|
||||||
|
weights = {
|
||||||
|
"tests": w.tests,
|
||||||
|
"detection_rules": w.detection_rules,
|
||||||
|
"d3fend": w.d3fend,
|
||||||
|
"freshness": w.freshness,
|
||||||
|
"platform_diversity": w.platform_diversity,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"weights": weights,
|
||||||
|
"total": sum(weights.values()),
|
||||||
|
}
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org.
|
"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org.
|
||||||
|
|
||||||
Uses configurable weights from Settings to compute coverage scores with
|
Reads configurable weights from the ``scoring_config`` table (falling
|
||||||
detailed breakdowns.
|
back to env-var defaults) to compute coverage scores with detailed
|
||||||
|
breakdowns.
|
||||||
|
|
||||||
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
|
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
|
||||||
fixed number of aggregated queries so that organisation-wide calculations
|
fixed number of aggregated queries so that organisation-wide calculations
|
||||||
@@ -14,7 +15,6 @@ from typing import Optional
|
|||||||
from sqlalchemy import case, func
|
from sqlalchemy import case, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.config import settings
|
|
||||||
from app.models.technique import Technique
|
from app.models.technique import Technique
|
||||||
from app.models.test import Test
|
from app.models.test import Test
|
||||||
from app.models.detection_rule import DetectionRule
|
from app.models.detection_rule import DetectionRule
|
||||||
@@ -22,20 +22,12 @@ from app.models.test_detection_result import TestDetectionResult
|
|||||||
from app.models.defensive_technique import DefensiveTechniqueMapping
|
from app.models.defensive_technique import DefensiveTechniqueMapping
|
||||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||||
from app.models.enums import TestState, TestResult
|
from app.models.enums import TestState, TestResult
|
||||||
|
from app.services.scoring_config_service import get_scoring_weights
|
||||||
|
|
||||||
|
|
||||||
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
|
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
|
||||||
|
|
||||||
|
|
||||||
def _build_empty_stats():
|
|
||||||
return {
|
|
||||||
"validated": 0,
|
|
||||||
"detected": 0,
|
|
||||||
"platforms": set(),
|
|
||||||
"latest_validated_at": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def bulk_technique_scores(db: Session) -> dict:
|
def bulk_technique_scores(db: Session) -> dict:
|
||||||
"""Pre-fetch all scoring data and compute per-technique scores in memory.
|
"""Pre-fetch all scoring data and compute per-technique scores in memory.
|
||||||
|
|
||||||
@@ -48,11 +40,12 @@ def bulk_technique_scores(db: Session) -> dict:
|
|||||||
|
|
||||||
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
|
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
|
||||||
"""
|
"""
|
||||||
w_tests = settings.SCORING_WEIGHT_TESTS
|
w = get_scoring_weights(db)
|
||||||
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
|
w_tests = w.tests
|
||||||
w_d3fend = settings.SCORING_WEIGHT_D3FEND
|
w_detection = w.detection_rules
|
||||||
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
|
w_d3fend = w.d3fend
|
||||||
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
w_freshness = w.freshness
|
||||||
|
w_diversity = w.platform_diversity
|
||||||
|
|
||||||
# Q1: test stats grouped by technique_id
|
# Q1: test stats grouped by technique_id
|
||||||
test_rows = (
|
test_rows = (
|
||||||
@@ -242,18 +235,14 @@ def bulk_technique_scores(db: Session) -> dict:
|
|||||||
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||||
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
||||||
|
|
||||||
Weights (configurable via settings):
|
Weights are read from the ``scoring_config`` table (or env defaults).
|
||||||
- tests_validated: weight from SCORING_WEIGHT_TESTS
|
|
||||||
- detection_rules: weight from SCORING_WEIGHT_DETECTION_RULES
|
|
||||||
- d3fend_coverage: weight from SCORING_WEIGHT_D3FEND
|
|
||||||
- freshness: weight from SCORING_WEIGHT_FRESHNESS
|
|
||||||
- platform_diversity: weight from SCORING_WEIGHT_PLATFORM_DIVERSITY
|
|
||||||
"""
|
"""
|
||||||
w_tests = settings.SCORING_WEIGHT_TESTS
|
w = get_scoring_weights(db)
|
||||||
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
|
w_tests = w.tests
|
||||||
w_d3fend = settings.SCORING_WEIGHT_D3FEND
|
w_detection = w.detection_rules
|
||||||
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
|
w_d3fend = w.d3fend
|
||||||
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
w_freshness = w.freshness
|
||||||
|
w_diversity = w.platform_diversity
|
||||||
|
|
||||||
breakdown = {}
|
breakdown = {}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,277 @@
|
|||||||
|
"""Test CRUD service — list, create, update, and query logic for security tests.
|
||||||
|
|
||||||
|
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||||
|
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.domain.errors import (
|
||||||
|
BusinessRuleViolation,
|
||||||
|
EntityNotFoundError,
|
||||||
|
PermissionViolation,
|
||||||
|
)
|
||||||
|
from app.models.enums import TestState
|
||||||
|
from app.models.technique import Technique
|
||||||
|
from app.models.test import Test
|
||||||
|
from app.models.test_template import TestTemplate
|
||||||
|
from app.models.audit import AuditLog
|
||||||
|
from app.utils import escape_like
|
||||||
|
|
||||||
|
|
||||||
|
def list_tests(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
state: str | None = None,
|
||||||
|
technique_id: uuid.UUID | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
created_by: uuid.UUID | None = None,
|
||||||
|
pending_validation_side: str | None = None,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> list[Test]:
|
||||||
|
"""Return a paginated list of tests with optional filters."""
|
||||||
|
query = db.query(Test).options(joinedload(Test.technique))
|
||||||
|
|
||||||
|
if state:
|
||||||
|
query = query.filter(Test.state == state)
|
||||||
|
if technique_id:
|
||||||
|
query = query.filter(Test.technique_id == technique_id)
|
||||||
|
if platform:
|
||||||
|
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||||
|
if created_by:
|
||||||
|
query = query.filter(Test.created_by == created_by)
|
||||||
|
if pending_validation_side == "red":
|
||||||
|
query = query.filter(
|
||||||
|
Test.state == TestState.in_review,
|
||||||
|
Test.red_validation_status.in_(["pending", None]),
|
||||||
|
)
|
||||||
|
elif pending_validation_side == "blue":
|
||||||
|
query = query.filter(
|
||||||
|
Test.state == TestState.in_review,
|
||||||
|
Test.blue_validation_status.in_(["pending", None]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
||||||
|
|
||||||
|
|
||||||
|
def create_test(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
technique_id: uuid.UUID,
|
||||||
|
creator_id: uuid.UUID,
|
||||||
|
**fields: Any,
|
||||||
|
) -> Test:
|
||||||
|
"""Create a new test linked to an existing technique.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if the technique does not exist.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", str(technique_id))
|
||||||
|
|
||||||
|
test = Test(
|
||||||
|
technique_id=technique_id,
|
||||||
|
created_by=creator_id,
|
||||||
|
state=TestState.draft,
|
||||||
|
**fields,
|
||||||
|
)
|
||||||
|
db.add(test)
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_from_template(
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
template_id: uuid.UUID,
|
||||||
|
technique_id_or_mitre: str,
|
||||||
|
creator_id: uuid.UUID,
|
||||||
|
) -> Test:
|
||||||
|
"""Instantiate a Test from a TestTemplate.
|
||||||
|
|
||||||
|
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
|
||||||
|
Raises EntityNotFoundError if template or technique not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||||
|
if template is None:
|
||||||
|
raise EntityNotFoundError("TestTemplate", str(template_id))
|
||||||
|
|
||||||
|
technique = None
|
||||||
|
try:
|
||||||
|
technique_uuid = uuid.UUID(technique_id_or_mitre)
|
||||||
|
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if technique is None:
|
||||||
|
technique = db.query(Technique).filter(
|
||||||
|
Technique.mitre_id == technique_id_or_mitre
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if technique is None:
|
||||||
|
raise EntityNotFoundError("Technique", technique_id_or_mitre)
|
||||||
|
|
||||||
|
test = Test(
|
||||||
|
technique_id=technique.id,
|
||||||
|
name=template.name,
|
||||||
|
description=template.description,
|
||||||
|
platform=template.platform,
|
||||||
|
procedure_text=template.attack_procedure,
|
||||||
|
tool_used=template.tool_suggested,
|
||||||
|
remediation_steps=template.suggested_remediation,
|
||||||
|
created_by=creator_id,
|
||||||
|
state=TestState.draft,
|
||||||
|
)
|
||||||
|
db.add(test)
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
|
||||||
|
"""Fetch a test with evidences eager-loaded.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if the test does not exist.
|
||||||
|
"""
|
||||||
|
test = (
|
||||||
|
db.query(Test)
|
||||||
|
.options(joinedload(Test.evidences))
|
||||||
|
.filter(Test.id == test_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if test is None:
|
||||||
|
raise EntityNotFoundError("Test", str(test_id))
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||||
|
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
|
||||||
|
test = db.query(Test).filter(Test.id == test_id).first()
|
||||||
|
if test is None:
|
||||||
|
raise EntityNotFoundError("Test", str(test_id))
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
||||||
|
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
|
||||||
|
test = (
|
||||||
|
db.query(Test)
|
||||||
|
.options(joinedload(Test.technique))
|
||||||
|
.filter(Test.id == test_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if test is None:
|
||||||
|
raise EntityNotFoundError("Test", str(test_id))
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def update_test(
|
||||||
|
db: Session,
|
||||||
|
test_id: uuid.UUID,
|
||||||
|
*,
|
||||||
|
updater_id: uuid.UUID,
|
||||||
|
updater_role: str,
|
||||||
|
**fields: Any,
|
||||||
|
) -> Test:
|
||||||
|
"""Update general test fields (draft or rejected only).
|
||||||
|
|
||||||
|
Raises PermissionViolation if not creator or admin.
|
||||||
|
Raises BusinessRuleViolation if state is not draft or rejected.
|
||||||
|
Raises EntityNotFoundError if test not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
test = get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
if updater_role != "admin" and test.created_by != updater_id:
|
||||||
|
raise PermissionViolation(
|
||||||
|
"Only the test creator or an admin can update this test"
|
||||||
|
)
|
||||||
|
|
||||||
|
if test.state not in (TestState.draft, TestState.rejected):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(test, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||||
|
"""Update Red Team fields (draft or red_executing only).
|
||||||
|
|
||||||
|
Raises BusinessRuleViolation if state not in (draft, red_executing).
|
||||||
|
Raises EntityNotFoundError if test not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
test = get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
if test.state not in (TestState.draft, TestState.red_executing):
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot update red fields in '{test.state.value}' state "
|
||||||
|
"(must be draft or red_executing)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(test, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||||
|
"""Update Blue Team fields (blue_evaluating only).
|
||||||
|
|
||||||
|
Raises BusinessRuleViolation if state is not blue_evaluating.
|
||||||
|
Raises EntityNotFoundError if test not found.
|
||||||
|
Does not commit; caller uses UnitOfWork.
|
||||||
|
"""
|
||||||
|
test = get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
if test.state != TestState.blue_evaluating:
|
||||||
|
raise BusinessRuleViolation(
|
||||||
|
f"Cannot update blue fields in '{test.state.value}' state "
|
||||||
|
"(must be blue_evaluating)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
setattr(test, field, value)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
|
||||||
|
"""Return chronological audit-log history for a test.
|
||||||
|
|
||||||
|
Raises EntityNotFoundError if the test does not exist.
|
||||||
|
"""
|
||||||
|
get_test_or_raise(db, test_id)
|
||||||
|
|
||||||
|
logs = (
|
||||||
|
db.query(AuditLog)
|
||||||
|
.filter(
|
||||||
|
AuditLog.entity_type == "test",
|
||||||
|
AuditLog.entity_id == str(test_id),
|
||||||
|
)
|
||||||
|
.order_by(AuditLog.timestamp.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(log.id),
|
||||||
|
"action": log.action,
|
||||||
|
"user_id": str(log.user_id) if log.user_id else None,
|
||||||
|
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
||||||
|
"details": log.details,
|
||||||
|
}
|
||||||
|
for log in logs
|
||||||
|
]
|
||||||
@@ -89,12 +89,12 @@ for mod_name in [
|
|||||||
# Imports
|
# Imports
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from app.domain.errors import PermissionViolation
|
||||||
from app.models.enums import TeamSide, TestState
|
from app.models.enums import TeamSide, TestState
|
||||||
from app.routers.evidence import (
|
from app.routers.evidence import router
|
||||||
router,
|
from app.services.evidence_service import (
|
||||||
_validate_upload_permission,
|
validate_delete_permission,
|
||||||
_validate_delete_permission,
|
validate_upload_permission,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -132,7 +132,7 @@ def test_red_tech_upload_red_in_red_executing():
|
|||||||
test = _make_test(TestState.red_executing)
|
test = _make_test(TestState.red_executing)
|
||||||
user = _make_user("red_tech")
|
user = _make_user("red_tech")
|
||||||
# Should not raise
|
# Should not raise
|
||||||
_validate_upload_permission(test, TeamSide.red, user)
|
validate_upload_permission(test, TeamSide.red, user.role)
|
||||||
print(" [PASS] red_tech can upload team=red in red_executing")
|
print(" [PASS] red_tech can upload team=red in red_executing")
|
||||||
|
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ def test_red_tech_upload_red_in_red_executing():
|
|||||||
def test_red_tech_upload_red_in_draft():
|
def test_red_tech_upload_red_in_draft():
|
||||||
test = _make_test(TestState.draft)
|
test = _make_test(TestState.draft)
|
||||||
user = _make_user("red_tech")
|
user = _make_user("red_tech")
|
||||||
_validate_upload_permission(test, TeamSide.red, user)
|
validate_upload_permission(test, TeamSide.red, user.role)
|
||||||
print(" [PASS] red_tech can upload team=red in draft")
|
print(" [PASS] red_tech can upload team=red in draft")
|
||||||
|
|
||||||
|
|
||||||
@@ -157,10 +157,10 @@ def test_red_tech_cannot_upload_blue():
|
|||||||
test = _make_test(TestState.red_executing)
|
test = _make_test(TestState.red_executing)
|
||||||
user = _make_user("red_tech")
|
user = _make_user("red_tech")
|
||||||
try:
|
try:
|
||||||
_validate_upload_permission(test, TeamSide.blue, user)
|
validate_upload_permission(test, TeamSide.blue, user.role)
|
||||||
assert False, "Should have raised HTTPException"
|
assert False, "Should have raised PermissionViolation"
|
||||||
except HTTPException as exc:
|
except PermissionViolation:
|
||||||
assert exc.status_code == 403
|
pass
|
||||||
print(" [PASS] red_tech CANNOT upload team=blue (403)")
|
print(" [PASS] red_tech CANNOT upload team=blue (403)")
|
||||||
|
|
||||||
|
|
||||||
@@ -172,7 +172,7 @@ def test_red_tech_cannot_upload_blue():
|
|||||||
def test_blue_tech_upload_blue_in_blue_evaluating():
|
def test_blue_tech_upload_blue_in_blue_evaluating():
|
||||||
test = _make_test(TestState.blue_evaluating)
|
test = _make_test(TestState.blue_evaluating)
|
||||||
user = _make_user("blue_tech")
|
user = _make_user("blue_tech")
|
||||||
_validate_upload_permission(test, TeamSide.blue, user)
|
validate_upload_permission(test, TeamSide.blue, user.role)
|
||||||
print(" [PASS] blue_tech can upload team=blue in blue_evaluating")
|
print(" [PASS] blue_tech can upload team=blue in blue_evaluating")
|
||||||
|
|
||||||
|
|
||||||
@@ -185,10 +185,10 @@ def test_blue_tech_cannot_upload_red():
|
|||||||
test = _make_test(TestState.blue_evaluating)
|
test = _make_test(TestState.blue_evaluating)
|
||||||
user = _make_user("blue_tech")
|
user = _make_user("blue_tech")
|
||||||
try:
|
try:
|
||||||
_validate_upload_permission(test, TeamSide.red, user)
|
validate_upload_permission(test, TeamSide.red, user.role)
|
||||||
assert False, "Should have raised HTTPException"
|
assert False, "Should have raised PermissionViolation"
|
||||||
except HTTPException as exc:
|
except PermissionViolation:
|
||||||
assert exc.status_code == 403
|
pass
|
||||||
print(" [PASS] blue_tech CANNOT upload team=red (403)")
|
print(" [PASS] blue_tech CANNOT upload team=red (403)")
|
||||||
|
|
||||||
|
|
||||||
@@ -223,10 +223,10 @@ def test_delete_in_review_fails():
|
|||||||
user = _make_user("red_tech")
|
user = _make_user("red_tech")
|
||||||
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
|
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
|
||||||
try:
|
try:
|
||||||
_validate_delete_permission(test, evidence, user)
|
validate_delete_permission(test, evidence, user.role, user.id)
|
||||||
assert False, "Should have raised HTTPException"
|
assert False, "Should have raised PermissionViolation"
|
||||||
except HTTPException as exc:
|
except PermissionViolation:
|
||||||
assert exc.status_code == 403
|
pass
|
||||||
print(" [PASS] DELETE in in_review -> 403")
|
print(" [PASS] DELETE in in_review -> 403")
|
||||||
|
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ def test_delete_red_evidence_in_red_executing():
|
|||||||
user = _make_user("red_tech")
|
user = _make_user("red_tech")
|
||||||
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
|
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
|
||||||
# Should not raise
|
# Should not raise
|
||||||
_validate_delete_permission(test, evidence, user)
|
validate_delete_permission(test, evidence, user.role, user.id)
|
||||||
print(" [PASS] DELETE red evidence in red_executing -> allowed")
|
print(" [PASS] DELETE red evidence in red_executing -> allowed")
|
||||||
|
|
||||||
|
|
||||||
@@ -254,11 +254,11 @@ def test_admin_bypass():
|
|||||||
|
|
||||||
# Red in blue_evaluating (normally blocked)
|
# Red in blue_evaluating (normally blocked)
|
||||||
test1 = _make_test(TestState.blue_evaluating)
|
test1 = _make_test(TestState.blue_evaluating)
|
||||||
_validate_upload_permission(test1, TeamSide.red, admin)
|
validate_upload_permission(test1, TeamSide.red, admin.role)
|
||||||
|
|
||||||
# Blue in draft (normally blocked)
|
# Blue in draft (normally blocked)
|
||||||
test2 = _make_test(TestState.draft)
|
test2 = _make_test(TestState.draft)
|
||||||
_validate_upload_permission(test2, TeamSide.blue, admin)
|
validate_upload_permission(test2, TeamSide.blue, admin.role)
|
||||||
|
|
||||||
print(" [PASS] Admin can upload any team in any state")
|
print(" [PASS] Admin can upload any team in any state")
|
||||||
|
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ from app.routers.test_templates import (
|
|||||||
toggle_template_active,
|
toggle_template_active,
|
||||||
template_stats,
|
template_stats,
|
||||||
)
|
)
|
||||||
from app.routers.tests import create_test_from_template
|
from app.services.test_crud_service import create_test_from_template as crud_create_from_template
|
||||||
from app.schemas.test_template import TestTemplateCreate
|
from app.schemas.test_template import TestTemplateCreate
|
||||||
|
|
||||||
|
|
||||||
@@ -174,7 +174,8 @@ def test_get_templates_by_technique():
|
|||||||
|
|
||||||
def test_instantiate_template():
|
def test_instantiate_template():
|
||||||
"""POST /tests/from-template creates a test pre-filled from template data."""
|
"""POST /tests/from-template creates a test pre-filled from template data."""
|
||||||
source = inspect.getsource(create_test_from_template)
|
# Template field copying lives in the service; router delegates to it
|
||||||
|
source = inspect.getsource(crud_create_from_template)
|
||||||
|
|
||||||
# Verify it reads from template and copies fields
|
# Verify it reads from template and copies fields
|
||||||
assert "template" in source, "Must reference template"
|
assert "template" in source, "Must reference template"
|
||||||
|
|||||||
@@ -419,56 +419,57 @@ def test_dual_validation_red_approves_blue_rejects(mock_log):
|
|||||||
|
|
||||||
def test_evidence_team_separation():
|
def test_evidence_team_separation():
|
||||||
"""Verify evidence router logic separates red and blue evidence correctly."""
|
"""Verify evidence router logic separates red and blue evidence correctly."""
|
||||||
from app.routers.evidence import _validate_upload_permission, _RED_EDITABLE_STATES, _BLUE_EDITABLE_STATES
|
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||||
|
from app.models.enums import TeamSide
|
||||||
|
from app.services.evidence_service import validate_upload_permission
|
||||||
|
|
||||||
# Red tech can upload red evidence in draft
|
# Red tech can upload red evidence in draft
|
||||||
test = _make_test(TestState.draft)
|
test = _make_test(TestState.draft)
|
||||||
red_user = _make_user("red_tech")
|
red_user = _make_user("red_tech")
|
||||||
red_user.role = "red_tech"
|
red_user.role = "red_tech"
|
||||||
from app.models.enums import TeamSide
|
validate_upload_permission(test, TeamSide.red, red_user.role) # should not raise
|
||||||
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
|
|
||||||
|
|
||||||
# Red tech can upload red evidence in red_executing
|
# Red tech can upload red evidence in red_executing
|
||||||
test.state = TestState.red_executing
|
test.state = TestState.red_executing
|
||||||
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
|
validate_upload_permission(test, TeamSide.red, red_user.role) # should not raise
|
||||||
|
|
||||||
# Red tech CANNOT upload red evidence in blue_evaluating
|
# Red tech CANNOT upload red evidence in blue_evaluating (state violation -> 400)
|
||||||
test.state = TestState.blue_evaluating
|
test.state = TestState.blue_evaluating
|
||||||
try:
|
try:
|
||||||
_validate_upload_permission(test, TeamSide.red, red_user)
|
validate_upload_permission(test, TeamSide.red, red_user.role)
|
||||||
assert False, "Should have raised HTTPException"
|
assert False, "Should have raised BusinessRuleViolation"
|
||||||
except HTTPException as exc:
|
except BusinessRuleViolation:
|
||||||
assert exc.status_code == 400
|
pass
|
||||||
|
|
||||||
# Red tech CANNOT upload blue evidence
|
# Red tech CANNOT upload blue evidence (role violation -> 403)
|
||||||
test.state = TestState.blue_evaluating
|
test.state = TestState.blue_evaluating
|
||||||
try:
|
try:
|
||||||
_validate_upload_permission(test, TeamSide.blue, red_user)
|
validate_upload_permission(test, TeamSide.blue, red_user.role)
|
||||||
assert False, "Should have raised HTTPException"
|
assert False, "Should have raised PermissionViolation"
|
||||||
except HTTPException as exc:
|
except PermissionViolation:
|
||||||
assert exc.status_code == 403
|
pass
|
||||||
|
|
||||||
# Blue tech can upload blue evidence in blue_evaluating
|
# Blue tech can upload blue evidence in blue_evaluating
|
||||||
test.state = TestState.blue_evaluating
|
test.state = TestState.blue_evaluating
|
||||||
blue_user = _make_user("blue_tech")
|
blue_user = _make_user("blue_tech")
|
||||||
blue_user.role = "blue_tech"
|
blue_user.role = "blue_tech"
|
||||||
_validate_upload_permission(test, TeamSide.blue, blue_user) # should not raise
|
validate_upload_permission(test, TeamSide.blue, blue_user.role) # should not raise
|
||||||
|
|
||||||
# Blue tech CANNOT upload blue evidence in draft
|
# Blue tech CANNOT upload blue evidence in draft (state violation -> 400)
|
||||||
test.state = TestState.draft
|
test.state = TestState.draft
|
||||||
try:
|
try:
|
||||||
_validate_upload_permission(test, TeamSide.blue, blue_user)
|
validate_upload_permission(test, TeamSide.blue, blue_user.role)
|
||||||
assert False, "Should have raised HTTPException"
|
assert False, "Should have raised BusinessRuleViolation"
|
||||||
except HTTPException as exc:
|
except BusinessRuleViolation:
|
||||||
assert exc.status_code == 400
|
pass
|
||||||
|
|
||||||
# Blue tech CANNOT upload red evidence
|
# Blue tech CANNOT upload red evidence (role violation -> 403)
|
||||||
test.state = TestState.draft
|
test.state = TestState.draft
|
||||||
try:
|
try:
|
||||||
_validate_upload_permission(test, TeamSide.red, blue_user)
|
validate_upload_permission(test, TeamSide.red, blue_user.role)
|
||||||
assert False, "Should have raised HTTPException"
|
assert False, "Should have raised PermissionViolation"
|
||||||
except HTTPException as exc:
|
except PermissionViolation:
|
||||||
assert exc.status_code == 403
|
pass
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
@@ -477,15 +478,15 @@ def test_evidence_team_separation():
|
|||||||
|
|
||||||
|
|
||||||
def test_red_edit_allowed_in_draft_and_red_executing():
|
def test_red_edit_allowed_in_draft_and_red_executing():
|
||||||
"""Verify the red update router checks that state is draft or red_executing."""
|
"""Verify the red update checks that state is draft or red_executing."""
|
||||||
from app.routers.tests import update_test_red
|
from app.services.test_crud_service import update_test_red
|
||||||
import inspect
|
import inspect
|
||||||
source = inspect.getsource(update_test_red)
|
source = inspect.getsource(update_test_red)
|
||||||
|
|
||||||
# The function must guard against states other than draft/red_executing
|
# The service must guard against states other than draft/red_executing
|
||||||
assert "draft" in source, "Red update must allow draft state"
|
assert "draft" in source, "Red update must allow draft state"
|
||||||
assert "red_executing" in source, "Red update must allow red_executing state"
|
assert "red_executing" in source, "Red update must allow red_executing state"
|
||||||
assert "400" in source or "HTTP_400_BAD_REQUEST" in source, "Red update must return 400 for invalid state"
|
assert "BusinessRuleViolation" in source, "Must raise domain exception for invalid state (mapped to 400)"
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|||||||
+15
-15
@@ -2,30 +2,30 @@
|
|||||||
|
|
||||||
## Tier 1 — Quick Wins
|
## Tier 1 — Quick Wins
|
||||||
|
|
||||||
- [ ] QW-1: Wire existing repos into `techniques.py` router
|
- [x] QW-1: Wire existing repos into `techniques.py` router
|
||||||
- [ ] QW-2: Fix `audit_service` to follow UoW (no direct `db.commit()`)
|
- [~] QW-2: Fix `audit_service` to follow UoW — deferred, resolves naturally as routers adopt UoW
|
||||||
- [ ] QW-3: Consolidate `status_service` with `TechniqueEntity.recalculate_status()`
|
- [x] QW-3: Consolidate `status_service` with `TechniqueEntity.recalculate_status()`
|
||||||
- [ ] QW-4: Remove remaining `HTTPException` from services
|
- [x] QW-4: Remove remaining `HTTPException` from services — already resolved
|
||||||
|
|
||||||
## Tier 2 — Service Extraction (fat routers → thin routers + services)
|
## Tier 2 — Service Extraction (fat routers → thin routers + services)
|
||||||
|
|
||||||
- [ ] SE-1: Extract reports service from `reports.py`
|
- [x] SE-1: Extract reports service → `coverage_report_service.py`
|
||||||
- [ ] SE-2: Extract metrics service from `metrics.py`
|
- [x] SE-2: Extract metrics service → `metrics_query_service.py`
|
||||||
- [ ] SE-3: Extract compliance service from `compliance.py`
|
- [x] SE-3: Extract compliance service → `compliance_service.py`
|
||||||
- [ ] SE-4: Extract detection_rules service from `detection_rules.py`
|
- [x] SE-4: Extract detection_rules service → `detection_rule_service.py`
|
||||||
- [ ] SE-5: Extract threat_actors service from `threat_actors.py`
|
- [x] SE-5: Extract threat_actors service → `threat_actor_service.py`
|
||||||
|
|
||||||
## Tier 3 — Architectural Fixes
|
## Tier 3 — Architectural Fixes
|
||||||
|
|
||||||
- [ ] AF-1: Persist scoring weights in DB (replace mutable `settings`)
|
- [x] AF-1: Persist scoring weights in DB → `scoring_config` table + `scoring_config_service.py`
|
||||||
- [ ] AF-2: Slim `tests.py` router (CRUD to repo/service)
|
- [x] AF-2: Slim `tests.py` router → `test_crud_service.py`
|
||||||
- [ ] AF-3: Slim `evidence.py` router (permissions to domain)
|
- [x] AF-3: Slim `evidence.py` router → `evidence_service.py`
|
||||||
- [ ] AF-4: Slim `campaigns.py` router (CRUD to service)
|
- [x] AF-4: Slim `campaigns.py` router → `campaign_crud_service.py`
|
||||||
|
|
||||||
## Tier 4 — Polish
|
## Tier 4 — Polish
|
||||||
|
|
||||||
- [ ] P-1: Structured JSON logging
|
- [x] P-1: Structured JSON logging → `logging_config.py`
|
||||||
- [ ] P-2: Create architecture skill file for future agents
|
- [x] P-2: Create architecture skill file → `~/.cursor/skills/aegis-architecture/SKILL.md`
|
||||||
|
|
||||||
## Completed (prior sessions)
|
## Completed (prior sessions)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user