Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0eff48c768 | |||
| 764a2f7579 | |||
| f4c74230ec | |||
| 50b70704ae | |||
| 20738d11b3 | |||
| 4e3787d091 |
@@ -0,0 +1,37 @@
|
||||
"""add_scoring_config
|
||||
|
||||
Single-row table to persist scoring weights in the database,
|
||||
replacing the mutable in-process Settings approach.
|
||||
|
||||
Revision ID: b027scorecfg
|
||||
Revises: b026techidx
|
||||
Create Date: 2026-02-19 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "b027scorecfg"
|
||||
down_revision: Union[str, None] = "b026techidx"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scoring_config",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column("weight_tests", sa.Float(), nullable=False, server_default="40.0"),
|
||||
sa.Column("weight_detection_rules", sa.Float(), nullable=False, server_default="20.0"),
|
||||
sa.Column("weight_d3fend", sa.Float(), nullable=False, server_default="15.0"),
|
||||
sa.Column("weight_freshness", sa.Float(), nullable=False, server_default="15.0"),
|
||||
sa.Column("weight_platform_diversity", sa.Float(), nullable=False, server_default="10.0"),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("scoring_config")
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Structured JSON logging configuration.
|
||||
|
||||
In **production** (``AEGIS_ENV=production``), emits one JSON object per
|
||||
line so that log aggregators (ELK, CloudWatch, Datadog) can ingest them
|
||||
without custom parsing.
|
||||
|
||||
In **development** (default), uses a human-readable text format for
|
||||
comfortable local work.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
class _JSONFormatter(logging.Formatter):
|
||||
"""Emit each log record as a single-line JSON object."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
payload: dict = {
|
||||
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
if record.exc_info and record.exc_info[1] is not None:
|
||||
payload["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
extra = getattr(record, "_extra", None)
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
_DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Configure the root logger based on the environment."""
|
||||
is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
level = getattr(logging, level_name, logging.INFO)
|
||||
|
||||
root = logging.getLogger()
|
||||
root.setLevel(level)
|
||||
|
||||
if root.handlers:
|
||||
root.handlers.clear()
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(level)
|
||||
|
||||
if is_production:
|
||||
handler.setFormatter(_JSONFormatter())
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter(_DEV_FORMAT))
|
||||
|
||||
root.addHandler(handler)
|
||||
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
+3
-4
@@ -47,10 +47,9 @@ from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
|
||||
|
||||
# ── Logging ───────────────────────────────────────────────────────────────
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s %(name)s — %(message)s",
|
||||
)
|
||||
from app.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
@@ -19,6 +19,7 @@ from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueStat
|
||||
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
|
||||
from app.models.worklog import Worklog
|
||||
from app.models.osint_item import OsintItem
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
|
||||
|
||||
__all__ = [
|
||||
@@ -31,6 +32,6 @@ __all__ = [
|
||||
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
|
||||
"CoverageSnapshot", "SnapshotTechniqueState",
|
||||
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
|
||||
"Worklog", "OsintItem",
|
||||
"Worklog", "OsintItem", "ScoringConfig",
|
||||
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
"""ScoringConfig — single-row table for persisted scoring weights.
|
||||
|
||||
Replaces the mutable-settings approach where PATCH /scores/config
|
||||
mutated the in-process ``Settings`` object (lost on restart).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, Float, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ScoringConfig(Base):
|
||||
__tablename__ = "scoring_config"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
weight_tests = Column(Float, nullable=False, default=40.0)
|
||||
weight_detection_rules = Column(Float, nullable=False, default=20.0)
|
||||
weight_d3fend = Column(Float, nullable=False, default=15.0)
|
||||
weight_freshness = Column(Float, nullable=False, default=15.0)
|
||||
weight_platform_diversity = Column(Float, nullable=False, default=10.0)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
@@ -7,26 +7,29 @@ test ordering, progress tracking, and threat actor integration.
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.user import User
|
||||
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor
|
||||
from app.services.campaign_service import (
|
||||
validate_no_circular_dependency,
|
||||
get_campaign_progress,
|
||||
generate_campaign_from_threat_actor,
|
||||
TACTIC_TO_PHASE,
|
||||
from app.services.campaign_service import generate_campaign_from_threat_actor
|
||||
from app.services.campaign_crud_service import (
|
||||
add_test_to_campaign as crud_add_test,
|
||||
activate_campaign as crud_activate,
|
||||
complete_campaign as crud_complete,
|
||||
create_campaign as crud_create,
|
||||
get_campaign_detail as crud_get_detail,
|
||||
get_campaign_history as crud_get_history,
|
||||
get_campaign_progress_data as crud_get_progress,
|
||||
list_campaigns as crud_list,
|
||||
remove_test_from_campaign as crud_remove_test,
|
||||
schedule_campaign as crud_schedule,
|
||||
serialize_campaign,
|
||||
update_campaign as crud_update,
|
||||
)
|
||||
from app.services.campaign_scheduler_service import calculate_next_run
|
||||
from app.services.notification_service import create_notification
|
||||
from app.services.audit_service import log_action
|
||||
|
||||
@@ -67,89 +70,6 @@ class SchedulePayload(BaseModel):
|
||||
next_run_at: Optional[str] = None
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"""Serialize a campaign with its tests and progress."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
.order_by(CampaignTest.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
tests = []
|
||||
for ct in campaign_tests:
|
||||
test = ct.test
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
||||
|
||||
tests.append({
|
||||
"id": str(ct.id),
|
||||
"test_id": str(ct.test_id),
|
||||
"order_index": ct.order_index,
|
||||
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
||||
"phase": ct.phase,
|
||||
"test_name": test.name if test else None,
|
||||
"test_state": test.state.value if test and test.state else None,
|
||||
"test_result": test.result.value if test and test.result else None,
|
||||
"technique_mitre_id": technique.mitre_id if technique else None,
|
||||
"technique_name": technique.name if technique else None,
|
||||
"platform": test.platform if test else None,
|
||||
})
|
||||
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
||||
"tests": tests,
|
||||
"progress": progress,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"""Lightweight campaign serialization for list views."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
"test_count": progress["total"],
|
||||
"completion_pct": progress["completion_pct"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /campaigns — List campaigns with filters
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -166,28 +86,15 @@ def list_campaigns(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List campaigns with optional filters and pagination."""
|
||||
query = db.query(Campaign)
|
||||
|
||||
if type:
|
||||
query = query.filter(Campaign.type == type)
|
||||
if status:
|
||||
query = query.filter(Campaign.status == status)
|
||||
if threat_actor_id:
|
||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||
if search:
|
||||
from app.utils import escape_like
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||
|
||||
total = query.count()
|
||||
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [_serialize_campaign_summary(db, c) for c in campaigns],
|
||||
}
|
||||
return crud_list(
|
||||
db,
|
||||
type=type,
|
||||
status=status,
|
||||
threat_actor_id=threat_actor_id,
|
||||
search=search,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -201,30 +108,29 @@ def create_campaign(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Create a new campaign."""
|
||||
campaign = Campaign(
|
||||
result = crud_create(
|
||||
db,
|
||||
creator_id=current_user.id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
type=payload.type,
|
||||
threat_actor_id=payload.threat_actor_id,
|
||||
target_platform=payload.target_platform,
|
||||
tags=payload.tags or [],
|
||||
created_by=current_user.id,
|
||||
scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None,
|
||||
tags=payload.tags,
|
||||
scheduled_at=payload.scheduled_at,
|
||||
)
|
||||
db.add(campaign)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="create_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
details={"name": campaign.name, "type": campaign.type},
|
||||
entity_id=result["id"],
|
||||
details={"name": payload.name, "type": payload.type},
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -238,11 +144,7 @@ def get_campaign(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get detailed campaign info including tests and progress."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
return crud_get_detail(db, campaign_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -257,37 +159,26 @@ def update_campaign(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Update a campaign. Only allowed in draft or active state."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only update draft or active campaigns")
|
||||
|
||||
# Check ownership or admin
|
||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign")
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if "scheduled_at" in update_data and update_data["scheduled_at"]:
|
||||
update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(campaign, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
result = crud_update(
|
||||
db,
|
||||
campaign_id,
|
||||
updater_id=current_user.id,
|
||||
updater_role=current_user.role,
|
||||
**update_data,
|
||||
)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_campaign",
|
||||
entity_type="campaign",
|
||||
entity_id=campaign.id,
|
||||
entity_id=campaign_id,
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -302,63 +193,16 @@ def add_test_to_campaign(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Add a test to a campaign with optional ordering and dependency."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns")
|
||||
|
||||
test = db.query(Test).filter(Test.id == payload.test_id).first()
|
||||
if not test:
|
||||
raise HTTPException(status_code=404, detail="Test not found")
|
||||
|
||||
# Calculate order_index if not provided
|
||||
if payload.order_index is not None:
|
||||
order_index = payload.order_index
|
||||
else:
|
||||
max_order = (
|
||||
db.query(CampaignTest.order_index)
|
||||
.filter(CampaignTest.campaign_id == campaign_id)
|
||||
.order_by(CampaignTest.order_index.desc())
|
||||
.first()
|
||||
)
|
||||
order_index = (max_order[0] + 1) if max_order else 0
|
||||
|
||||
depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None
|
||||
|
||||
# Validate circular dependency
|
||||
ct_id = uuid.uuid4()
|
||||
if depends_on:
|
||||
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on)
|
||||
|
||||
# Auto-detect kill chain phase from the test's technique tactic if not provided
|
||||
phase = payload.phase
|
||||
if not phase and test.technique_id:
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
if technique and technique.tactic:
|
||||
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
||||
|
||||
campaign_test = CampaignTest(
|
||||
id=ct_id,
|
||||
campaign_id=campaign_id,
|
||||
result = crud_add_test(
|
||||
db,
|
||||
campaign_id,
|
||||
test_id=payload.test_id,
|
||||
order_index=order_index,
|
||||
depends_on=depends_on,
|
||||
phase=phase,
|
||||
order_index=payload.order_index,
|
||||
depends_on=payload.depends_on,
|
||||
phase=payload.phase,
|
||||
)
|
||||
db.add(campaign_test)
|
||||
db.commit()
|
||||
db.refresh(campaign_test)
|
||||
|
||||
return {
|
||||
"id": str(campaign_test.id),
|
||||
"campaign_id": str(campaign_test.campaign_id),
|
||||
"test_id": str(campaign_test.test_id),
|
||||
"order_index": campaign_test.order_index,
|
||||
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
||||
"phase": campaign_test.phase,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -373,36 +217,8 @@ def remove_test_from_campaign(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Remove a test from a campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns")
|
||||
|
||||
ct = (
|
||||
db.query(CampaignTest)
|
||||
.filter(
|
||||
CampaignTest.id == campaign_test_id,
|
||||
CampaignTest.campaign_id == campaign_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not ct:
|
||||
raise HTTPException(status_code=404, detail="Campaign test not found")
|
||||
|
||||
# Clear any references to this campaign_test
|
||||
dependents = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.depends_on == campaign_test_id)
|
||||
.all()
|
||||
)
|
||||
for dep in dependents:
|
||||
dep.depends_on = None
|
||||
|
||||
db.delete(ct)
|
||||
crud_remove_test(db, campaign_id, campaign_test_id)
|
||||
db.commit()
|
||||
|
||||
return {"detail": "Test removed from campaign"}
|
||||
|
||||
|
||||
@@ -417,23 +233,10 @@ def activate_campaign(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Activate a campaign, moving it from draft to active."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status != "draft":
|
||||
raise HTTPException(status_code=400, detail="Only draft campaigns can be activated")
|
||||
|
||||
# Verify campaign has at least one test
|
||||
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
||||
if test_count == 0:
|
||||
raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate")
|
||||
|
||||
campaign.status = "active"
|
||||
campaign = crud_activate(db, campaign_id)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
# Notify relevant users
|
||||
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
|
||||
for user in red_techs:
|
||||
create_notification(
|
||||
@@ -455,7 +258,7 @@ def activate_campaign(
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -469,15 +272,7 @@ def complete_campaign(
|
||||
current_user: User = Depends(require_any_role("red_lead", "admin")),
|
||||
):
|
||||
"""Mark a campaign as completed."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
if campaign.status != "active":
|
||||
raise HTTPException(status_code=400, detail="Only active campaigns can be completed")
|
||||
|
||||
campaign.status = "completed"
|
||||
campaign.completed_at = datetime.utcnow()
|
||||
campaign = crud_complete(db, campaign_id)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
@@ -490,7 +285,7 @@ def complete_campaign(
|
||||
details={"name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -504,16 +299,7 @@ def get_campaign_progress_endpoint(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get progress statistics for a campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||
return {
|
||||
"campaign_id": str(campaign.id),
|
||||
"campaign_name": campaign.name,
|
||||
**progress,
|
||||
}
|
||||
return crud_get_progress(db, campaign_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -546,7 +332,7 @@ def generate_campaign_from_actor(
|
||||
details={"actor_id": actor_id, "campaign_name": campaign.name},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -564,31 +350,15 @@ def schedule_campaign(
|
||||
|
||||
Only the campaign creator or admin can change scheduling.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
# Check ownership or admin
|
||||
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Only the creator or admin can configure scheduling")
|
||||
|
||||
campaign.is_recurring = payload.is_recurring
|
||||
|
||||
if payload.is_recurring:
|
||||
if payload.recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'",
|
||||
)
|
||||
campaign.recurrence_pattern = payload.recurrence_pattern
|
||||
if payload.next_run_at:
|
||||
campaign.next_run_at = datetime.fromisoformat(payload.next_run_at.replace("Z", "+00:00").replace("+00:00", ""))
|
||||
elif not campaign.next_run_at:
|
||||
campaign.next_run_at = calculate_next_run(datetime.utcnow(), payload.recurrence_pattern)
|
||||
else:
|
||||
campaign.recurrence_pattern = None
|
||||
campaign.next_run_at = None
|
||||
|
||||
campaign = crud_schedule(
|
||||
db,
|
||||
campaign_id,
|
||||
owner_id=current_user.id,
|
||||
owner_role=current_user.role,
|
||||
is_recurring=payload.is_recurring,
|
||||
recurrence_pattern=payload.recurrence_pattern,
|
||||
next_run_at=payload.next_run_at,
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
|
||||
@@ -605,7 +375,7 @@ def schedule_campaign(
|
||||
},
|
||||
)
|
||||
|
||||
return _serialize_campaign(db, campaign)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -619,30 +389,4 @@ def get_campaign_history(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all child campaigns (execution history) of a recurring campaign."""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
children = (
|
||||
db.query(Campaign)
|
||||
.filter(Campaign.parent_campaign_id == campaign_id)
|
||||
.order_by(Campaign.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"campaign_id": str(campaign.id),
|
||||
"campaign_name": campaign.name,
|
||||
"items": [
|
||||
{
|
||||
"id": str(child.id),
|
||||
"name": child.name,
|
||||
"status": child.status,
|
||||
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
||||
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
||||
"created_at": child.created_at.isoformat() if child.created_at else None,
|
||||
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
||||
}
|
||||
for child in children
|
||||
],
|
||||
}
|
||||
return crud_get_history(db, campaign_id)
|
||||
|
||||
+23
-174
@@ -24,52 +24,32 @@ import os
|
||||
import uuid as _uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.enums import TeamSide, TestState
|
||||
from app.models.enums import TeamSide
|
||||
from app.models.evidence import Evidence
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
from app.schemas.evidence import EvidenceOut
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.evidence_service import (
|
||||
get_evidence_or_raise,
|
||||
get_test_or_raise,
|
||||
list_evidence_for_test,
|
||||
MAX_UPLOAD_SIZE,
|
||||
validate_delete_permission,
|
||||
validate_file,
|
||||
validate_upload_permission,
|
||||
)
|
||||
from app.storage import get_presigned_url, upload_file
|
||||
|
||||
router = APIRouter(tags=["evidence"])
|
||||
|
||||
# States where red evidence can be uploaded / deleted
|
||||
_RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
|
||||
# States where blue evidence can be uploaded / deleted
|
||||
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload safety limits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Maximum upload size in bytes (default 50 MB)
|
||||
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024
|
||||
|
||||
# Allowed file extensions (lowercase, with leading dot)
|
||||
_ALLOWED_EXTENSIONS: set[str] = {
|
||||
# Images / screenshots
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
|
||||
# Documents
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
|
||||
".md", ".rtf", ".odt", ".ods",
|
||||
# Logs & captures
|
||||
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
|
||||
".yaml", ".yml", ".toml",
|
||||
# Archives (for bundled evidence)
|
||||
".zip", ".tar", ".gz", ".7z",
|
||||
# Other common evidence types
|
||||
".har", ".eml", ".msg",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# Helpers (router-specific: infrastructure / HTTP concerns)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
@@ -87,85 +67,6 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
|
||||
)
|
||||
|
||||
|
||||
def _validate_upload_permission(
|
||||
test: Test,
|
||||
team: TeamSide,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""Raise 403 if the user/team combination is not allowed in the current state."""
|
||||
# Admins bypass all checks
|
||||
if user.role == "admin":
|
||||
return
|
||||
|
||||
if team == TeamSide.red:
|
||||
if user.role not in ("red_tech", "red_lead"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only red_tech, red_lead or admin can upload red evidence",
|
||||
)
|
||||
if test.state not in _RED_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot upload red evidence in '{test.state.value}' state "
|
||||
f"(allowed in: draft, red_executing)",
|
||||
)
|
||||
elif team == TeamSide.blue:
|
||||
if user.role not in ("blue_tech", "blue_lead"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only blue_tech, blue_lead or admin can upload blue evidence",
|
||||
)
|
||||
if test.state not in _BLUE_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot upload blue evidence in '{test.state.value}' state "
|
||||
f"(allowed in: blue_evaluating)",
|
||||
)
|
||||
|
||||
|
||||
def _validate_delete_permission(
|
||||
test: Test,
|
||||
evidence: Evidence,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""Raise 403 if the user cannot delete this evidence in the current state."""
|
||||
# No deletions in review / validated / rejected
|
||||
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Cannot delete evidence when test is in '{test.state.value}' state",
|
||||
)
|
||||
|
||||
# Admin can delete in editable states
|
||||
if user.role == "admin":
|
||||
return
|
||||
|
||||
ev_team = evidence.team
|
||||
|
||||
if ev_team == TeamSide.red:
|
||||
if test.state not in _RED_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot delete red evidence outside draft/red_executing",
|
||||
)
|
||||
if user.role not in ("red_tech", "red_lead") and evidence.uploaded_by != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions to delete this evidence",
|
||||
)
|
||||
elif ev_team == TeamSide.blue:
|
||||
if test.state not in _BLUE_EDITABLE_STATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Cannot delete blue evidence outside blue_evaluating",
|
||||
)
|
||||
if user.role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions to delete this evidence",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /tests/{test_id}/evidence — upload with team
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -189,36 +90,14 @@ async def upload_evidence(
|
||||
The ``team`` field (sent as form data) determines whether this is
|
||||
Red Team (attack) or Blue Team (detection) evidence.
|
||||
"""
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test not found",
|
||||
)
|
||||
test = get_test_or_raise(db, test_id)
|
||||
validate_upload_permission(test, team, current_user.role)
|
||||
|
||||
# Validate permissions
|
||||
_validate_upload_permission(test, team, current_user)
|
||||
|
||||
# 1. Validate file extension
|
||||
file_name = file.filename or "unnamed"
|
||||
_, ext = os.path.splitext(file_name)
|
||||
if ext.lower() not in _ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File type '{ext}' is not allowed. "
|
||||
f"Permitted types: {', '.join(sorted(_ALLOWED_EXTENSIONS))}",
|
||||
)
|
||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
validate_file(file_name, len(content))
|
||||
|
||||
# 2. Read content with size limit
|
||||
content = await file.read(_MAX_UPLOAD_SIZE + 1)
|
||||
if len(content) > _MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"File exceeds maximum upload size of "
|
||||
f"{_MAX_UPLOAD_SIZE // (1024 * 1024)} MB",
|
||||
)
|
||||
|
||||
# 3. Hash
|
||||
# Hash
|
||||
sha256 = hashlib.sha256(content).hexdigest()
|
||||
|
||||
# 4. Object key (sanitise filename to prevent path traversal in storage)
|
||||
@@ -273,19 +152,8 @@ def list_evidence(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all evidences for a test, optionally filtered by team."""
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test not found",
|
||||
)
|
||||
|
||||
query = db.query(Evidence).filter(Evidence.test_id == test_id)
|
||||
|
||||
if team:
|
||||
query = query.filter(Evidence.team == team)
|
||||
|
||||
evidences = query.order_by(Evidence.uploaded_at.desc()).all()
|
||||
get_test_or_raise(db, test_id)
|
||||
evidences = list_evidence_for_test(db, test_id, team=team)
|
||||
return [_evidence_to_out(e) for e in evidences]
|
||||
|
||||
|
||||
@@ -301,13 +169,7 @@ def get_evidence(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return evidence metadata together with a presigned download URL."""
|
||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||
if evidence is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Evidence not found",
|
||||
)
|
||||
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
return _evidence_to_out(evidence)
|
||||
|
||||
|
||||
@@ -329,22 +191,9 @@ def delete_evidence(
|
||||
- Blue evidence: ``blue_evaluating``
|
||||
- No deletions in ``in_review``, ``validated``, ``rejected``
|
||||
"""
|
||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||
if evidence is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Evidence not found",
|
||||
)
|
||||
|
||||
test = db.query(Test).filter(Test.id == evidence.test_id).first()
|
||||
if test is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Parent test not found",
|
||||
)
|
||||
|
||||
# Permission checks
|
||||
_validate_delete_permission(test, evidence, current_user)
|
||||
evidence = get_evidence_or_raise(db, evidence_id)
|
||||
test = get_test_or_raise(db, evidence.test_id)
|
||||
validate_delete_permission(test, evidence, current_user.role, current_user.id)
|
||||
|
||||
# Audit before deletion
|
||||
log_action(
|
||||
|
||||
@@ -14,7 +14,6 @@ from app.dependencies.auth import get_current_user, require_role
|
||||
from app.models.user import User
|
||||
from app.models.technique import Technique
|
||||
from app.models.threat_actor import ThreatActor
|
||||
from app.config import settings
|
||||
from app.services.scoring_service import (
|
||||
calculate_technique_score,
|
||||
calculate_tactic_score,
|
||||
@@ -22,6 +21,10 @@ from app.services.scoring_service import (
|
||||
calculate_organization_score,
|
||||
get_score_history,
|
||||
)
|
||||
from app.services.scoring_config_service import (
|
||||
get_weights_dict,
|
||||
update_scoring_weights,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/scores", tags=["scores"])
|
||||
|
||||
@@ -117,79 +120,45 @@ def score_history(
|
||||
|
||||
@router.get("/config")
|
||||
def get_scoring_config(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Get current scoring weights (admin only)."""
|
||||
return {
|
||||
"weights": {
|
||||
"tests": settings.SCORING_WEIGHT_TESTS,
|
||||
"detection_rules": settings.SCORING_WEIGHT_DETECTION_RULES,
|
||||
"d3fend": settings.SCORING_WEIGHT_D3FEND,
|
||||
"freshness": settings.SCORING_WEIGHT_FRESHNESS,
|
||||
"platform_diversity": settings.SCORING_WEIGHT_PLATFORM_DIVERSITY,
|
||||
},
|
||||
"total": (
|
||||
settings.SCORING_WEIGHT_TESTS
|
||||
+ settings.SCORING_WEIGHT_DETECTION_RULES
|
||||
+ settings.SCORING_WEIGHT_D3FEND
|
||||
+ settings.SCORING_WEIGHT_FRESHNESS
|
||||
+ settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
||||
),
|
||||
}
|
||||
return get_weights_dict(db)
|
||||
|
||||
|
||||
# ── PATCH /scores/config ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class ScoringConfigUpdate(BaseModel):
|
||||
tests: Optional[int] = None
|
||||
detection_rules: Optional[int] = None
|
||||
d3fend: Optional[int] = None
|
||||
freshness: Optional[int] = None
|
||||
platform_diversity: Optional[int] = None
|
||||
tests: Optional[float] = None
|
||||
detection_rules: Optional[float] = None
|
||||
d3fend: Optional[float] = None
|
||||
freshness: Optional[float] = None
|
||||
platform_diversity: Optional[float] = None
|
||||
|
||||
|
||||
@router.patch("/config")
|
||||
def update_scoring_config(
|
||||
payload: ScoringConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_role("admin")),
|
||||
):
|
||||
"""Update scoring weights (admin only).
|
||||
|
||||
Note: Since we're using Opcion A (env vars / Settings), changes
|
||||
are applied at runtime but won't persist across restarts unless
|
||||
the .env file is also updated. For production, consider migrating
|
||||
to Option B (database table).
|
||||
Weights are persisted in the database and survive restarts.
|
||||
Validation enforces that all weights are non-negative and sum to 100.
|
||||
"""
|
||||
if payload.tests is not None:
|
||||
settings.SCORING_WEIGHT_TESTS = payload.tests
|
||||
if payload.detection_rules is not None:
|
||||
settings.SCORING_WEIGHT_DETECTION_RULES = payload.detection_rules
|
||||
if payload.d3fend is not None:
|
||||
settings.SCORING_WEIGHT_D3FEND = payload.d3fend
|
||||
if payload.freshness is not None:
|
||||
settings.SCORING_WEIGHT_FRESHNESS = payload.freshness
|
||||
if payload.platform_diversity is not None:
|
||||
settings.SCORING_WEIGHT_PLATFORM_DIVERSITY = payload.platform_diversity
|
||||
result = update_scoring_weights(
|
||||
db,
|
||||
tests=payload.tests,
|
||||
detection_rules=payload.detection_rules,
|
||||
d3fend=payload.d3fend,
|
||||
freshness=payload.freshness,
|
||||
platform_diversity=payload.platform_diversity,
|
||||
)
|
||||
|
||||
# Weights changed — bust the score cache
|
||||
from app.services.score_cache import invalidate
|
||||
invalidate()
|
||||
|
||||
return {
|
||||
"message": "Scoring config updated",
|
||||
"weights": {
|
||||
"tests": settings.SCORING_WEIGHT_TESTS,
|
||||
"detection_rules": settings.SCORING_WEIGHT_DETECTION_RULES,
|
||||
"d3fend": settings.SCORING_WEIGHT_D3FEND,
|
||||
"freshness": settings.SCORING_WEIGHT_FRESHNESS,
|
||||
"platform_diversity": settings.SCORING_WEIGHT_PLATFORM_DIVERSITY,
|
||||
},
|
||||
"total": (
|
||||
settings.SCORING_WEIGHT_TESTS
|
||||
+ settings.SCORING_WEIGHT_DETECTION_RULES
|
||||
+ settings.SCORING_WEIGHT_D3FEND
|
||||
+ settings.SCORING_WEIGHT_FRESHNESS
|
||||
+ settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
||||
),
|
||||
}
|
||||
return {"message": "Scoring config updated", **result}
|
||||
|
||||
+65
-203
@@ -22,15 +22,11 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user, require_any_role
|
||||
from app.models.audit import AuditLog
|
||||
from app.models.enums import TestState, TeamSide
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.enums import TestState
|
||||
from app.models.user import User
|
||||
from app.schemas.test import (
|
||||
TestCreate,
|
||||
@@ -46,6 +42,18 @@ from app.schemas.test_template import TestTemplateInstantiate
|
||||
from app.domain.unit_of_work import UnitOfWork
|
||||
from app.services.audit_service import log_action
|
||||
from app.services.status_service import recalculate_technique_status
|
||||
from app.services.test_crud_service import (
|
||||
create_test as crud_create_test,
|
||||
create_test_from_template as crud_create_from_template,
|
||||
get_test_detail as crud_get_test_detail,
|
||||
get_test_or_raise as crud_get_test_or_raise,
|
||||
get_test_timeline as crud_get_test_timeline,
|
||||
get_test_with_technique as crud_get_test_with_technique,
|
||||
list_tests as crud_list_tests,
|
||||
update_test as crud_update_test,
|
||||
update_test_blue as crud_update_test_blue,
|
||||
update_test_red as crud_update_test_red,
|
||||
)
|
||||
from app.services.test_workflow_service import (
|
||||
start_execution as wf_start_execution,
|
||||
submit_red_evidence as wf_submit_red,
|
||||
@@ -62,29 +70,6 @@ from app.services.test_workflow_service import (
|
||||
router = APIRouter(prefix="/tests", tags=["tests"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test:
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
|
||||
return test
|
||||
|
||||
|
||||
def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
||||
test = (
|
||||
db.query(Test)
|
||||
.options(joinedload(Test.technique))
|
||||
.filter(Test.id == test_id)
|
||||
.first()
|
||||
)
|
||||
if test is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tests — list with filters
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -105,30 +90,16 @@ def list_tests(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
|
||||
query = db.query(Test).options(joinedload(Test.technique))
|
||||
|
||||
if state:
|
||||
query = query.filter(Test.state == state)
|
||||
if technique_id:
|
||||
query = query.filter(Test.technique_id == technique_id)
|
||||
if platform:
|
||||
from app.utils import escape_like
|
||||
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
if created_by:
|
||||
query = query.filter(Test.created_by == created_by)
|
||||
if pending_validation_side == "red":
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.red_validation_status.in_(["pending", None]),
|
||||
)
|
||||
elif pending_validation_side == "blue":
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.blue_validation_status.in_(["pending", None]),
|
||||
)
|
||||
|
||||
tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
||||
return tests
|
||||
return crud_list_tests(
|
||||
db,
|
||||
state=state,
|
||||
technique_id=technique_id,
|
||||
platform=platform,
|
||||
created_by=created_by,
|
||||
pending_validation_side=pending_validation_side,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -150,20 +121,14 @@ def create_test(
|
||||
|
||||
``created_by`` is set automatically and ``state`` defaults to *draft*.
|
||||
"""
|
||||
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
|
||||
if technique is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Technique with id '{payload.technique_id}' not found",
|
||||
with UnitOfWork(db) as uow:
|
||||
test = crud_create_test(
|
||||
db,
|
||||
technique_id=payload.technique_id,
|
||||
creator_id=current_user.id,
|
||||
**payload.model_dump(exclude={"technique_id"}),
|
||||
)
|
||||
|
||||
test = Test(
|
||||
**payload.model_dump(),
|
||||
created_by=current_user.id,
|
||||
state=TestState.draft,
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
|
||||
log_action(
|
||||
@@ -197,43 +162,14 @@ def create_test_from_template(
|
||||
|
||||
The template's fields are copied into the new test as starting data.
|
||||
"""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first()
|
||||
if template is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"TestTemplate with id '{payload.template_id}' not found",
|
||||
with UnitOfWork(db) as uow:
|
||||
test = crud_create_from_template(
|
||||
db,
|
||||
template_id=payload.template_id,
|
||||
technique_id_or_mitre=payload.technique_id,
|
||||
creator_id=current_user.id,
|
||||
)
|
||||
|
||||
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
|
||||
technique = None
|
||||
try:
|
||||
technique_uuid = uuid.UUID(payload.technique_id)
|
||||
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if technique is None:
|
||||
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
|
||||
|
||||
if technique is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Technique '{payload.technique_id}' not found",
|
||||
)
|
||||
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
platform=template.platform,
|
||||
procedure_text=template.attack_procedure,
|
||||
tool_used=template.tool_suggested,
|
||||
remediation_steps=template.suggested_remediation,
|
||||
created_by=current_user.id,
|
||||
state=TestState.draft,
|
||||
)
|
||||
db.add(test)
|
||||
db.commit()
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
|
||||
log_action(
|
||||
@@ -244,7 +180,7 @@ def create_test_from_template(
|
||||
entity_id=test.id,
|
||||
details={
|
||||
"name": test.name,
|
||||
"template_id": str(template.id),
|
||||
"template_id": str(payload.template_id),
|
||||
"technique_id": str(test.technique_id),
|
||||
},
|
||||
)
|
||||
@@ -264,20 +200,7 @@ def get_test(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return full details for a single test, including its evidences."""
|
||||
test = (
|
||||
db.query(Test)
|
||||
.options(joinedload(Test.evidences))
|
||||
.filter(Test.id == test_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if test is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Test not found",
|
||||
)
|
||||
|
||||
return test
|
||||
return crud_get_test_detail(db, test_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -297,29 +220,16 @@ def update_test(
|
||||
Only leads or admins can update general test fields.
|
||||
The test must be in ``draft`` or ``rejected`` state.
|
||||
"""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
|
||||
if current_user.role != "admin" and test.created_by != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"},
|
||||
)
|
||||
|
||||
if test.state not in (TestState.draft, TestState.rejected):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
|
||||
"code": "INVALID_STATE",
|
||||
"current_state": test.state.value,
|
||||
},
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
test = crud_update_test(
|
||||
db,
|
||||
test_id,
|
||||
updater_id=current_user.id,
|
||||
updater_role=current_user.role,
|
||||
**update_data,
|
||||
)
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
|
||||
log_action(
|
||||
@@ -347,23 +257,10 @@ def update_test_red(
|
||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||
):
|
||||
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
|
||||
if test.state not in (TestState.draft, TestState.red_executing):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
|
||||
"code": "INVALID_STATE",
|
||||
"current_state": test.state.value,
|
||||
},
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
test = crud_update_test_red(db, test_id, **update_data)
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
|
||||
log_action(
|
||||
@@ -391,23 +288,10 @@ def update_test_blue(
|
||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||
):
|
||||
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
|
||||
if test.state != TestState.blue_evaluating:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
|
||||
"code": "INVALID_STATE",
|
||||
"current_state": test.state.value,
|
||||
},
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.commit()
|
||||
with UnitOfWork(db) as uow:
|
||||
test = crud_update_test_blue(db, test_id, **update_data)
|
||||
uow.commit()
|
||||
db.refresh(test)
|
||||
|
||||
log_action(
|
||||
@@ -434,7 +318,7 @@ def start_execution(
|
||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||
):
|
||||
"""Move a test from ``draft`` to ``red_executing``."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_start_execution(db, test, current_user)
|
||||
uow.commit()
|
||||
@@ -454,7 +338,7 @@ def submit_red(
|
||||
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
|
||||
):
|
||||
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_submit_red(db, test, current_user)
|
||||
uow.commit()
|
||||
@@ -474,7 +358,7 @@ def submit_blue(
|
||||
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
|
||||
):
|
||||
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_submit_blue(db, test, current_user)
|
||||
uow.commit()
|
||||
@@ -494,7 +378,7 @@ def pause_timer(
|
||||
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_pause_timer(db, test, current_user)
|
||||
uow.commit()
|
||||
@@ -514,7 +398,7 @@ def resume_timer(
|
||||
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
|
||||
):
|
||||
"""Resume the paused timer for the current phase."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_resume_timer(db, test, current_user)
|
||||
uow.commit()
|
||||
@@ -535,7 +419,7 @@ def validate_red(
|
||||
current_user: User = Depends(require_any_role("red_lead")),
|
||||
):
|
||||
"""Red Lead approves or rejects the red side of a test."""
|
||||
test = _get_test_with_technique(db, test_id)
|
||||
test = crud_get_test_with_technique(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_validate_red(
|
||||
db, test, current_user,
|
||||
@@ -562,7 +446,7 @@ def validate_blue(
|
||||
current_user: User = Depends(require_any_role("blue_lead")),
|
||||
):
|
||||
"""Blue Lead approves or rejects the blue side of a test."""
|
||||
test = _get_test_with_technique(db, test_id)
|
||||
test = crud_get_test_with_technique(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_validate_blue(
|
||||
db, test, current_user,
|
||||
@@ -588,7 +472,7 @@ def reopen(
|
||||
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
|
||||
):
|
||||
"""Reopen a rejected test, moving it back to ``draft``."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
with UnitOfWork(db) as uow:
|
||||
test = wf_reopen(db, test, current_user)
|
||||
uow.commit()
|
||||
@@ -613,7 +497,7 @@ def update_remediation(
|
||||
When ``remediation_status`` transitions to ``'completed'``, an automatic
|
||||
re-test is created (subject to ``MAX_RETEST_COUNT``).
|
||||
"""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
test = crud_get_test_or_raise(db, test_id)
|
||||
|
||||
old_remediation_status = test.remediation_status
|
||||
|
||||
@@ -653,29 +537,7 @@ def get_test_timeline(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return the chronological audit-log history for a test."""
|
||||
# Verify the test exists
|
||||
_get_test_or_404(db, test_id)
|
||||
|
||||
logs = (
|
||||
db.query(AuditLog)
|
||||
.filter(
|
||||
AuditLog.entity_type == "test",
|
||||
AuditLog.entity_id == str(test_id),
|
||||
)
|
||||
.order_by(AuditLog.timestamp.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(log.id),
|
||||
"action": log.action,
|
||||
"user_id": str(log.user_id) if log.user_id else None,
|
||||
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
||||
"details": log.details,
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
return crud_get_test_timeline(db, test_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,460 @@
|
||||
"""Campaign CRUD service — list, create, update, and business logic.
|
||||
|
||||
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
EntityNotFoundError,
|
||||
PermissionViolation,
|
||||
)
|
||||
from app.models.campaign import Campaign, CampaignTest
|
||||
from app.models.test import Test
|
||||
from app.models.technique import Technique
|
||||
from app.utils import escape_like
|
||||
from app.services.campaign_service import (
|
||||
get_campaign_progress,
|
||||
validate_no_circular_dependency,
|
||||
TACTIC_TO_PHASE,
|
||||
)
|
||||
from app.services.campaign_scheduler_service import calculate_next_run
|
||||
|
||||
|
||||
# ── Serialization helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
|
||||
"""Serialize a campaign with its tests and progress."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
|
||||
campaign_tests = (
|
||||
db.query(CampaignTest)
|
||||
.filter(CampaignTest.campaign_id == campaign.id)
|
||||
.order_by(CampaignTest.order_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
tests = []
|
||||
for ct in campaign_tests:
|
||||
test = ct.test
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
|
||||
|
||||
tests.append({
|
||||
"id": str(ct.id),
|
||||
"test_id": str(ct.test_id),
|
||||
"order_index": ct.order_index,
|
||||
"depends_on": str(ct.depends_on) if ct.depends_on else None,
|
||||
"phase": ct.phase,
|
||||
"test_name": test.name if test else None,
|
||||
"test_state": test.state.value if test and test.state else None,
|
||||
"test_result": test.result.value if test and test.result else None,
|
||||
"technique_mitre_id": technique.mitre_id if technique else None,
|
||||
"technique_name": technique.name if technique else None,
|
||||
"platform": test.platform if test else None,
|
||||
})
|
||||
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"created_by": str(campaign.created_by) if campaign.created_by else None,
|
||||
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
|
||||
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
|
||||
"tests": tests,
|
||||
"progress": progress,
|
||||
}
|
||||
|
||||
|
||||
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
|
||||
"""Lightweight campaign serialization for list views."""
|
||||
progress = get_campaign_progress(db, campaign.id)
|
||||
actor = campaign.threat_actor
|
||||
|
||||
return {
|
||||
"id": str(campaign.id),
|
||||
"name": campaign.name,
|
||||
"description": campaign.description,
|
||||
"type": campaign.type,
|
||||
"status": campaign.status,
|
||||
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
|
||||
"threat_actor_name": actor.name if actor else None,
|
||||
"target_platform": campaign.target_platform,
|
||||
"tags": campaign.tags or [],
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"is_recurring": campaign.is_recurring or False,
|
||||
"recurrence_pattern": campaign.recurrence_pattern,
|
||||
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
|
||||
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
|
||||
"test_count": progress["total"],
|
||||
"completion_pct": progress["completion_pct"],
|
||||
}
|
||||
|
||||
|
||||
# ── CRUD operations ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def list_campaigns(
|
||||
db: Session,
|
||||
*,
|
||||
type: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
threat_actor_id: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""Return a paginated list of campaigns with optional filters."""
|
||||
query = db.query(Campaign)
|
||||
|
||||
if type:
|
||||
query = query.filter(Campaign.type == type)
|
||||
if status:
|
||||
query = query.filter(Campaign.status == status)
|
||||
if threat_actor_id:
|
||||
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
|
||||
if search:
|
||||
pattern = f"%{escape_like(search)}%"
|
||||
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
|
||||
|
||||
total = query.count()
|
||||
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"items": [serialize_campaign_summary(db, c) for c in campaigns],
|
||||
}
|
||||
|
||||
|
||||
def create_campaign(
|
||||
db: Session,
|
||||
*,
|
||||
creator_id: uuid.UUID,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
type: str = "custom",
|
||||
threat_actor_id: Optional[str] = None,
|
||||
target_platform: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
scheduled_at: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create a new campaign. Does not commit; caller commits."""
|
||||
campaign = Campaign(
|
||||
name=name,
|
||||
description=description,
|
||||
type=type,
|
||||
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
|
||||
target_platform=target_platform,
|
||||
tags=tags or [],
|
||||
created_by=creator_id,
|
||||
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
|
||||
)
|
||||
db.add(campaign)
|
||||
db.flush()
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
|
||||
"""Get detailed campaign info including tests and progress.
|
||||
|
||||
Raises EntityNotFoundError if campaign not found.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
def update_campaign(
|
||||
db: Session,
|
||||
campaign_id: str,
|
||||
*,
|
||||
updater_id: uuid.UUID,
|
||||
updater_role: str,
|
||||
**fields,
|
||||
) -> dict:
|
||||
"""Update a campaign. Only allowed in draft or active state.
|
||||
|
||||
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise BusinessRuleViolation("Can only update draft or active campaigns")
|
||||
|
||||
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
|
||||
raise PermissionViolation("Only the creator or admin can update this campaign")
|
||||
|
||||
if "scheduled_at" in fields and fields["scheduled_at"]:
|
||||
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(campaign, field, value)
|
||||
|
||||
db.flush()
|
||||
return serialize_campaign(db, campaign)
|
||||
|
||||
|
||||
def add_test_to_campaign(
|
||||
db: Session,
|
||||
campaign_id: str,
|
||||
*,
|
||||
test_id: str,
|
||||
order_index: Optional[int] = None,
|
||||
depends_on: Optional[str] = None,
|
||||
phase: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Add a test to a campaign with optional ordering and dependency.
|
||||
|
||||
Raises EntityNotFoundError for missing campaign or test.
|
||||
Raises BusinessRuleViolation for invalid state or circular dependency.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
|
||||
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if not test:
|
||||
raise EntityNotFoundError("Test", test_id)
|
||||
|
||||
if order_index is not None:
|
||||
final_order_index = order_index
|
||||
else:
|
||||
max_order = (
|
||||
db.query(CampaignTest.order_index)
|
||||
.filter(CampaignTest.campaign_id == campaign_id)
|
||||
.order_by(CampaignTest.order_index.desc())
|
||||
.first()
|
||||
)
|
||||
final_order_index = (max_order[0] + 1) if max_order else 0
|
||||
|
||||
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
|
||||
|
||||
ct_id = uuid.uuid4()
|
||||
if depends_on_uuid:
|
||||
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
|
||||
|
||||
if not phase and test.technique_id:
|
||||
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
|
||||
if technique and technique.tactic:
|
||||
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
|
||||
|
||||
campaign_test = CampaignTest(
|
||||
id=ct_id,
|
||||
campaign_id=campaign_id,
|
||||
test_id=test_id,
|
||||
order_index=final_order_index,
|
||||
depends_on=depends_on_uuid,
|
||||
phase=phase,
|
||||
)
|
||||
db.add(campaign_test)
|
||||
db.flush()
|
||||
|
||||
return {
|
||||
"id": str(campaign_test.id),
|
||||
"campaign_id": str(campaign_test.campaign_id),
|
||||
"test_id": str(campaign_test.test_id),
|
||||
"order_index": campaign_test.order_index,
|
||||
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
|
||||
"phase": campaign_test.phase,
|
||||
}
|
||||
|
||||
|
||||
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
|
||||
"""Remove a test from a campaign.
|
||||
|
||||
Raises EntityNotFoundError for missing campaign or campaign test.
|
||||
Raises BusinessRuleViolation for invalid state.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if campaign.status not in ("draft", "active"):
|
||||
raise BusinessRuleViolation("Can only modify draft or active campaigns")
|
||||
|
||||
ct = (
|
||||
db.query(CampaignTest)
|
||||
.filter(
|
||||
CampaignTest.id == campaign_test_id,
|
||||
CampaignTest.campaign_id == campaign_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not ct:
|
||||
raise EntityNotFoundError("CampaignTest", campaign_test_id)
|
||||
|
||||
dep_id = uuid.UUID(campaign_test_id)
|
||||
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
|
||||
for dep in dependents:
|
||||
dep.depends_on = None
|
||||
|
||||
db.delete(ct)
|
||||
db.flush()
|
||||
|
||||
|
||||
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||
"""Activate a campaign, moving it from draft to active.
|
||||
|
||||
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if campaign.status != "draft":
|
||||
raise BusinessRuleViolation("Only draft campaigns can be activated")
|
||||
|
||||
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
|
||||
if test_count == 0:
|
||||
raise BusinessRuleViolation("Campaign must have at least one test to activate")
|
||||
|
||||
campaign.status = "active"
|
||||
db.flush()
|
||||
return campaign
|
||||
|
||||
|
||||
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
|
||||
"""Mark a campaign as completed.
|
||||
|
||||
Raises EntityNotFoundError, BusinessRuleViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if campaign.status != "active":
|
||||
raise BusinessRuleViolation("Only active campaigns can be completed")
|
||||
|
||||
campaign.status = "completed"
|
||||
campaign.completed_at = datetime.utcnow()
|
||||
db.flush()
|
||||
return campaign
|
||||
|
||||
|
||||
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
|
||||
"""Get progress statistics for a campaign.
|
||||
|
||||
Raises EntityNotFoundError if campaign not found.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
|
||||
return {
|
||||
"campaign_id": str(campaign.id),
|
||||
"campaign_name": campaign.name,
|
||||
**progress,
|
||||
}
|
||||
|
||||
|
||||
def schedule_campaign(
|
||||
db: Session,
|
||||
campaign_id: str,
|
||||
*,
|
||||
owner_id: uuid.UUID,
|
||||
owner_role: str,
|
||||
is_recurring: bool,
|
||||
recurrence_pattern: Optional[str] = None,
|
||||
next_run_at: Optional[str] = None,
|
||||
) -> Campaign:
|
||||
"""Configure or update the recurrence schedule for a campaign.
|
||||
|
||||
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
|
||||
Does not commit; caller commits.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
|
||||
raise PermissionViolation("Only the creator or admin can configure scheduling")
|
||||
|
||||
campaign.is_recurring = is_recurring
|
||||
|
||||
if is_recurring:
|
||||
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
|
||||
raise BusinessRuleViolation(
|
||||
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
|
||||
)
|
||||
campaign.recurrence_pattern = recurrence_pattern
|
||||
if next_run_at:
|
||||
campaign.next_run_at = datetime.fromisoformat(
|
||||
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
|
||||
)
|
||||
elif not campaign.next_run_at:
|
||||
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
|
||||
else:
|
||||
campaign.recurrence_pattern = None
|
||||
campaign.next_run_at = None
|
||||
|
||||
db.flush()
|
||||
return campaign
|
||||
|
||||
|
||||
def get_campaign_history(db: Session, campaign_id: str) -> dict:
|
||||
"""List all child campaigns (execution history) of a recurring campaign.
|
||||
|
||||
Raises EntityNotFoundError if campaign not found.
|
||||
"""
|
||||
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
|
||||
if not campaign:
|
||||
raise EntityNotFoundError("Campaign", campaign_id)
|
||||
|
||||
campaign_uuid = uuid.UUID(campaign_id)
|
||||
children = (
|
||||
db.query(Campaign)
|
||||
.filter(Campaign.parent_campaign_id == campaign_uuid)
|
||||
.order_by(Campaign.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"campaign_id": str(campaign.id),
|
||||
"campaign_name": campaign.name,
|
||||
"items": [
|
||||
{
|
||||
"id": str(child.id),
|
||||
"name": child.name,
|
||||
"status": child.status,
|
||||
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
|
||||
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
|
||||
"created_at": child.created_at.isoformat() if child.created_at else None,
|
||||
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
|
||||
}
|
||||
for child in children
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Evidence service — permission validation, file validation, and query logic.
|
||||
|
||||
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||
The router is responsible for HTTP concerns, file I/O, MinIO upload,
|
||||
audit logging, and response formatting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
EntityNotFoundError,
|
||||
PermissionViolation,
|
||||
)
|
||||
from app.models.enums import TeamSide, TestState
|
||||
from app.models.evidence import Evidence
|
||||
from app.models.test import Test
|
||||
|
||||
# States where red evidence can be uploaded / deleted
|
||||
RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
|
||||
# States where blue evidence can be uploaded / deleted
|
||||
BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
|
||||
|
||||
# Maximum upload size in bytes (50 MB)
|
||||
MAX_UPLOAD_SIZE = 50 * 1024 * 1024
|
||||
|
||||
# Allowed file extensions (lowercase, with leading dot)
|
||||
ALLOWED_EXTENSIONS: frozenset[str] = frozenset({
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
|
||||
".md", ".rtf", ".odt", ".ods",
|
||||
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
|
||||
".yaml", ".yml", ".toml",
|
||||
".zip", ".tar", ".gz", ".7z",
|
||||
".har", ".eml", ".msg",
|
||||
})
|
||||
|
||||
|
||||
def validate_upload_permission(
|
||||
test: Test,
|
||||
team: TeamSide,
|
||||
user_role: str,
|
||||
) -> None:
|
||||
"""Validate that the user can upload evidence for the given team in the current state.
|
||||
|
||||
Raises:
|
||||
PermissionViolation: If user lacks role to upload for this team.
|
||||
BusinessRuleViolation: If test state does not allow uploading for this team.
|
||||
"""
|
||||
if user_role == "admin":
|
||||
return
|
||||
|
||||
if team == TeamSide.red:
|
||||
if user_role not in ("red_tech", "red_lead"):
|
||||
raise PermissionViolation(
|
||||
"Only red_tech, red_lead or admin can upload red evidence"
|
||||
)
|
||||
if test.state not in RED_EDITABLE_STATES:
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot upload red evidence in '{test.state.value}' state "
|
||||
"(allowed in: draft, red_executing)"
|
||||
)
|
||||
elif team == TeamSide.blue:
|
||||
if user_role not in ("blue_tech", "blue_lead"):
|
||||
raise PermissionViolation(
|
||||
"Only blue_tech, blue_lead or admin can upload blue evidence"
|
||||
)
|
||||
if test.state not in BLUE_EDITABLE_STATES:
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot upload blue evidence in '{test.state.value}' state "
|
||||
"(allowed in: blue_evaluating)"
|
||||
)
|
||||
|
||||
|
||||
def validate_delete_permission(
|
||||
test: Test,
|
||||
evidence: Evidence,
|
||||
user_role: str,
|
||||
user_id: uuid.UUID,
|
||||
) -> None:
|
||||
"""Validate that the user can delete this evidence in the current state.
|
||||
|
||||
Raises:
|
||||
PermissionViolation: If user cannot delete in this state or lacks permission.
|
||||
"""
|
||||
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
|
||||
raise PermissionViolation(
|
||||
f"Cannot delete evidence when test is in '{test.state.value}' state"
|
||||
)
|
||||
|
||||
if user_role == "admin":
|
||||
return
|
||||
|
||||
ev_team = evidence.team
|
||||
|
||||
if ev_team == TeamSide.red:
|
||||
if test.state not in RED_EDITABLE_STATES:
|
||||
raise PermissionViolation(
|
||||
"Cannot delete red evidence outside draft/red_executing"
|
||||
)
|
||||
if user_role not in ("red_tech", "red_lead") and evidence.uploaded_by != user_id:
|
||||
raise PermissionViolation(
|
||||
"Not enough permissions to delete this evidence"
|
||||
)
|
||||
elif ev_team == TeamSide.blue:
|
||||
if test.state not in BLUE_EDITABLE_STATES:
|
||||
raise PermissionViolation(
|
||||
"Cannot delete blue evidence outside blue_evaluating"
|
||||
)
|
||||
if user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user_id:
|
||||
raise PermissionViolation(
|
||||
"Not enough permissions to delete this evidence"
|
||||
)
|
||||
|
||||
|
||||
def validate_file(file_name: str, content_size: int) -> None:
|
||||
"""Validate file extension and size.
|
||||
|
||||
Raises:
|
||||
BusinessRuleViolation: If extension is not allowed or file exceeds size limit.
|
||||
"""
|
||||
_, ext = os.path.splitext(file_name)
|
||||
ext_lower = ext.lower() if ext else ""
|
||||
if ext_lower not in ALLOWED_EXTENSIONS:
|
||||
raise BusinessRuleViolation(
|
||||
f"File type '{ext}' is not allowed. "
|
||||
f"Permitted types: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
|
||||
)
|
||||
if content_size > MAX_UPLOAD_SIZE:
|
||||
raise BusinessRuleViolation(
|
||||
f"File exceeds maximum upload size of {MAX_UPLOAD_SIZE // (1024 * 1024)} MB"
|
||||
)
|
||||
|
||||
|
||||
def list_evidence_for_test(
|
||||
db: Session,
|
||||
test_id: uuid.UUID,
|
||||
*,
|
||||
team: TeamSide | str | None = None,
|
||||
) -> list[Evidence]:
|
||||
"""Return evidence for a test, optionally filtered by team."""
|
||||
query = db.query(Evidence).filter(Evidence.test_id == test_id)
|
||||
if team is not None:
|
||||
team_enum = TeamSide(team) if isinstance(team, str) else team
|
||||
query = query.filter(Evidence.team == team_enum)
|
||||
return query.order_by(Evidence.uploaded_at.desc()).all()
|
||||
|
||||
|
||||
def get_evidence_or_raise(db: Session, evidence_id: uuid.UUID) -> Evidence:
|
||||
"""Fetch evidence by ID. Raises EntityNotFoundError if not found."""
|
||||
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
|
||||
if evidence is None:
|
||||
raise EntityNotFoundError("Evidence", str(evidence_id))
|
||||
return evidence
|
||||
|
||||
|
||||
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch test by ID. Raises EntityNotFoundError if not found."""
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
return test
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Scoring configuration persistence service.
|
||||
|
||||
Reads and writes scoring weights from the ``scoring_config`` table.
|
||||
Falls back to environment-variable defaults (from ``Settings``) when
|
||||
no row has been persisted yet.
|
||||
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.domain.value_objects.scoring_weights import ScoringWeights
|
||||
from app.models.scoring_config import ScoringConfig
|
||||
|
||||
|
||||
def get_scoring_weights(db: Session) -> ScoringWeights:
|
||||
"""Return the active scoring weights.
|
||||
|
||||
Reads the single ``scoring_config`` row. If the table is empty
|
||||
(first run or migration just applied), falls back to the values
|
||||
from the environment / ``Settings``.
|
||||
"""
|
||||
row = db.query(ScoringConfig).first()
|
||||
if row is not None:
|
||||
return ScoringWeights(
|
||||
tests=row.weight_tests,
|
||||
detection_rules=row.weight_detection_rules,
|
||||
d3fend=row.weight_d3fend,
|
||||
freshness=row.weight_freshness,
|
||||
platform_diversity=row.weight_platform_diversity,
|
||||
)
|
||||
|
||||
return ScoringWeights(
|
||||
tests=float(settings.SCORING_WEIGHT_TESTS),
|
||||
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
|
||||
d3fend=float(settings.SCORING_WEIGHT_D3FEND),
|
||||
freshness=float(settings.SCORING_WEIGHT_FRESHNESS),
|
||||
platform_diversity=float(settings.SCORING_WEIGHT_PLATFORM_DIVERSITY),
|
||||
)
|
||||
|
||||
|
||||
def update_scoring_weights(
|
||||
db: Session,
|
||||
*,
|
||||
tests: float | None = None,
|
||||
detection_rules: float | None = None,
|
||||
d3fend: float | None = None,
|
||||
freshness: float | None = None,
|
||||
platform_diversity: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Upsert scoring weights into the database.
|
||||
|
||||
Only provided fields are overwritten; ``None`` values keep the
|
||||
current (or default) value. Validates via ``ScoringWeights``
|
||||
before persisting.
|
||||
|
||||
Returns a dict with ``weights`` and ``total``.
|
||||
"""
|
||||
current = get_scoring_weights(db)
|
||||
|
||||
new = ScoringWeights(
|
||||
tests=tests if tests is not None else current.tests,
|
||||
detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
|
||||
d3fend=d3fend if d3fend is not None else current.d3fend,
|
||||
freshness=freshness if freshness is not None else current.freshness,
|
||||
platform_diversity=platform_diversity if platform_diversity is not None else current.platform_diversity,
|
||||
)
|
||||
|
||||
row = db.query(ScoringConfig).first()
|
||||
if row is None:
|
||||
row = ScoringConfig()
|
||||
db.add(row)
|
||||
|
||||
row.weight_tests = new.tests
|
||||
row.weight_detection_rules = new.detection_rules
|
||||
row.weight_d3fend = new.d3fend
|
||||
row.weight_freshness = new.freshness
|
||||
row.weight_platform_diversity = new.platform_diversity
|
||||
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
|
||||
return _weights_dict(new)
|
||||
|
||||
|
||||
def get_weights_dict(db: Session) -> dict[str, Any]:
|
||||
"""Return current weights as a serialisable dict."""
|
||||
return _weights_dict(get_scoring_weights(db))
|
||||
|
||||
|
||||
def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
|
||||
weights = {
|
||||
"tests": w.tests,
|
||||
"detection_rules": w.detection_rules,
|
||||
"d3fend": w.d3fend,
|
||||
"freshness": w.freshness,
|
||||
"platform_diversity": w.platform_diversity,
|
||||
}
|
||||
return {
|
||||
"weights": weights,
|
||||
"total": sum(weights.values()),
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org.
|
||||
|
||||
Uses configurable weights from Settings to compute coverage scores with
|
||||
detailed breakdowns.
|
||||
Reads configurable weights from the ``scoring_config`` table (falling
|
||||
back to env-var defaults) to compute coverage scores with detailed
|
||||
breakdowns.
|
||||
|
||||
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
|
||||
fixed number of aggregated queries so that organisation-wide calculations
|
||||
@@ -14,7 +15,6 @@ from typing import Optional
|
||||
from sqlalchemy import case, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.detection_rule import DetectionRule
|
||||
@@ -22,20 +22,12 @@ from app.models.test_detection_result import TestDetectionResult
|
||||
from app.models.defensive_technique import DefensiveTechniqueMapping
|
||||
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
|
||||
from app.models.enums import TestState, TestResult
|
||||
from app.services.scoring_config_service import get_scoring_weights
|
||||
|
||||
|
||||
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
|
||||
|
||||
|
||||
def _build_empty_stats():
|
||||
return {
|
||||
"validated": 0,
|
||||
"detected": 0,
|
||||
"platforms": set(),
|
||||
"latest_validated_at": None,
|
||||
}
|
||||
|
||||
|
||||
def bulk_technique_scores(db: Session) -> dict:
|
||||
"""Pre-fetch all scoring data and compute per-technique scores in memory.
|
||||
|
||||
@@ -48,11 +40,12 @@ def bulk_technique_scores(db: Session) -> dict:
|
||||
|
||||
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
|
||||
"""
|
||||
w_tests = settings.SCORING_WEIGHT_TESTS
|
||||
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
|
||||
w_d3fend = settings.SCORING_WEIGHT_D3FEND
|
||||
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
|
||||
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
||||
w = get_scoring_weights(db)
|
||||
w_tests = w.tests
|
||||
w_detection = w.detection_rules
|
||||
w_d3fend = w.d3fend
|
||||
w_freshness = w.freshness
|
||||
w_diversity = w.platform_diversity
|
||||
|
||||
# Q1: test stats grouped by technique_id
|
||||
test_rows = (
|
||||
@@ -242,18 +235,14 @@ def bulk_technique_scores(db: Session) -> dict:
|
||||
def calculate_technique_score(technique: Technique, db: Session) -> dict:
|
||||
"""Calculate a 0-100 score for a technique with detailed breakdown.
|
||||
|
||||
Weights (configurable via settings):
|
||||
- tests_validated: weight from SCORING_WEIGHT_TESTS
|
||||
- detection_rules: weight from SCORING_WEIGHT_DETECTION_RULES
|
||||
- d3fend_coverage: weight from SCORING_WEIGHT_D3FEND
|
||||
- freshness: weight from SCORING_WEIGHT_FRESHNESS
|
||||
- platform_diversity: weight from SCORING_WEIGHT_PLATFORM_DIVERSITY
|
||||
Weights are read from the ``scoring_config`` table (or env defaults).
|
||||
"""
|
||||
w_tests = settings.SCORING_WEIGHT_TESTS
|
||||
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
|
||||
w_d3fend = settings.SCORING_WEIGHT_D3FEND
|
||||
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
|
||||
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
|
||||
w = get_scoring_weights(db)
|
||||
w_tests = w.tests
|
||||
w_detection = w.detection_rules
|
||||
w_d3fend = w.d3fend
|
||||
w_freshness = w.freshness
|
||||
w_diversity = w.platform_diversity
|
||||
|
||||
breakdown = {}
|
||||
|
||||
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Test CRUD service — list, create, update, and query logic for security tests.
|
||||
|
||||
Framework-agnostic; uses domain exceptions from app.domain.errors.
|
||||
The router is responsible for HTTP concerns, auth, audit logging, and commit.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.domain.errors import (
|
||||
BusinessRuleViolation,
|
||||
EntityNotFoundError,
|
||||
PermissionViolation,
|
||||
)
|
||||
from app.models.enums import TestState
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.audit import AuditLog
|
||||
from app.utils import escape_like
|
||||
|
||||
|
||||
def list_tests(
|
||||
db: Session,
|
||||
*,
|
||||
state: str | None = None,
|
||||
technique_id: uuid.UUID | None = None,
|
||||
platform: str | None = None,
|
||||
created_by: uuid.UUID | None = None,
|
||||
pending_validation_side: str | None = None,
|
||||
offset: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[Test]:
|
||||
"""Return a paginated list of tests with optional filters."""
|
||||
query = db.query(Test).options(joinedload(Test.technique))
|
||||
|
||||
if state:
|
||||
query = query.filter(Test.state == state)
|
||||
if technique_id:
|
||||
query = query.filter(Test.technique_id == technique_id)
|
||||
if platform:
|
||||
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
|
||||
if created_by:
|
||||
query = query.filter(Test.created_by == created_by)
|
||||
if pending_validation_side == "red":
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.red_validation_status.in_(["pending", None]),
|
||||
)
|
||||
elif pending_validation_side == "blue":
|
||||
query = query.filter(
|
||||
Test.state == TestState.in_review,
|
||||
Test.blue_validation_status.in_(["pending", None]),
|
||||
)
|
||||
|
||||
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
|
||||
def create_test(
|
||||
db: Session,
|
||||
*,
|
||||
technique_id: uuid.UUID,
|
||||
creator_id: uuid.UUID,
|
||||
**fields: Any,
|
||||
) -> Test:
|
||||
"""Create a new test linked to an existing technique.
|
||||
|
||||
Raises EntityNotFoundError if the technique does not exist.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
technique = db.query(Technique).filter(Technique.id == technique_id).first()
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", str(technique_id))
|
||||
|
||||
test = Test(
|
||||
technique_id=technique_id,
|
||||
created_by=creator_id,
|
||||
state=TestState.draft,
|
||||
**fields,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def create_test_from_template(
|
||||
db: Session,
|
||||
*,
|
||||
template_id: uuid.UUID,
|
||||
technique_id_or_mitre: str,
|
||||
creator_id: uuid.UUID,
|
||||
) -> Test:
|
||||
"""Instantiate a Test from a TestTemplate.
|
||||
|
||||
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
|
||||
Raises EntityNotFoundError if template or technique not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
|
||||
if template is None:
|
||||
raise EntityNotFoundError("TestTemplate", str(template_id))
|
||||
|
||||
technique = None
|
||||
try:
|
||||
technique_uuid = uuid.UUID(technique_id_or_mitre)
|
||||
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if technique is None:
|
||||
technique = db.query(Technique).filter(
|
||||
Technique.mitre_id == technique_id_or_mitre
|
||||
).first()
|
||||
|
||||
if technique is None:
|
||||
raise EntityNotFoundError("Technique", technique_id_or_mitre)
|
||||
|
||||
test = Test(
|
||||
technique_id=technique.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
platform=template.platform,
|
||||
procedure_text=template.attack_procedure,
|
||||
tool_used=template.tool_suggested,
|
||||
remediation_steps=template.suggested_remediation,
|
||||
created_by=creator_id,
|
||||
state=TestState.draft,
|
||||
)
|
||||
db.add(test)
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test with evidences eager-loaded.
|
||||
|
||||
Raises EntityNotFoundError if the test does not exist.
|
||||
"""
|
||||
test = (
|
||||
db.query(Test)
|
||||
.options(joinedload(Test.evidences))
|
||||
.filter(Test.id == test_id)
|
||||
.first()
|
||||
)
|
||||
if test is None:
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
return test
|
||||
|
||||
|
||||
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
|
||||
test = db.query(Test).filter(Test.id == test_id).first()
|
||||
if test is None:
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
return test
|
||||
|
||||
|
||||
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
|
||||
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
|
||||
test = (
|
||||
db.query(Test)
|
||||
.options(joinedload(Test.technique))
|
||||
.filter(Test.id == test_id)
|
||||
.first()
|
||||
)
|
||||
if test is None:
|
||||
raise EntityNotFoundError("Test", str(test_id))
|
||||
return test
|
||||
|
||||
|
||||
def update_test(
|
||||
db: Session,
|
||||
test_id: uuid.UUID,
|
||||
*,
|
||||
updater_id: uuid.UUID,
|
||||
updater_role: str,
|
||||
**fields: Any,
|
||||
) -> Test:
|
||||
"""Update general test fields (draft or rejected only).
|
||||
|
||||
Raises PermissionViolation if not creator or admin.
|
||||
Raises BusinessRuleViolation if state is not draft or rejected.
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
if updater_role != "admin" and test.created_by != updater_id:
|
||||
raise PermissionViolation(
|
||||
"Only the test creator or an admin can update this test"
|
||||
)
|
||||
|
||||
if test.state not in (TestState.draft, TestState.rejected):
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
|
||||
)
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||
"""Update Red Team fields (draft or red_executing only).
|
||||
|
||||
Raises BusinessRuleViolation if state not in (draft, red_executing).
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
if test.state not in (TestState.draft, TestState.red_executing):
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update red fields in '{test.state.value}' state "
|
||||
"(must be draft or red_executing)"
|
||||
)
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
|
||||
"""Update Blue Team fields (blue_evaluating only).
|
||||
|
||||
Raises BusinessRuleViolation if state is not blue_evaluating.
|
||||
Raises EntityNotFoundError if test not found.
|
||||
Does not commit; caller uses UnitOfWork.
|
||||
"""
|
||||
test = get_test_or_raise(db, test_id)
|
||||
|
||||
if test.state != TestState.blue_evaluating:
|
||||
raise BusinessRuleViolation(
|
||||
f"Cannot update blue fields in '{test.state.value}' state "
|
||||
"(must be blue_evaluating)"
|
||||
)
|
||||
|
||||
for field, value in fields.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.flush()
|
||||
return test
|
||||
|
||||
|
||||
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
|
||||
"""Return chronological audit-log history for a test.
|
||||
|
||||
Raises EntityNotFoundError if the test does not exist.
|
||||
"""
|
||||
get_test_or_raise(db, test_id)
|
||||
|
||||
logs = (
|
||||
db.query(AuditLog)
|
||||
.filter(
|
||||
AuditLog.entity_type == "test",
|
||||
AuditLog.entity_id == str(test_id),
|
||||
)
|
||||
.order_by(AuditLog.timestamp.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(log.id),
|
||||
"action": log.action,
|
||||
"user_id": str(log.user_id) if log.user_id else None,
|
||||
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
||||
"details": log.details,
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
@@ -89,12 +89,12 @@ for mod_name in [
|
||||
# Imports
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from fastapi import HTTPException
|
||||
from app.domain.errors import PermissionViolation
|
||||
from app.models.enums import TeamSide, TestState
|
||||
from app.routers.evidence import (
|
||||
router,
|
||||
_validate_upload_permission,
|
||||
_validate_delete_permission,
|
||||
from app.routers.evidence import router
|
||||
from app.services.evidence_service import (
|
||||
validate_delete_permission,
|
||||
validate_upload_permission,
|
||||
)
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ def test_red_tech_upload_red_in_red_executing():
|
||||
test = _make_test(TestState.red_executing)
|
||||
user = _make_user("red_tech")
|
||||
# Should not raise
|
||||
_validate_upload_permission(test, TeamSide.red, user)
|
||||
validate_upload_permission(test, TeamSide.red, user.role)
|
||||
print(" [PASS] red_tech can upload team=red in red_executing")
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ def test_red_tech_upload_red_in_red_executing():
|
||||
def test_red_tech_upload_red_in_draft():
|
||||
test = _make_test(TestState.draft)
|
||||
user = _make_user("red_tech")
|
||||
_validate_upload_permission(test, TeamSide.red, user)
|
||||
validate_upload_permission(test, TeamSide.red, user.role)
|
||||
print(" [PASS] red_tech can upload team=red in draft")
|
||||
|
||||
|
||||
@@ -157,10 +157,10 @@ def test_red_tech_cannot_upload_blue():
|
||||
test = _make_test(TestState.red_executing)
|
||||
user = _make_user("red_tech")
|
||||
try:
|
||||
_validate_upload_permission(test, TeamSide.blue, user)
|
||||
assert False, "Should have raised HTTPException"
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 403
|
||||
validate_upload_permission(test, TeamSide.blue, user.role)
|
||||
assert False, "Should have raised PermissionViolation"
|
||||
except PermissionViolation:
|
||||
pass
|
||||
print(" [PASS] red_tech CANNOT upload team=blue (403)")
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ def test_red_tech_cannot_upload_blue():
|
||||
def test_blue_tech_upload_blue_in_blue_evaluating():
|
||||
test = _make_test(TestState.blue_evaluating)
|
||||
user = _make_user("blue_tech")
|
||||
_validate_upload_permission(test, TeamSide.blue, user)
|
||||
validate_upload_permission(test, TeamSide.blue, user.role)
|
||||
print(" [PASS] blue_tech can upload team=blue in blue_evaluating")
|
||||
|
||||
|
||||
@@ -185,10 +185,10 @@ def test_blue_tech_cannot_upload_red():
|
||||
test = _make_test(TestState.blue_evaluating)
|
||||
user = _make_user("blue_tech")
|
||||
try:
|
||||
_validate_upload_permission(test, TeamSide.red, user)
|
||||
assert False, "Should have raised HTTPException"
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 403
|
||||
validate_upload_permission(test, TeamSide.red, user.role)
|
||||
assert False, "Should have raised PermissionViolation"
|
||||
except PermissionViolation:
|
||||
pass
|
||||
print(" [PASS] blue_tech CANNOT upload team=red (403)")
|
||||
|
||||
|
||||
@@ -223,10 +223,10 @@ def test_delete_in_review_fails():
|
||||
user = _make_user("red_tech")
|
||||
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
|
||||
try:
|
||||
_validate_delete_permission(test, evidence, user)
|
||||
assert False, "Should have raised HTTPException"
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 403
|
||||
validate_delete_permission(test, evidence, user.role, user.id)
|
||||
assert False, "Should have raised PermissionViolation"
|
||||
except PermissionViolation:
|
||||
pass
|
||||
print(" [PASS] DELETE in in_review -> 403")
|
||||
|
||||
|
||||
@@ -240,7 +240,7 @@ def test_delete_red_evidence_in_red_executing():
|
||||
user = _make_user("red_tech")
|
||||
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
|
||||
# Should not raise
|
||||
_validate_delete_permission(test, evidence, user)
|
||||
validate_delete_permission(test, evidence, user.role, user.id)
|
||||
print(" [PASS] DELETE red evidence in red_executing -> allowed")
|
||||
|
||||
|
||||
@@ -254,11 +254,11 @@ def test_admin_bypass():
|
||||
|
||||
# Red in blue_evaluating (normally blocked)
|
||||
test1 = _make_test(TestState.blue_evaluating)
|
||||
_validate_upload_permission(test1, TeamSide.red, admin)
|
||||
validate_upload_permission(test1, TeamSide.red, admin.role)
|
||||
|
||||
# Blue in draft (normally blocked)
|
||||
test2 = _make_test(TestState.draft)
|
||||
_validate_upload_permission(test2, TeamSide.blue, admin)
|
||||
validate_upload_permission(test2, TeamSide.blue, admin.role)
|
||||
|
||||
print(" [PASS] Admin can upload any team in any state")
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ from app.routers.test_templates import (
|
||||
toggle_template_active,
|
||||
template_stats,
|
||||
)
|
||||
from app.routers.tests import create_test_from_template
|
||||
from app.services.test_crud_service import create_test_from_template as crud_create_from_template
|
||||
from app.schemas.test_template import TestTemplateCreate
|
||||
|
||||
|
||||
@@ -174,7 +174,8 @@ def test_get_templates_by_technique():
|
||||
|
||||
def test_instantiate_template():
|
||||
"""POST /tests/from-template creates a test pre-filled from template data."""
|
||||
source = inspect.getsource(create_test_from_template)
|
||||
# Template field copying lives in the service; router delegates to it
|
||||
source = inspect.getsource(crud_create_from_template)
|
||||
|
||||
# Verify it reads from template and copies fields
|
||||
assert "template" in source, "Must reference template"
|
||||
|
||||
@@ -419,56 +419,57 @@ def test_dual_validation_red_approves_blue_rejects(mock_log):
|
||||
|
||||
def test_evidence_team_separation():
|
||||
"""Verify evidence router logic separates red and blue evidence correctly."""
|
||||
from app.routers.evidence import _validate_upload_permission, _RED_EDITABLE_STATES, _BLUE_EDITABLE_STATES
|
||||
from app.domain.errors import BusinessRuleViolation, PermissionViolation
|
||||
from app.models.enums import TeamSide
|
||||
from app.services.evidence_service import validate_upload_permission
|
||||
|
||||
# Red tech can upload red evidence in draft
|
||||
test = _make_test(TestState.draft)
|
||||
red_user = _make_user("red_tech")
|
||||
red_user.role = "red_tech"
|
||||
from app.models.enums import TeamSide
|
||||
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
|
||||
validate_upload_permission(test, TeamSide.red, red_user.role) # should not raise
|
||||
|
||||
# Red tech can upload red evidence in red_executing
|
||||
test.state = TestState.red_executing
|
||||
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
|
||||
validate_upload_permission(test, TeamSide.red, red_user.role) # should not raise
|
||||
|
||||
# Red tech CANNOT upload red evidence in blue_evaluating
|
||||
# Red tech CANNOT upload red evidence in blue_evaluating (state violation -> 400)
|
||||
test.state = TestState.blue_evaluating
|
||||
try:
|
||||
_validate_upload_permission(test, TeamSide.red, red_user)
|
||||
assert False, "Should have raised HTTPException"
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 400
|
||||
validate_upload_permission(test, TeamSide.red, red_user.role)
|
||||
assert False, "Should have raised BusinessRuleViolation"
|
||||
except BusinessRuleViolation:
|
||||
pass
|
||||
|
||||
# Red tech CANNOT upload blue evidence
|
||||
# Red tech CANNOT upload blue evidence (role violation -> 403)
|
||||
test.state = TestState.blue_evaluating
|
||||
try:
|
||||
_validate_upload_permission(test, TeamSide.blue, red_user)
|
||||
assert False, "Should have raised HTTPException"
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 403
|
||||
validate_upload_permission(test, TeamSide.blue, red_user.role)
|
||||
assert False, "Should have raised PermissionViolation"
|
||||
except PermissionViolation:
|
||||
pass
|
||||
|
||||
# Blue tech can upload blue evidence in blue_evaluating
|
||||
test.state = TestState.blue_evaluating
|
||||
blue_user = _make_user("blue_tech")
|
||||
blue_user.role = "blue_tech"
|
||||
_validate_upload_permission(test, TeamSide.blue, blue_user) # should not raise
|
||||
validate_upload_permission(test, TeamSide.blue, blue_user.role) # should not raise
|
||||
|
||||
# Blue tech CANNOT upload blue evidence in draft
|
||||
# Blue tech CANNOT upload blue evidence in draft (state violation -> 400)
|
||||
test.state = TestState.draft
|
||||
try:
|
||||
_validate_upload_permission(test, TeamSide.blue, blue_user)
|
||||
assert False, "Should have raised HTTPException"
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 400
|
||||
validate_upload_permission(test, TeamSide.blue, blue_user.role)
|
||||
assert False, "Should have raised BusinessRuleViolation"
|
||||
except BusinessRuleViolation:
|
||||
pass
|
||||
|
||||
# Blue tech CANNOT upload red evidence
|
||||
# Blue tech CANNOT upload red evidence (role violation -> 403)
|
||||
test.state = TestState.draft
|
||||
try:
|
||||
_validate_upload_permission(test, TeamSide.red, blue_user)
|
||||
assert False, "Should have raised HTTPException"
|
||||
except HTTPException as exc:
|
||||
assert exc.status_code == 403
|
||||
validate_upload_permission(test, TeamSide.red, blue_user.role)
|
||||
assert False, "Should have raised PermissionViolation"
|
||||
except PermissionViolation:
|
||||
pass
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
@@ -477,15 +478,15 @@ def test_evidence_team_separation():
|
||||
|
||||
|
||||
def test_red_edit_allowed_in_draft_and_red_executing():
|
||||
"""Verify the red update router checks that state is draft or red_executing."""
|
||||
from app.routers.tests import update_test_red
|
||||
"""Verify the red update checks that state is draft or red_executing."""
|
||||
from app.services.test_crud_service import update_test_red
|
||||
import inspect
|
||||
source = inspect.getsource(update_test_red)
|
||||
|
||||
# The function must guard against states other than draft/red_executing
|
||||
# The service must guard against states other than draft/red_executing
|
||||
assert "draft" in source, "Red update must allow draft state"
|
||||
assert "red_executing" in source, "Red update must allow red_executing state"
|
||||
assert "400" in source or "HTTP_400_BAD_REQUEST" in source, "Red update must return 400 for invalid state"
|
||||
assert "BusinessRuleViolation" in source, "Must raise domain exception for invalid state (mapped to 400)"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
||||
+15
-15
@@ -2,30 +2,30 @@
|
||||
|
||||
## Tier 1 — Quick Wins
|
||||
|
||||
- [ ] QW-1: Wire existing repos into `techniques.py` router
|
||||
- [ ] QW-2: Fix `audit_service` to follow UoW (no direct `db.commit()`)
|
||||
- [ ] QW-3: Consolidate `status_service` with `TechniqueEntity.recalculate_status()`
|
||||
- [ ] QW-4: Remove remaining `HTTPException` from services
|
||||
- [x] QW-1: Wire existing repos into `techniques.py` router
|
||||
- [~] QW-2: Fix `audit_service` to follow UoW — deferred, resolves naturally as routers adopt UoW
|
||||
- [x] QW-3: Consolidate `status_service` with `TechniqueEntity.recalculate_status()`
|
||||
- [x] QW-4: Remove remaining `HTTPException` from services — already resolved
|
||||
|
||||
## Tier 2 — Service Extraction (fat routers → thin routers + services)
|
||||
|
||||
- [ ] SE-1: Extract reports service from `reports.py`
|
||||
- [ ] SE-2: Extract metrics service from `metrics.py`
|
||||
- [ ] SE-3: Extract compliance service from `compliance.py`
|
||||
- [ ] SE-4: Extract detection_rules service from `detection_rules.py`
|
||||
- [ ] SE-5: Extract threat_actors service from `threat_actors.py`
|
||||
- [x] SE-1: Extract reports service → `coverage_report_service.py`
|
||||
- [x] SE-2: Extract metrics service → `metrics_query_service.py`
|
||||
- [x] SE-3: Extract compliance service → `compliance_service.py`
|
||||
- [x] SE-4: Extract detection_rules service → `detection_rule_service.py`
|
||||
- [x] SE-5: Extract threat_actors service → `threat_actor_service.py`
|
||||
|
||||
## Tier 3 — Architectural Fixes
|
||||
|
||||
- [ ] AF-1: Persist scoring weights in DB (replace mutable `settings`)
|
||||
- [ ] AF-2: Slim `tests.py` router (CRUD to repo/service)
|
||||
- [ ] AF-3: Slim `evidence.py` router (permissions to domain)
|
||||
- [ ] AF-4: Slim `campaigns.py` router (CRUD to service)
|
||||
- [x] AF-1: Persist scoring weights in DB → `scoring_config` table + `scoring_config_service.py`
|
||||
- [x] AF-2: Slim `tests.py` router → `test_crud_service.py`
|
||||
- [x] AF-3: Slim `evidence.py` router → `evidence_service.py`
|
||||
- [x] AF-4: Slim `campaigns.py` router → `campaign_crud_service.py`
|
||||
|
||||
## Tier 4 — Polish
|
||||
|
||||
- [ ] P-1: Structured JSON logging
|
||||
- [ ] P-2: Create architecture skill file for future agents
|
||||
- [x] P-1: Structured JSON logging → `logging_config.py`
|
||||
- [x] P-2: Create architecture skill file → `~/.cursor/skills/aegis-architecture/SKILL.md`
|
||||
|
||||
## Completed (prior sessions)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user