Compare commits

...

6 Commits

18 changed files with 1412 additions and 858 deletions
@@ -0,0 +1,37 @@
"""add_scoring_config
Single-row table to persist scoring weights in the database,
replacing the mutable in-process Settings approach.
Revision ID: b027scorecfg
Revises: b026techidx
Create Date: 2026-02-19 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision: str = "b027scorecfg"
down_revision: Union[str, None] = "b026techidx"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"scoring_config",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("weight_tests", sa.Float(), nullable=False, server_default="40.0"),
sa.Column("weight_detection_rules", sa.Float(), nullable=False, server_default="20.0"),
sa.Column("weight_d3fend", sa.Float(), nullable=False, server_default="15.0"),
sa.Column("weight_freshness", sa.Float(), nullable=False, server_default="15.0"),
sa.Column("weight_platform_diversity", sa.Float(), nullable=False, server_default="10.0"),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("scoring_config")
+67
View File
@@ -0,0 +1,67 @@
"""Structured JSON logging configuration.
In **production** (``AEGIS_ENV=production``), emits one JSON object per
line so that log aggregators (ELK, CloudWatch, Datadog) can ingest them
without custom parsing.
In **development** (default), uses a human-readable text format for
comfortable local work.
"""
from __future__ import annotations
import json
import logging
import os
import sys
from datetime import datetime, timezone
class _JSONFormatter(logging.Formatter):
"""Emit each log record as a single-line JSON object."""
def format(self, record: logging.LogRecord) -> str:
payload: dict = {
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info and record.exc_info[1] is not None:
payload["exception"] = self.formatException(record.exc_info)
extra = getattr(record, "_extra", None)
if extra:
payload.update(extra)
return json.dumps(payload, default=str)
_DEV_FORMAT = "%(asctime)s %(levelname)-8s %(name)s%(message)s"
def setup_logging() -> None:
"""Configure the root logger based on the environment."""
is_production = os.environ.get("AEGIS_ENV", "").lower() == "production"
level_name = os.environ.get("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
root = logging.getLogger()
root.setLevel(level)
if root.handlers:
root.handlers.clear()
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(level)
if is_production:
handler.setFormatter(_JSONFormatter())
else:
handler.setFormatter(logging.Formatter(_DEV_FORMAT))
root.addHandler(handler)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
+3 -4
View File
@@ -47,10 +47,9 @@ from app.jobs.mitre_sync_job import start_scheduler, scheduler
_IS_PRODUCTION = os.environ.get("AEGIS_ENV", "").lower() == "production"
# ── Logging ───────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s%(message)s",
)
from app.logging_config import setup_logging
setup_logging()
@asynccontextmanager
async def lifespan(app: FastAPI):
+2 -1
View File
@@ -19,6 +19,7 @@ from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueStat
from app.models.jira_link import JiraLink, JiraLinkEntityType, JiraSyncDirection
from app.models.worklog import Worklog
from app.models.osint_item import OsintItem
from app.models.scoring_config import ScoringConfig
from app.models.enums import TechniqueStatus, TestState, TestResult, TeamSide
__all__ = [
@@ -31,6 +32,6 @@ __all__ = [
"ComplianceFramework", "ComplianceControl", "ComplianceControlMapping",
"CoverageSnapshot", "SnapshotTechniqueState",
"JiraLink", "JiraLinkEntityType", "JiraSyncDirection",
"Worklog", "OsintItem",
"Worklog", "OsintItem", "ScoringConfig",
"TechniqueStatus", "TestState", "TestResult", "TeamSide",
]
+24
View File
@@ -0,0 +1,24 @@
"""ScoringConfig — single-row table for persisted scoring weights.
Replaces the mutable-settings approach where PATCH /scores/config
mutated the in-process ``Settings`` object (lost on restart).
"""
import uuid
from sqlalchemy import Column, Float, DateTime, func
from sqlalchemy.dialects.postgresql import UUID
from app.database import Base
class ScoringConfig(Base):
__tablename__ = "scoring_config"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
weight_tests = Column(Float, nullable=False, default=40.0)
weight_detection_rules = Column(Float, nullable=False, default=20.0)
weight_d3fend = Column(Float, nullable=False, default=15.0)
weight_freshness = Column(Float, nullable=False, default=15.0)
weight_platform_diversity = Column(Float, nullable=False, default=10.0)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
+69 -325
View File
@@ -7,26 +7,29 @@ test ordering, progress tracking, and threat actor integration.
import logging
import uuid
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.user import User
from app.models.campaign import Campaign, CampaignTest, KILL_CHAIN_PHASES
from app.models.test import Test
from app.models.technique import Technique
from app.models.threat_actor import ThreatActor
from app.services.campaign_service import (
validate_no_circular_dependency,
get_campaign_progress,
generate_campaign_from_threat_actor,
TACTIC_TO_PHASE,
from app.services.campaign_service import generate_campaign_from_threat_actor
from app.services.campaign_crud_service import (
add_test_to_campaign as crud_add_test,
activate_campaign as crud_activate,
complete_campaign as crud_complete,
create_campaign as crud_create,
get_campaign_detail as crud_get_detail,
get_campaign_history as crud_get_history,
get_campaign_progress_data as crud_get_progress,
list_campaigns as crud_list,
remove_test_from_campaign as crud_remove_test,
schedule_campaign as crud_schedule,
serialize_campaign,
update_campaign as crud_update,
)
from app.services.campaign_scheduler_service import calculate_next_run
from app.services.notification_service import create_notification
from app.services.audit_service import log_action
@@ -67,89 +70,6 @@ class SchedulePayload(BaseModel):
next_run_at: Optional[str] = None
# ── Helpers ──────────────────────────────────────────────────────────
def _serialize_campaign(db: Session, campaign: Campaign) -> dict:
"""Serialize a campaign with its tests and progress."""
progress = get_campaign_progress(db, campaign.id)
campaign_tests = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.order_by(CampaignTest.order_index)
.all()
)
tests = []
for ct in campaign_tests:
test = ct.test
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
tests.append({
"id": str(ct.id),
"test_id": str(ct.test_id),
"order_index": ct.order_index,
"depends_on": str(ct.depends_on) if ct.depends_on else None,
"phase": ct.phase,
"test_name": test.name if test else None,
"test_state": test.state.value if test and test.state else None,
"test_result": test.result.value if test and test.result else None,
"technique_mitre_id": technique.mitre_id if technique else None,
"technique_name": technique.name if technique else None,
"platform": test.platform if test else None,
})
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"created_by": str(campaign.created_by) if campaign.created_by else None,
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
"tests": tests,
"progress": progress,
}
def _serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"""Lightweight campaign serialization for list views."""
progress = get_campaign_progress(db, campaign.id)
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"test_count": progress["total"],
"completion_pct": progress["completion_pct"],
}
# ---------------------------------------------------------------------------
# GET /campaigns — List campaigns with filters
# ---------------------------------------------------------------------------
@@ -166,28 +86,15 @@ def list_campaigns(
current_user: User = Depends(get_current_user),
):
"""List campaigns with optional filters and pagination."""
query = db.query(Campaign)
if type:
query = query.filter(Campaign.type == type)
if status:
query = query.filter(Campaign.status == status)
if threat_actor_id:
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
if search:
from app.utils import escape_like
pattern = f"%{escape_like(search)}%"
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
total = query.count()
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [_serialize_campaign_summary(db, c) for c in campaigns],
}
return crud_list(
db,
type=type,
status=status,
threat_actor_id=threat_actor_id,
search=search,
offset=offset,
limit=limit,
)
# ---------------------------------------------------------------------------
@@ -201,30 +108,29 @@ def create_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Create a new campaign."""
campaign = Campaign(
result = crud_create(
db,
creator_id=current_user.id,
name=payload.name,
description=payload.description,
type=payload.type,
threat_actor_id=payload.threat_actor_id,
target_platform=payload.target_platform,
tags=payload.tags or [],
created_by=current_user.id,
scheduled_at=datetime.fromisoformat(payload.scheduled_at) if payload.scheduled_at else None,
tags=payload.tags,
scheduled_at=payload.scheduled_at,
)
db.add(campaign)
db.commit()
db.refresh(campaign)
log_action(
db,
user_id=current_user.id,
action="create_campaign",
entity_type="campaign",
entity_id=campaign.id,
details={"name": campaign.name, "type": campaign.type},
entity_id=result["id"],
details={"name": payload.name, "type": payload.type},
)
db.commit()
return _serialize_campaign(db, campaign)
return result
# ---------------------------------------------------------------------------
@@ -238,11 +144,7 @@ def get_campaign(
current_user: User = Depends(get_current_user),
):
"""Get detailed campaign info including tests and progress."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
return _serialize_campaign(db, campaign)
return crud_get_detail(db, campaign_id)
# ---------------------------------------------------------------------------
@@ -257,37 +159,26 @@ def update_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Update a campaign. Only allowed in draft or active state."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status not in ("draft", "active"):
raise HTTPException(status_code=400, detail="Can only update draft or active campaigns")
# Check ownership or admin
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
raise HTTPException(status_code=403, detail="Only the creator or admin can update this campaign")
update_data = payload.model_dump(exclude_unset=True)
if "scheduled_at" in update_data and update_data["scheduled_at"]:
update_data["scheduled_at"] = datetime.fromisoformat(update_data["scheduled_at"])
for field, value in update_data.items():
setattr(campaign, field, value)
db.commit()
db.refresh(campaign)
result = crud_update(
db,
campaign_id,
updater_id=current_user.id,
updater_role=current_user.role,
**update_data,
)
log_action(
db,
user_id=current_user.id,
action="update_campaign",
entity_type="campaign",
entity_id=campaign.id,
entity_id=campaign_id,
details={"updated_fields": list(update_data.keys())},
)
db.commit()
return _serialize_campaign(db, campaign)
return result
# ---------------------------------------------------------------------------
@@ -302,63 +193,16 @@ def add_test_to_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Add a test to a campaign with optional ordering and dependency."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status not in ("draft", "active"):
raise HTTPException(status_code=400, detail="Can only add tests to draft or active campaigns")
test = db.query(Test).filter(Test.id == payload.test_id).first()
if not test:
raise HTTPException(status_code=404, detail="Test not found")
# Calculate order_index if not provided
if payload.order_index is not None:
order_index = payload.order_index
else:
max_order = (
db.query(CampaignTest.order_index)
.filter(CampaignTest.campaign_id == campaign_id)
.order_by(CampaignTest.order_index.desc())
.first()
)
order_index = (max_order[0] + 1) if max_order else 0
depends_on = uuid.UUID(payload.depends_on) if payload.depends_on else None
# Validate circular dependency
ct_id = uuid.uuid4()
if depends_on:
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on)
# Auto-detect kill chain phase from the test's technique tactic if not provided
phase = payload.phase
if not phase and test.technique_id:
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if technique and technique.tactic:
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
campaign_test = CampaignTest(
id=ct_id,
campaign_id=campaign_id,
result = crud_add_test(
db,
campaign_id,
test_id=payload.test_id,
order_index=order_index,
depends_on=depends_on,
phase=phase,
order_index=payload.order_index,
depends_on=payload.depends_on,
phase=payload.phase,
)
db.add(campaign_test)
db.commit()
db.refresh(campaign_test)
return {
"id": str(campaign_test.id),
"campaign_id": str(campaign_test.campaign_id),
"test_id": str(campaign_test.test_id),
"order_index": campaign_test.order_index,
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
"phase": campaign_test.phase,
}
return result
# ---------------------------------------------------------------------------
@@ -373,36 +217,8 @@ def remove_test_from_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Remove a test from a campaign."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status not in ("draft", "active"):
raise HTTPException(status_code=400, detail="Can only modify draft or active campaigns")
ct = (
db.query(CampaignTest)
.filter(
CampaignTest.id == campaign_test_id,
CampaignTest.campaign_id == campaign_id,
)
.first()
)
if not ct:
raise HTTPException(status_code=404, detail="Campaign test not found")
# Clear any references to this campaign_test
dependents = (
db.query(CampaignTest)
.filter(CampaignTest.depends_on == campaign_test_id)
.all()
)
for dep in dependents:
dep.depends_on = None
db.delete(ct)
crud_remove_test(db, campaign_id, campaign_test_id)
db.commit()
return {"detail": "Test removed from campaign"}
@@ -417,23 +233,10 @@ def activate_campaign(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Activate a campaign, moving it from draft to active."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status != "draft":
raise HTTPException(status_code=400, detail="Only draft campaigns can be activated")
# Verify campaign has at least one test
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
if test_count == 0:
raise HTTPException(status_code=400, detail="Campaign must have at least one test to activate")
campaign.status = "active"
campaign = crud_activate(db, campaign_id)
db.commit()
db.refresh(campaign)
# Notify relevant users
red_techs = db.query(User).filter(User.role == "red_tech", User.is_active == True).all() # noqa: E712
for user in red_techs:
create_notification(
@@ -455,7 +258,7 @@ def activate_campaign(
details={"name": campaign.name},
)
return _serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# ---------------------------------------------------------------------------
@@ -469,15 +272,7 @@ def complete_campaign(
current_user: User = Depends(require_any_role("red_lead", "admin")),
):
"""Mark a campaign as completed."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
if campaign.status != "active":
raise HTTPException(status_code=400, detail="Only active campaigns can be completed")
campaign.status = "completed"
campaign.completed_at = datetime.utcnow()
campaign = crud_complete(db, campaign_id)
db.commit()
db.refresh(campaign)
@@ -490,7 +285,7 @@ def complete_campaign(
details={"name": campaign.name},
)
return _serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# ---------------------------------------------------------------------------
@@ -504,16 +299,7 @@ def get_campaign_progress_endpoint(
current_user: User = Depends(get_current_user),
):
"""Get progress statistics for a campaign."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
**progress,
}
return crud_get_progress(db, campaign_id)
# ---------------------------------------------------------------------------
@@ -546,7 +332,7 @@ def generate_campaign_from_actor(
details={"actor_id": actor_id, "campaign_name": campaign.name},
)
return _serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# ---------------------------------------------------------------------------
@@ -564,31 +350,15 @@ def schedule_campaign(
Only the campaign creator or admin can change scheduling.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
# Check ownership or admin
if str(campaign.created_by) != str(current_user.id) and current_user.role != "admin":
raise HTTPException(status_code=403, detail="Only the creator or admin can configure scheduling")
campaign.is_recurring = payload.is_recurring
if payload.is_recurring:
if payload.recurrence_pattern not in ("weekly", "monthly", "quarterly"):
raise HTTPException(
status_code=400,
detail="recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'",
)
campaign.recurrence_pattern = payload.recurrence_pattern
if payload.next_run_at:
campaign.next_run_at = datetime.fromisoformat(payload.next_run_at.replace("Z", "+00:00").replace("+00:00", ""))
elif not campaign.next_run_at:
campaign.next_run_at = calculate_next_run(datetime.utcnow(), payload.recurrence_pattern)
else:
campaign.recurrence_pattern = None
campaign.next_run_at = None
campaign = crud_schedule(
db,
campaign_id,
owner_id=current_user.id,
owner_role=current_user.role,
is_recurring=payload.is_recurring,
recurrence_pattern=payload.recurrence_pattern,
next_run_at=payload.next_run_at,
)
db.commit()
db.refresh(campaign)
@@ -605,7 +375,7 @@ def schedule_campaign(
},
)
return _serialize_campaign(db, campaign)
return serialize_campaign(db, campaign)
# ---------------------------------------------------------------------------
@@ -619,30 +389,4 @@ def get_campaign_history(
current_user: User = Depends(get_current_user),
):
"""List all child campaigns (execution history) of a recurring campaign."""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
children = (
db.query(Campaign)
.filter(Campaign.parent_campaign_id == campaign_id)
.order_by(Campaign.created_at.desc())
.all()
)
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
"items": [
{
"id": str(child.id),
"name": child.name,
"status": child.status,
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
"created_at": child.created_at.isoformat() if child.created_at else None,
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
}
for child in children
],
}
return crud_get_history(db, campaign_id)
+23 -174
View File
@@ -24,52 +24,32 @@ import os
import uuid as _uuid
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user
from app.models.enums import TeamSide, TestState
from app.models.enums import TeamSide
from app.models.evidence import Evidence
from app.models.test import Test
from app.models.user import User
from app.schemas.evidence import EvidenceOut
from app.services.audit_service import log_action
from app.services.evidence_service import (
get_evidence_or_raise,
get_test_or_raise,
list_evidence_for_test,
MAX_UPLOAD_SIZE,
validate_delete_permission,
validate_file,
validate_upload_permission,
)
from app.storage import get_presigned_url, upload_file
router = APIRouter(tags=["evidence"])
# States where red evidence can be uploaded / deleted
_RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
# States where blue evidence can be uploaded / deleted
_BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
# ---------------------------------------------------------------------------
# Upload safety limits
# ---------------------------------------------------------------------------
# Maximum upload size in bytes (default 50 MB)
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024
# Allowed file extensions (lowercase, with leading dot)
_ALLOWED_EXTENSIONS: set[str] = {
# Images / screenshots
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
# Documents
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
".md", ".rtf", ".odt", ".ods",
# Logs & captures
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
".yaml", ".yml", ".toml",
# Archives (for bundled evidence)
".zip", ".tar", ".gz", ".7z",
# Other common evidence types
".har", ".eml", ".msg",
}
# ---------------------------------------------------------------------------
# Helpers
# Helpers (router-specific: infrastructure / HTTP concerns)
# ---------------------------------------------------------------------------
def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
@@ -87,85 +67,6 @@ def _evidence_to_out(evidence: Evidence) -> EvidenceOut:
)
def _validate_upload_permission(
test: Test,
team: TeamSide,
user: User,
) -> None:
"""Raise 403 if the user/team combination is not allowed in the current state."""
# Admins bypass all checks
if user.role == "admin":
return
if team == TeamSide.red:
if user.role not in ("red_tech", "red_lead"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only red_tech, red_lead or admin can upload red evidence",
)
if test.state not in _RED_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot upload red evidence in '{test.state.value}' state "
f"(allowed in: draft, red_executing)",
)
elif team == TeamSide.blue:
if user.role not in ("blue_tech", "blue_lead"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only blue_tech, blue_lead or admin can upload blue evidence",
)
if test.state not in _BLUE_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot upload blue evidence in '{test.state.value}' state "
f"(allowed in: blue_evaluating)",
)
def _validate_delete_permission(
test: Test,
evidence: Evidence,
user: User,
) -> None:
"""Raise 403 if the user cannot delete this evidence in the current state."""
# No deletions in review / validated / rejected
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Cannot delete evidence when test is in '{test.state.value}' state",
)
# Admin can delete in editable states
if user.role == "admin":
return
ev_team = evidence.team
if ev_team == TeamSide.red:
if test.state not in _RED_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot delete red evidence outside draft/red_executing",
)
if user.role not in ("red_tech", "red_lead") and evidence.uploaded_by != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to delete this evidence",
)
elif ev_team == TeamSide.blue:
if test.state not in _BLUE_EDITABLE_STATES:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Cannot delete blue evidence outside blue_evaluating",
)
if user.role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to delete this evidence",
)
# ---------------------------------------------------------------------------
# POST /tests/{test_id}/evidence — upload with team
# ---------------------------------------------------------------------------
@@ -189,36 +90,14 @@ async def upload_evidence(
The ``team`` field (sent as form data) determines whether this is
Red Team (attack) or Blue Team (detection) evidence.
"""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
test = get_test_or_raise(db, test_id)
validate_upload_permission(test, team, current_user.role)
# Validate permissions
_validate_upload_permission(test, team, current_user)
# 1. Validate file extension
file_name = file.filename or "unnamed"
_, ext = os.path.splitext(file_name)
if ext.lower() not in _ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type '{ext}' is not allowed. "
f"Permitted types: {', '.join(sorted(_ALLOWED_EXTENSIONS))}",
)
content = await file.read(MAX_UPLOAD_SIZE + 1)
validate_file(file_name, len(content))
# 2. Read content with size limit
content = await file.read(_MAX_UPLOAD_SIZE + 1)
if len(content) > _MAX_UPLOAD_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File exceeds maximum upload size of "
f"{_MAX_UPLOAD_SIZE // (1024 * 1024)} MB",
)
# 3. Hash
# Hash
sha256 = hashlib.sha256(content).hexdigest()
# 4. Object key (sanitise filename to prevent path traversal in storage)
@@ -273,19 +152,8 @@ def list_evidence(
current_user: User = Depends(get_current_user),
):
"""List all evidences for a test, optionally filtered by team."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
query = db.query(Evidence).filter(Evidence.test_id == test_id)
if team:
query = query.filter(Evidence.team == team)
evidences = query.order_by(Evidence.uploaded_at.desc()).all()
get_test_or_raise(db, test_id)
evidences = list_evidence_for_test(db, test_id, team=team)
return [_evidence_to_out(e) for e in evidences]
@@ -301,13 +169,7 @@ def get_evidence(
current_user: User = Depends(get_current_user),
):
"""Return evidence metadata together with a presigned download URL."""
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
if evidence is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Evidence not found",
)
evidence = get_evidence_or_raise(db, evidence_id)
return _evidence_to_out(evidence)
@@ -329,22 +191,9 @@ def delete_evidence(
- Blue evidence: ``blue_evaluating``
- No deletions in ``in_review``, ``validated``, ``rejected``
"""
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
if evidence is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Evidence not found",
)
test = db.query(Test).filter(Test.id == evidence.test_id).first()
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Parent test not found",
)
# Permission checks
_validate_delete_permission(test, evidence, current_user)
evidence = get_evidence_or_raise(db, evidence_id)
test = get_test_or_raise(db, evidence.test_id)
validate_delete_permission(test, evidence, current_user.role, current_user.id)
# Audit before deletion
log_action(
+23 -54
View File
@@ -14,7 +14,6 @@ from app.dependencies.auth import get_current_user, require_role
from app.models.user import User
from app.models.technique import Technique
from app.models.threat_actor import ThreatActor
from app.config import settings
from app.services.scoring_service import (
calculate_technique_score,
calculate_tactic_score,
@@ -22,6 +21,10 @@ from app.services.scoring_service import (
calculate_organization_score,
get_score_history,
)
from app.services.scoring_config_service import (
get_weights_dict,
update_scoring_weights,
)
router = APIRouter(prefix="/scores", tags=["scores"])
@@ -117,79 +120,45 @@ def score_history(
@router.get("/config")
def get_scoring_config(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Get current scoring weights (admin only)."""
return {
"weights": {
"tests": settings.SCORING_WEIGHT_TESTS,
"detection_rules": settings.SCORING_WEIGHT_DETECTION_RULES,
"d3fend": settings.SCORING_WEIGHT_D3FEND,
"freshness": settings.SCORING_WEIGHT_FRESHNESS,
"platform_diversity": settings.SCORING_WEIGHT_PLATFORM_DIVERSITY,
},
"total": (
settings.SCORING_WEIGHT_TESTS
+ settings.SCORING_WEIGHT_DETECTION_RULES
+ settings.SCORING_WEIGHT_D3FEND
+ settings.SCORING_WEIGHT_FRESHNESS
+ settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
),
}
return get_weights_dict(db)
# ── PATCH /scores/config ─────────────────────────────────────────────
class ScoringConfigUpdate(BaseModel):
tests: Optional[int] = None
detection_rules: Optional[int] = None
d3fend: Optional[int] = None
freshness: Optional[int] = None
platform_diversity: Optional[int] = None
tests: Optional[float] = None
detection_rules: Optional[float] = None
d3fend: Optional[float] = None
freshness: Optional[float] = None
platform_diversity: Optional[float] = None
@router.patch("/config")
def update_scoring_config(
payload: ScoringConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
"""Update scoring weights (admin only).
Note: Since we're using Opcion A (env vars / Settings), changes
are applied at runtime but won't persist across restarts unless
the .env file is also updated. For production, consider migrating
to Option B (database table).
Weights are persisted in the database and survive restarts.
Validation enforces that all weights are non-negative and sum to 100.
"""
if payload.tests is not None:
settings.SCORING_WEIGHT_TESTS = payload.tests
if payload.detection_rules is not None:
settings.SCORING_WEIGHT_DETECTION_RULES = payload.detection_rules
if payload.d3fend is not None:
settings.SCORING_WEIGHT_D3FEND = payload.d3fend
if payload.freshness is not None:
settings.SCORING_WEIGHT_FRESHNESS = payload.freshness
if payload.platform_diversity is not None:
settings.SCORING_WEIGHT_PLATFORM_DIVERSITY = payload.platform_diversity
result = update_scoring_weights(
db,
tests=payload.tests,
detection_rules=payload.detection_rules,
d3fend=payload.d3fend,
freshness=payload.freshness,
platform_diversity=payload.platform_diversity,
)
# Weights changed — bust the score cache
from app.services.score_cache import invalidate
invalidate()
return {
"message": "Scoring config updated",
"weights": {
"tests": settings.SCORING_WEIGHT_TESTS,
"detection_rules": settings.SCORING_WEIGHT_DETECTION_RULES,
"d3fend": settings.SCORING_WEIGHT_D3FEND,
"freshness": settings.SCORING_WEIGHT_FRESHNESS,
"platform_diversity": settings.SCORING_WEIGHT_PLATFORM_DIVERSITY,
},
"total": (
settings.SCORING_WEIGHT_TESTS
+ settings.SCORING_WEIGHT_DETECTION_RULES
+ settings.SCORING_WEIGHT_D3FEND
+ settings.SCORING_WEIGHT_FRESHNESS
+ settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
),
}
return {"message": "Scoring config updated", **result}
+65 -203
View File
@@ -22,15 +22,11 @@ import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import Session
from app.database import get_db
from app.dependencies.auth import get_current_user, require_any_role
from app.models.audit import AuditLog
from app.models.enums import TestState, TeamSide
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.enums import TestState
from app.models.user import User
from app.schemas.test import (
TestCreate,
@@ -46,6 +42,18 @@ from app.schemas.test_template import TestTemplateInstantiate
from app.domain.unit_of_work import UnitOfWork
from app.services.audit_service import log_action
from app.services.status_service import recalculate_technique_status
from app.services.test_crud_service import (
create_test as crud_create_test,
create_test_from_template as crud_create_from_template,
get_test_detail as crud_get_test_detail,
get_test_or_raise as crud_get_test_or_raise,
get_test_timeline as crud_get_test_timeline,
get_test_with_technique as crud_get_test_with_technique,
list_tests as crud_list_tests,
update_test as crud_update_test,
update_test_blue as crud_update_test_blue,
update_test_red as crud_update_test_red,
)
from app.services.test_workflow_service import (
start_execution as wf_start_execution,
submit_red_evidence as wf_submit_red,
@@ -62,29 +70,6 @@ from app.services.test_workflow_service import (
router = APIRouter(prefix="/tests", tags=["tests"])
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_test_or_404(db: Session, test_id: uuid.UUID) -> Test:
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
def _get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Test not found")
return test
# ---------------------------------------------------------------------------
# GET /tests — list with filters
# ---------------------------------------------------------------------------
@@ -105,30 +90,16 @@ def list_tests(
current_user: User = Depends(get_current_user),
):
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
from app.utils import escape_like
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
tests = query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
return tests
return crud_list_tests(
db,
state=state,
technique_id=technique_id,
platform=platform,
created_by=created_by,
pending_validation_side=pending_validation_side,
offset=offset,
limit=limit,
)
# ---------------------------------------------------------------------------
@@ -150,20 +121,14 @@ def create_test(
``created_by`` is set automatically and ``state`` defaults to *draft*.
"""
technique = db.query(Technique).filter(Technique.id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique with id '{payload.technique_id}' not found",
with UnitOfWork(db) as uow:
test = crud_create_test(
db,
technique_id=payload.technique_id,
creator_id=current_user.id,
**payload.model_dump(exclude={"technique_id"}),
)
test = Test(
**payload.model_dump(),
created_by=current_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
uow.commit()
db.refresh(test)
log_action(
@@ -197,43 +162,14 @@ def create_test_from_template(
The template's fields are copied into the new test as starting data.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == payload.template_id).first()
if template is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"TestTemplate with id '{payload.template_id}' not found",
with UnitOfWork(db) as uow:
test = crud_create_from_template(
db,
template_id=payload.template_id,
technique_id_or_mitre=payload.technique_id,
creator_id=current_user.id,
)
# Resolve technique_id: accept both UUID and MITRE ID (e.g. "T1059.001")
technique = None
try:
technique_uuid = uuid.UUID(payload.technique_id)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(Technique.mitre_id == payload.technique_id).first()
if technique is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Technique '{payload.technique_id}' not found",
)
test = Test(
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=current_user.id,
state=TestState.draft,
)
db.add(test)
db.commit()
uow.commit()
db.refresh(test)
log_action(
@@ -244,7 +180,7 @@ def create_test_from_template(
entity_id=test.id,
details={
"name": test.name,
"template_id": str(template.id),
"template_id": str(payload.template_id),
"technique_id": str(test.technique_id),
},
)
@@ -264,20 +200,7 @@ def get_test(
current_user: User = Depends(get_current_user),
):
"""Return full details for a single test, including its evidences."""
test = (
db.query(Test)
.options(joinedload(Test.evidences))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Test not found",
)
return test
return crud_get_test_detail(db, test_id)
# ---------------------------------------------------------------------------
@@ -297,29 +220,16 @@ def update_test(
Only leads or admins can update general test fields.
The test must be in ``draft`` or ``rejected`` state.
"""
test = _get_test_or_404(db, test_id)
if current_user.role != "admin" and test.created_by != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={"message": "Only the test creator or an admin can update this test", "code": "FORBIDDEN"},
)
if test.state not in (TestState.draft, TestState.rejected):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test(
db,
test_id,
updater_id=current_user.id,
updater_role=current_user.role,
**update_data,
)
uow.commit()
db.refresh(test)
log_action(
@@ -347,23 +257,10 @@ def update_test_red(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
test = _get_test_or_404(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update red fields in '{test.state.value}' state (must be draft or red_executing)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test_red(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
@@ -391,23 +288,10 @@ def update_test_blue(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
test = _get_test_or_404(db, test_id)
if test.state != TestState.blue_evaluating:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": f"Cannot update blue fields in '{test.state.value}' state (must be blue_evaluating)",
"code": "INVALID_STATE",
"current_state": test.state.value,
},
)
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test, field, value)
db.commit()
with UnitOfWork(db) as uow:
test = crud_update_test_blue(db, test_id, **update_data)
uow.commit()
db.refresh(test)
log_action(
@@ -434,7 +318,7 @@ def start_execution(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Move a test from ``draft`` to ``red_executing``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_start_execution(db, test, current_user)
uow.commit()
@@ -454,7 +338,7 @@ def submit_red(
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_submit_red(db, test, current_user)
uow.commit()
@@ -474,7 +358,7 @@ def submit_blue(
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_submit_blue(db, test, current_user)
uow.commit()
@@ -494,7 +378,7 @@ def pause_timer(
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_pause_timer(db, test, current_user)
uow.commit()
@@ -514,7 +398,7 @@ def resume_timer(
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
"""Resume the paused timer for the current phase."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_resume_timer(db, test, current_user)
uow.commit()
@@ -535,7 +419,7 @@ def validate_red(
current_user: User = Depends(require_any_role("red_lead")),
):
"""Red Lead approves or rejects the red side of a test."""
test = _get_test_with_technique(db, test_id)
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
test = wf_validate_red(
db, test, current_user,
@@ -562,7 +446,7 @@ def validate_blue(
current_user: User = Depends(require_any_role("blue_lead")),
):
"""Blue Lead approves or rejects the blue side of a test."""
test = _get_test_with_technique(db, test_id)
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
test = wf_validate_blue(
db, test, current_user,
@@ -588,7 +472,7 @@ def reopen(
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
"""Reopen a rejected test, moving it back to ``draft``."""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
test = wf_reopen(db, test, current_user)
uow.commit()
@@ -613,7 +497,7 @@ def update_remediation(
When ``remediation_status`` transitions to ``'completed'``, an automatic
re-test is created (subject to ``MAX_RETEST_COUNT``).
"""
test = _get_test_or_404(db, test_id)
test = crud_get_test_or_raise(db, test_id)
old_remediation_status = test.remediation_status
@@ -653,29 +537,7 @@ def get_test_timeline(
current_user: User = Depends(get_current_user),
):
"""Return the chronological audit-log history for a test."""
# Verify the test exists
_get_test_or_404(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]
return crud_get_test_timeline(db, test_id)
# ---------------------------------------------------------------------------
@@ -0,0 +1,460 @@
"""Campaign CRUD service — list, create, update, and business logic.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from datetime import datetime
from typing import Optional
from sqlalchemy.orm import Session
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.campaign import Campaign, CampaignTest
from app.models.test import Test
from app.models.technique import Technique
from app.utils import escape_like
from app.services.campaign_service import (
get_campaign_progress,
validate_no_circular_dependency,
TACTIC_TO_PHASE,
)
from app.services.campaign_scheduler_service import calculate_next_run
# ── Serialization helpers ────────────────────────────────────────────────
def serialize_campaign(db: Session, campaign: Campaign) -> dict:
"""Serialize a campaign with its tests and progress."""
progress = get_campaign_progress(db, campaign.id)
campaign_tests = (
db.query(CampaignTest)
.filter(CampaignTest.campaign_id == campaign.id)
.order_by(CampaignTest.order_index)
.all()
)
tests = []
for ct in campaign_tests:
test = ct.test
technique = db.query(Technique).filter(Technique.id == test.technique_id).first() if test else None
tests.append({
"id": str(ct.id),
"test_id": str(ct.test_id),
"order_index": ct.order_index,
"depends_on": str(ct.depends_on) if ct.depends_on else None,
"phase": ct.phase,
"test_name": test.name if test else None,
"test_state": test.state.value if test and test.state else None,
"test_result": test.result.value if test and test.result else None,
"technique_mitre_id": technique.mitre_id if technique else None,
"technique_name": technique.name if technique else None,
"platform": test.platform if test else None,
})
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"created_by": str(campaign.created_by) if campaign.created_by else None,
"scheduled_at": campaign.scheduled_at.isoformat() if campaign.scheduled_at else None,
"completed_at": campaign.completed_at.isoformat() if campaign.completed_at else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"parent_campaign_id": str(campaign.parent_campaign_id) if campaign.parent_campaign_id else None,
"tests": tests,
"progress": progress,
}
def serialize_campaign_summary(db: Session, campaign: Campaign) -> dict:
"""Lightweight campaign serialization for list views."""
progress = get_campaign_progress(db, campaign.id)
actor = campaign.threat_actor
return {
"id": str(campaign.id),
"name": campaign.name,
"description": campaign.description,
"type": campaign.type,
"status": campaign.status,
"threat_actor_id": str(campaign.threat_actor_id) if campaign.threat_actor_id else None,
"threat_actor_name": actor.name if actor else None,
"target_platform": campaign.target_platform,
"tags": campaign.tags or [],
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"is_recurring": campaign.is_recurring or False,
"recurrence_pattern": campaign.recurrence_pattern,
"next_run_at": campaign.next_run_at.isoformat() if campaign.next_run_at else None,
"last_run_at": campaign.last_run_at.isoformat() if campaign.last_run_at else None,
"test_count": progress["total"],
"completion_pct": progress["completion_pct"],
}
# ── CRUD operations ───────────────────────────────────────────────────────
def list_campaigns(
db: Session,
*,
type: Optional[str] = None,
status: Optional[str] = None,
threat_actor_id: Optional[str] = None,
search: Optional[str] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""Return a paginated list of campaigns with optional filters."""
query = db.query(Campaign)
if type:
query = query.filter(Campaign.type == type)
if status:
query = query.filter(Campaign.status == status)
if threat_actor_id:
query = query.filter(Campaign.threat_actor_id == threat_actor_id)
if search:
pattern = f"%{escape_like(search)}%"
query = query.filter(Campaign.name.ilike(pattern) | Campaign.description.ilike(pattern))
total = query.count()
campaigns = query.order_by(Campaign.created_at.desc()).offset(offset).limit(limit).all()
return {
"total": total,
"offset": offset,
"limit": limit,
"items": [serialize_campaign_summary(db, c) for c in campaigns],
}
def create_campaign(
db: Session,
*,
creator_id: uuid.UUID,
name: str,
description: Optional[str] = None,
type: str = "custom",
threat_actor_id: Optional[str] = None,
target_platform: Optional[str] = None,
tags: Optional[list[str]] = None,
scheduled_at: Optional[str] = None,
) -> dict:
"""Create a new campaign. Does not commit; caller commits."""
campaign = Campaign(
name=name,
description=description,
type=type,
threat_actor_id=uuid.UUID(threat_actor_id) if threat_actor_id else None,
target_platform=target_platform,
tags=tags or [],
created_by=creator_id,
scheduled_at=datetime.fromisoformat(scheduled_at) if scheduled_at else None,
)
db.add(campaign)
db.flush()
return serialize_campaign(db, campaign)
def get_campaign_detail(db: Session, campaign_id: str) -> dict:
"""Get detailed campaign info including tests and progress.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
return serialize_campaign(db, campaign)
def update_campaign(
db: Session,
campaign_id: str,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields,
) -> dict:
"""Update a campaign. Only allowed in draft or active state.
Raises EntityNotFoundError, BusinessRuleViolation, or PermissionViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only update draft or active campaigns")
if str(campaign.created_by) != str(updater_id) and updater_role != "admin":
raise PermissionViolation("Only the creator or admin can update this campaign")
if "scheduled_at" in fields and fields["scheduled_at"]:
fields["scheduled_at"] = datetime.fromisoformat(fields["scheduled_at"])
for field, value in fields.items():
setattr(campaign, field, value)
db.flush()
return serialize_campaign(db, campaign)
def add_test_to_campaign(
db: Session,
campaign_id: str,
*,
test_id: str,
order_index: Optional[int] = None,
depends_on: Optional[str] = None,
phase: Optional[str] = None,
) -> dict:
"""Add a test to a campaign with optional ordering and dependency.
Raises EntityNotFoundError for missing campaign or test.
Raises BusinessRuleViolation for invalid state or circular dependency.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only add tests to draft or active campaigns")
test = db.query(Test).filter(Test.id == test_id).first()
if not test:
raise EntityNotFoundError("Test", test_id)
if order_index is not None:
final_order_index = order_index
else:
max_order = (
db.query(CampaignTest.order_index)
.filter(CampaignTest.campaign_id == campaign_id)
.order_by(CampaignTest.order_index.desc())
.first()
)
final_order_index = (max_order[0] + 1) if max_order else 0
depends_on_uuid = uuid.UUID(depends_on) if depends_on else None
ct_id = uuid.uuid4()
if depends_on_uuid:
validate_no_circular_dependency(db, uuid.UUID(campaign_id), ct_id, depends_on_uuid)
if not phase and test.technique_id:
technique = db.query(Technique).filter(Technique.id == test.technique_id).first()
if technique and technique.tactic:
phase = TACTIC_TO_PHASE.get(technique.tactic, None)
campaign_test = CampaignTest(
id=ct_id,
campaign_id=campaign_id,
test_id=test_id,
order_index=final_order_index,
depends_on=depends_on_uuid,
phase=phase,
)
db.add(campaign_test)
db.flush()
return {
"id": str(campaign_test.id),
"campaign_id": str(campaign_test.campaign_id),
"test_id": str(campaign_test.test_id),
"order_index": campaign_test.order_index,
"depends_on": str(campaign_test.depends_on) if campaign_test.depends_on else None,
"phase": campaign_test.phase,
}
def remove_test_from_campaign(db: Session, campaign_id: str, campaign_test_id: str) -> None:
"""Remove a test from a campaign.
Raises EntityNotFoundError for missing campaign or campaign test.
Raises BusinessRuleViolation for invalid state.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status not in ("draft", "active"):
raise BusinessRuleViolation("Can only modify draft or active campaigns")
ct = (
db.query(CampaignTest)
.filter(
CampaignTest.id == campaign_test_id,
CampaignTest.campaign_id == campaign_id,
)
.first()
)
if not ct:
raise EntityNotFoundError("CampaignTest", campaign_test_id)
dep_id = uuid.UUID(campaign_test_id)
dependents = db.query(CampaignTest).filter(CampaignTest.depends_on == dep_id).all()
for dep in dependents:
dep.depends_on = None
db.delete(ct)
db.flush()
def activate_campaign(db: Session, campaign_id: str) -> Campaign:
"""Activate a campaign, moving it from draft to active.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "draft":
raise BusinessRuleViolation("Only draft campaigns can be activated")
test_count = db.query(CampaignTest).filter(CampaignTest.campaign_id == campaign_id).count()
if test_count == 0:
raise BusinessRuleViolation("Campaign must have at least one test to activate")
campaign.status = "active"
db.flush()
return campaign
def complete_campaign(db: Session, campaign_id: str) -> Campaign:
"""Mark a campaign as completed.
Raises EntityNotFoundError, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if campaign.status != "active":
raise BusinessRuleViolation("Only active campaigns can be completed")
campaign.status = "completed"
campaign.completed_at = datetime.utcnow()
db.flush()
return campaign
def get_campaign_progress_data(db: Session, campaign_id: str) -> dict:
"""Get progress statistics for a campaign.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
progress = get_campaign_progress(db, uuid.UUID(campaign_id))
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
**progress,
}
def schedule_campaign(
db: Session,
campaign_id: str,
*,
owner_id: uuid.UUID,
owner_role: str,
is_recurring: bool,
recurrence_pattern: Optional[str] = None,
next_run_at: Optional[str] = None,
) -> Campaign:
"""Configure or update the recurrence schedule for a campaign.
Raises EntityNotFoundError, PermissionViolation, BusinessRuleViolation.
Does not commit; caller commits.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
if str(campaign.created_by) != str(owner_id) and owner_role != "admin":
raise PermissionViolation("Only the creator or admin can configure scheduling")
campaign.is_recurring = is_recurring
if is_recurring:
if recurrence_pattern not in ("weekly", "monthly", "quarterly"):
raise BusinessRuleViolation(
"recurrence_pattern must be 'weekly', 'monthly', or 'quarterly'"
)
campaign.recurrence_pattern = recurrence_pattern
if next_run_at:
campaign.next_run_at = datetime.fromisoformat(
next_run_at.replace("Z", "+00:00").replace("+00:00", "")
)
elif not campaign.next_run_at:
campaign.next_run_at = calculate_next_run(datetime.utcnow(), recurrence_pattern)
else:
campaign.recurrence_pattern = None
campaign.next_run_at = None
db.flush()
return campaign
def get_campaign_history(db: Session, campaign_id: str) -> dict:
"""List all child campaigns (execution history) of a recurring campaign.
Raises EntityNotFoundError if campaign not found.
"""
campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first()
if not campaign:
raise EntityNotFoundError("Campaign", campaign_id)
campaign_uuid = uuid.UUID(campaign_id)
children = (
db.query(Campaign)
.filter(Campaign.parent_campaign_id == campaign_uuid)
.order_by(Campaign.created_at.desc())
.all()
)
return {
"campaign_id": str(campaign.id),
"campaign_name": campaign.name,
"items": [
{
"id": str(child.id),
"name": child.name,
"status": child.status,
"test_count": db.query(CampaignTest).filter(CampaignTest.campaign_id == child.id).count(),
"completion_pct": get_campaign_progress(db, child.id)["completion_pct"],
"created_at": child.created_at.isoformat() if child.created_at else None,
"completed_at": child.completed_at.isoformat() if child.completed_at else None,
}
for child in children
],
}
+167
View File
@@ -0,0 +1,167 @@
"""Evidence service — permission validation, file validation, and query logic.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, file I/O, MinIO upload,
audit logging, and response formatting.
"""
from __future__ import annotations
import os
import uuid
from sqlalchemy.orm import Session
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.enums import TeamSide, TestState
from app.models.evidence import Evidence
from app.models.test import Test
# States where red evidence can be uploaded / deleted
RED_EDITABLE_STATES = (TestState.draft, TestState.red_executing)
# States where blue evidence can be uploaded / deleted
BLUE_EDITABLE_STATES = (TestState.blue_evaluating,)
# Maximum upload size in bytes (50 MB)
MAX_UPLOAD_SIZE = 50 * 1024 * 1024
# Allowed file extensions (lowercase, with leading dot)
ALLOWED_EXTENSIONS: frozenset[str] = frozenset({
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg",
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".csv", ".txt",
".md", ".rtf", ".odt", ".ods",
".log", ".pcap", ".pcapng", ".evtx", ".json", ".xml",
".yaml", ".yml", ".toml",
".zip", ".tar", ".gz", ".7z",
".har", ".eml", ".msg",
})
def validate_upload_permission(
test: Test,
team: TeamSide,
user_role: str,
) -> None:
"""Validate that the user can upload evidence for the given team in the current state.
Raises:
PermissionViolation: If user lacks role to upload for this team.
BusinessRuleViolation: If test state does not allow uploading for this team.
"""
if user_role == "admin":
return
if team == TeamSide.red:
if user_role not in ("red_tech", "red_lead"):
raise PermissionViolation(
"Only red_tech, red_lead or admin can upload red evidence"
)
if test.state not in RED_EDITABLE_STATES:
raise BusinessRuleViolation(
f"Cannot upload red evidence in '{test.state.value}' state "
"(allowed in: draft, red_executing)"
)
elif team == TeamSide.blue:
if user_role not in ("blue_tech", "blue_lead"):
raise PermissionViolation(
"Only blue_tech, blue_lead or admin can upload blue evidence"
)
if test.state not in BLUE_EDITABLE_STATES:
raise BusinessRuleViolation(
f"Cannot upload blue evidence in '{test.state.value}' state "
"(allowed in: blue_evaluating)"
)
def validate_delete_permission(
test: Test,
evidence: Evidence,
user_role: str,
user_id: uuid.UUID,
) -> None:
"""Validate that the user can delete this evidence in the current state.
Raises:
PermissionViolation: If user cannot delete in this state or lacks permission.
"""
if test.state in (TestState.in_review, TestState.validated, TestState.rejected):
raise PermissionViolation(
f"Cannot delete evidence when test is in '{test.state.value}' state"
)
if user_role == "admin":
return
ev_team = evidence.team
if ev_team == TeamSide.red:
if test.state not in RED_EDITABLE_STATES:
raise PermissionViolation(
"Cannot delete red evidence outside draft/red_executing"
)
if user_role not in ("red_tech", "red_lead") and evidence.uploaded_by != user_id:
raise PermissionViolation(
"Not enough permissions to delete this evidence"
)
elif ev_team == TeamSide.blue:
if test.state not in BLUE_EDITABLE_STATES:
raise PermissionViolation(
"Cannot delete blue evidence outside blue_evaluating"
)
if user_role not in ("blue_tech", "blue_lead") and evidence.uploaded_by != user_id:
raise PermissionViolation(
"Not enough permissions to delete this evidence"
)
def validate_file(file_name: str, content_size: int) -> None:
"""Validate file extension and size.
Raises:
BusinessRuleViolation: If extension is not allowed or file exceeds size limit.
"""
_, ext = os.path.splitext(file_name)
ext_lower = ext.lower() if ext else ""
if ext_lower not in ALLOWED_EXTENSIONS:
raise BusinessRuleViolation(
f"File type '{ext}' is not allowed. "
f"Permitted types: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
)
if content_size > MAX_UPLOAD_SIZE:
raise BusinessRuleViolation(
f"File exceeds maximum upload size of {MAX_UPLOAD_SIZE // (1024 * 1024)} MB"
)
def list_evidence_for_test(
db: Session,
test_id: uuid.UUID,
*,
team: TeamSide | str | None = None,
) -> list[Evidence]:
"""Return evidence for a test, optionally filtered by team."""
query = db.query(Evidence).filter(Evidence.test_id == test_id)
if team is not None:
team_enum = TeamSide(team) if isinstance(team, str) else team
query = query.filter(Evidence.team == team_enum)
return query.order_by(Evidence.uploaded_at.desc()).all()
def get_evidence_or_raise(db: Session, evidence_id: uuid.UUID) -> Evidence:
"""Fetch evidence by ID. Raises EntityNotFoundError if not found."""
evidence = db.query(Evidence).filter(Evidence.id == evidence_id).first()
if evidence is None:
raise EntityNotFoundError("Evidence", str(evidence_id))
return evidence
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch test by ID. Raises EntityNotFoundError if not found."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
@@ -0,0 +1,107 @@
"""Scoring configuration persistence service.
Reads and writes scoring weights from the ``scoring_config`` table.
Falls back to environment-variable defaults (from ``Settings``) when
no row has been persisted yet.
This module is framework-agnostic: no FastAPI imports.
"""
from __future__ import annotations
from typing import Any
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.value_objects.scoring_weights import ScoringWeights
from app.models.scoring_config import ScoringConfig
def get_scoring_weights(db: Session) -> ScoringWeights:
"""Return the active scoring weights.
Reads the single ``scoring_config`` row. If the table is empty
(first run or migration just applied), falls back to the values
from the environment / ``Settings``.
"""
row = db.query(ScoringConfig).first()
if row is not None:
return ScoringWeights(
tests=row.weight_tests,
detection_rules=row.weight_detection_rules,
d3fend=row.weight_d3fend,
freshness=row.weight_freshness,
platform_diversity=row.weight_platform_diversity,
)
return ScoringWeights(
tests=float(settings.SCORING_WEIGHT_TESTS),
detection_rules=float(settings.SCORING_WEIGHT_DETECTION_RULES),
d3fend=float(settings.SCORING_WEIGHT_D3FEND),
freshness=float(settings.SCORING_WEIGHT_FRESHNESS),
platform_diversity=float(settings.SCORING_WEIGHT_PLATFORM_DIVERSITY),
)
def update_scoring_weights(
db: Session,
*,
tests: float | None = None,
detection_rules: float | None = None,
d3fend: float | None = None,
freshness: float | None = None,
platform_diversity: float | None = None,
) -> dict[str, Any]:
"""Upsert scoring weights into the database.
Only provided fields are overwritten; ``None`` values keep the
current (or default) value. Validates via ``ScoringWeights``
before persisting.
Returns a dict with ``weights`` and ``total``.
"""
current = get_scoring_weights(db)
new = ScoringWeights(
tests=tests if tests is not None else current.tests,
detection_rules=detection_rules if detection_rules is not None else current.detection_rules,
d3fend=d3fend if d3fend is not None else current.d3fend,
freshness=freshness if freshness is not None else current.freshness,
platform_diversity=platform_diversity if platform_diversity is not None else current.platform_diversity,
)
row = db.query(ScoringConfig).first()
if row is None:
row = ScoringConfig()
db.add(row)
row.weight_tests = new.tests
row.weight_detection_rules = new.detection_rules
row.weight_d3fend = new.d3fend
row.weight_freshness = new.freshness
row.weight_platform_diversity = new.platform_diversity
db.commit()
db.refresh(row)
return _weights_dict(new)
def get_weights_dict(db: Session) -> dict[str, Any]:
"""Return current weights as a serialisable dict."""
return _weights_dict(get_scoring_weights(db))
def _weights_dict(w: ScoringWeights) -> dict[str, Any]:
weights = {
"tests": w.tests,
"detection_rules": w.detection_rules,
"d3fend": w.d3fend,
"freshness": w.freshness,
"platform_diversity": w.platform_diversity,
}
return {
"weights": weights,
"total": sum(weights.values()),
}
+17 -28
View File
@@ -1,7 +1,8 @@
"""Scoring service — granular 0-100 scoring for techniques, tactics, actors, and org.
Uses configurable weights from Settings to compute coverage scores with
detailed breakdowns.
Reads configurable weights from the ``scoring_config`` table (falling
back to env-var defaults) to compute coverage scores with detailed
breakdowns.
Bulk helpers (``bulk_technique_scores``) pre-fetch all scoring data in a
fixed number of aggregated queries so that organisation-wide calculations
@@ -14,7 +15,6 @@ from typing import Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from app.config import settings
from app.models.technique import Technique
from app.models.test import Test
from app.models.detection_rule import DetectionRule
@@ -22,20 +22,12 @@ from app.models.test_detection_result import TestDetectionResult
from app.models.defensive_technique import DefensiveTechniqueMapping
from app.models.threat_actor import ThreatActor, ThreatActorTechnique
from app.models.enums import TestState, TestResult
from app.services.scoring_config_service import get_scoring_weights
# ── Bulk scoring helpers (5 queries for ALL techniques) ───────────────
def _build_empty_stats():
return {
"validated": 0,
"detected": 0,
"platforms": set(),
"latest_validated_at": None,
}
def bulk_technique_scores(db: Session) -> dict:
"""Pre-fetch all scoring data and compute per-technique scores in memory.
@@ -48,11 +40,12 @@ def bulk_technique_scores(db: Session) -> dict:
Returns ``{technique_id: {"total_score": float, "breakdown": dict}}``.
"""
w_tests = settings.SCORING_WEIGHT_TESTS
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
w_d3fend = settings.SCORING_WEIGHT_D3FEND
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
w = get_scoring_weights(db)
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
# Q1: test stats grouped by technique_id
test_rows = (
@@ -242,18 +235,14 @@ def bulk_technique_scores(db: Session) -> dict:
def calculate_technique_score(technique: Technique, db: Session) -> dict:
"""Calculate a 0-100 score for a technique with detailed breakdown.
Weights (configurable via settings):
- tests_validated: weight from SCORING_WEIGHT_TESTS
- detection_rules: weight from SCORING_WEIGHT_DETECTION_RULES
- d3fend_coverage: weight from SCORING_WEIGHT_D3FEND
- freshness: weight from SCORING_WEIGHT_FRESHNESS
- platform_diversity: weight from SCORING_WEIGHT_PLATFORM_DIVERSITY
Weights are read from the ``scoring_config`` table (or env defaults).
"""
w_tests = settings.SCORING_WEIGHT_TESTS
w_detection = settings.SCORING_WEIGHT_DETECTION_RULES
w_d3fend = settings.SCORING_WEIGHT_D3FEND
w_freshness = settings.SCORING_WEIGHT_FRESHNESS
w_diversity = settings.SCORING_WEIGHT_PLATFORM_DIVERSITY
w = get_scoring_weights(db)
w_tests = w.tests
w_detection = w.detection_rules
w_d3fend = w.d3fend
w_freshness = w.freshness
w_diversity = w.platform_diversity
breakdown = {}
+277
View File
@@ -0,0 +1,277 @@
"""Test CRUD service — list, create, update, and query logic for security tests.
Framework-agnostic; uses domain exceptions from app.domain.errors.
The router is responsible for HTTP concerns, auth, audit logging, and commit.
"""
import uuid
from typing import Any
from sqlalchemy.orm import Session, joinedload
from app.domain.errors import (
BusinessRuleViolation,
EntityNotFoundError,
PermissionViolation,
)
from app.models.enums import TestState
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.audit import AuditLog
from app.utils import escape_like
def list_tests(
db: Session,
*,
state: str | None = None,
technique_id: uuid.UUID | None = None,
platform: str | None = None,
created_by: uuid.UUID | None = None,
pending_validation_side: str | None = None,
offset: int = 0,
limit: int = 50,
) -> list[Test]:
"""Return a paginated list of tests with optional filters."""
query = db.query(Test).options(joinedload(Test.technique))
if state:
query = query.filter(Test.state == state)
if technique_id:
query = query.filter(Test.technique_id == technique_id)
if platform:
query = query.filter(Test.platform.ilike(f"%{escape_like(platform)}%"))
if created_by:
query = query.filter(Test.created_by == created_by)
if pending_validation_side == "red":
query = query.filter(
Test.state == TestState.in_review,
Test.red_validation_status.in_(["pending", None]),
)
elif pending_validation_side == "blue":
query = query.filter(
Test.state == TestState.in_review,
Test.blue_validation_status.in_(["pending", None]),
)
return query.order_by(Test.created_at.desc()).offset(offset).limit(limit).all()
def create_test(
db: Session,
*,
technique_id: uuid.UUID,
creator_id: uuid.UUID,
**fields: Any,
) -> Test:
"""Create a new test linked to an existing technique.
Raises EntityNotFoundError if the technique does not exist.
Does not commit; caller uses UnitOfWork.
"""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if technique is None:
raise EntityNotFoundError("Technique", str(technique_id))
test = Test(
technique_id=technique_id,
created_by=creator_id,
state=TestState.draft,
**fields,
)
db.add(test)
db.flush()
return test
def create_test_from_template(
db: Session,
*,
template_id: uuid.UUID,
technique_id_or_mitre: str,
creator_id: uuid.UUID,
) -> Test:
"""Instantiate a Test from a TestTemplate.
technique_id_or_mitre can be a UUID string or MITRE ID (e.g. T1059.001).
Raises EntityNotFoundError if template or technique not found.
Does not commit; caller uses UnitOfWork.
"""
template = db.query(TestTemplate).filter(TestTemplate.id == template_id).first()
if template is None:
raise EntityNotFoundError("TestTemplate", str(template_id))
technique = None
try:
technique_uuid = uuid.UUID(technique_id_or_mitre)
technique = db.query(Technique).filter(Technique.id == technique_uuid).first()
except ValueError:
pass
if technique is None:
technique = db.query(Technique).filter(
Technique.mitre_id == technique_id_or_mitre
).first()
if technique is None:
raise EntityNotFoundError("Technique", technique_id_or_mitre)
test = Test(
technique_id=technique.id,
name=template.name,
description=template.description,
platform=template.platform,
procedure_text=template.attack_procedure,
tool_used=template.tool_suggested,
remediation_steps=template.suggested_remediation,
created_by=creator_id,
state=TestState.draft,
)
db.add(test)
db.flush()
return test
def get_test_detail(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with evidences eager-loaded.
Raises EntityNotFoundError if the test does not exist.
"""
test = (
db.query(Test)
.options(joinedload(Test.evidences))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_or_raise(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test by ID. Raises EntityNotFoundError if not found."""
test = db.query(Test).filter(Test.id == test_id).first()
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def get_test_with_technique(db: Session, test_id: uuid.UUID) -> Test:
"""Fetch a test with technique joined. Raises EntityNotFoundError if not found."""
test = (
db.query(Test)
.options(joinedload(Test.technique))
.filter(Test.id == test_id)
.first()
)
if test is None:
raise EntityNotFoundError("Test", str(test_id))
return test
def update_test(
db: Session,
test_id: uuid.UUID,
*,
updater_id: uuid.UUID,
updater_role: str,
**fields: Any,
) -> Test:
"""Update general test fields (draft or rejected only).
Raises PermissionViolation if not creator or admin.
Raises BusinessRuleViolation if state is not draft or rejected.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if updater_role != "admin" and test.created_by != updater_id:
raise PermissionViolation(
"Only the test creator or an admin can update this test"
)
if test.state not in (TestState.draft, TestState.rejected):
raise BusinessRuleViolation(
f"Cannot update a test in '{test.state.value}' state (must be draft or rejected)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_red(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Red Team fields (draft or red_executing only).
Raises BusinessRuleViolation if state not in (draft, red_executing).
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state not in (TestState.draft, TestState.red_executing):
raise BusinessRuleViolation(
f"Cannot update red fields in '{test.state.value}' state "
"(must be draft or red_executing)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def update_test_blue(db: Session, test_id: uuid.UUID, **fields: Any) -> Test:
"""Update Blue Team fields (blue_evaluating only).
Raises BusinessRuleViolation if state is not blue_evaluating.
Raises EntityNotFoundError if test not found.
Does not commit; caller uses UnitOfWork.
"""
test = get_test_or_raise(db, test_id)
if test.state != TestState.blue_evaluating:
raise BusinessRuleViolation(
f"Cannot update blue fields in '{test.state.value}' state "
"(must be blue_evaluating)"
)
for field, value in fields.items():
setattr(test, field, value)
db.flush()
return test
def get_test_timeline(db: Session, test_id: uuid.UUID) -> list[dict[str, Any]]:
"""Return chronological audit-log history for a test.
Raises EntityNotFoundError if the test does not exist.
"""
get_test_or_raise(db, test_id)
logs = (
db.query(AuditLog)
.filter(
AuditLog.entity_type == "test",
AuditLog.entity_id == str(test_id),
)
.order_by(AuditLog.timestamp.asc())
.all()
)
return [
{
"id": str(log.id),
"action": log.action,
"user_id": str(log.user_id) if log.user_id else None,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"details": log.details,
}
for log in logs
]
+23 -23
View File
@@ -89,12 +89,12 @@ for mod_name in [
# Imports
# ---------------------------------------------------------------------------
from fastapi import HTTPException
from app.domain.errors import PermissionViolation
from app.models.enums import TeamSide, TestState
from app.routers.evidence import (
router,
_validate_upload_permission,
_validate_delete_permission,
from app.routers.evidence import router
from app.services.evidence_service import (
validate_delete_permission,
validate_upload_permission,
)
@@ -132,7 +132,7 @@ def test_red_tech_upload_red_in_red_executing():
test = _make_test(TestState.red_executing)
user = _make_user("red_tech")
# Should not raise
_validate_upload_permission(test, TeamSide.red, user)
validate_upload_permission(test, TeamSide.red, user.role)
print(" [PASS] red_tech can upload team=red in red_executing")
@@ -144,7 +144,7 @@ def test_red_tech_upload_red_in_red_executing():
def test_red_tech_upload_red_in_draft():
test = _make_test(TestState.draft)
user = _make_user("red_tech")
_validate_upload_permission(test, TeamSide.red, user)
validate_upload_permission(test, TeamSide.red, user.role)
print(" [PASS] red_tech can upload team=red in draft")
@@ -157,10 +157,10 @@ def test_red_tech_cannot_upload_blue():
test = _make_test(TestState.red_executing)
user = _make_user("red_tech")
try:
_validate_upload_permission(test, TeamSide.blue, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
validate_upload_permission(test, TeamSide.blue, user.role)
assert False, "Should have raised PermissionViolation"
except PermissionViolation:
pass
print(" [PASS] red_tech CANNOT upload team=blue (403)")
@@ -172,7 +172,7 @@ def test_red_tech_cannot_upload_blue():
def test_blue_tech_upload_blue_in_blue_evaluating():
test = _make_test(TestState.blue_evaluating)
user = _make_user("blue_tech")
_validate_upload_permission(test, TeamSide.blue, user)
validate_upload_permission(test, TeamSide.blue, user.role)
print(" [PASS] blue_tech can upload team=blue in blue_evaluating")
@@ -185,10 +185,10 @@ def test_blue_tech_cannot_upload_red():
test = _make_test(TestState.blue_evaluating)
user = _make_user("blue_tech")
try:
_validate_upload_permission(test, TeamSide.red, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
validate_upload_permission(test, TeamSide.red, user.role)
assert False, "Should have raised PermissionViolation"
except PermissionViolation:
pass
print(" [PASS] blue_tech CANNOT upload team=red (403)")
@@ -223,10 +223,10 @@ def test_delete_in_review_fails():
user = _make_user("red_tech")
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
try:
_validate_delete_permission(test, evidence, user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
validate_delete_permission(test, evidence, user.role, user.id)
assert False, "Should have raised PermissionViolation"
except PermissionViolation:
pass
print(" [PASS] DELETE in in_review -> 403")
@@ -240,7 +240,7 @@ def test_delete_red_evidence_in_red_executing():
user = _make_user("red_tech")
evidence = _make_evidence(TeamSide.red, uploaded_by=user.id)
# Should not raise
_validate_delete_permission(test, evidence, user)
validate_delete_permission(test, evidence, user.role, user.id)
print(" [PASS] DELETE red evidence in red_executing -> allowed")
@@ -254,11 +254,11 @@ def test_admin_bypass():
# Red in blue_evaluating (normally blocked)
test1 = _make_test(TestState.blue_evaluating)
_validate_upload_permission(test1, TeamSide.red, admin)
validate_upload_permission(test1, TeamSide.red, admin.role)
# Blue in draft (normally blocked)
test2 = _make_test(TestState.draft)
_validate_upload_permission(test2, TeamSide.blue, admin)
validate_upload_permission(test2, TeamSide.blue, admin.role)
print(" [PASS] Admin can upload any team in any state")
+3 -2
View File
@@ -101,7 +101,7 @@ from app.routers.test_templates import (
toggle_template_active,
template_stats,
)
from app.routers.tests import create_test_from_template
from app.services.test_crud_service import create_test_from_template as crud_create_from_template
from app.schemas.test_template import TestTemplateCreate
@@ -174,7 +174,8 @@ def test_get_templates_by_technique():
def test_instantiate_template():
"""POST /tests/from-template creates a test pre-filled from template data."""
source = inspect.getsource(create_test_from_template)
# Template field copying lives in the service; router delegates to it
source = inspect.getsource(crud_create_from_template)
# Verify it reads from template and copies fields
assert "template" in source, "Must reference template"
+30 -29
View File
@@ -419,56 +419,57 @@ def test_dual_validation_red_approves_blue_rejects(mock_log):
def test_evidence_team_separation():
"""Verify evidence router logic separates red and blue evidence correctly."""
from app.routers.evidence import _validate_upload_permission, _RED_EDITABLE_STATES, _BLUE_EDITABLE_STATES
from app.domain.errors import BusinessRuleViolation, PermissionViolation
from app.models.enums import TeamSide
from app.services.evidence_service import validate_upload_permission
# Red tech can upload red evidence in draft
test = _make_test(TestState.draft)
red_user = _make_user("red_tech")
red_user.role = "red_tech"
from app.models.enums import TeamSide
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
validate_upload_permission(test, TeamSide.red, red_user.role) # should not raise
# Red tech can upload red evidence in red_executing
test.state = TestState.red_executing
_validate_upload_permission(test, TeamSide.red, red_user) # should not raise
validate_upload_permission(test, TeamSide.red, red_user.role) # should not raise
# Red tech CANNOT upload red evidence in blue_evaluating
# Red tech CANNOT upload red evidence in blue_evaluating (state violation -> 400)
test.state = TestState.blue_evaluating
try:
_validate_upload_permission(test, TeamSide.red, red_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
validate_upload_permission(test, TeamSide.red, red_user.role)
assert False, "Should have raised BusinessRuleViolation"
except BusinessRuleViolation:
pass
# Red tech CANNOT upload blue evidence
# Red tech CANNOT upload blue evidence (role violation -> 403)
test.state = TestState.blue_evaluating
try:
_validate_upload_permission(test, TeamSide.blue, red_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
validate_upload_permission(test, TeamSide.blue, red_user.role)
assert False, "Should have raised PermissionViolation"
except PermissionViolation:
pass
# Blue tech can upload blue evidence in blue_evaluating
test.state = TestState.blue_evaluating
blue_user = _make_user("blue_tech")
blue_user.role = "blue_tech"
_validate_upload_permission(test, TeamSide.blue, blue_user) # should not raise
validate_upload_permission(test, TeamSide.blue, blue_user.role) # should not raise
# Blue tech CANNOT upload blue evidence in draft
# Blue tech CANNOT upload blue evidence in draft (state violation -> 400)
test.state = TestState.draft
try:
_validate_upload_permission(test, TeamSide.blue, blue_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 400
validate_upload_permission(test, TeamSide.blue, blue_user.role)
assert False, "Should have raised BusinessRuleViolation"
except BusinessRuleViolation:
pass
# Blue tech CANNOT upload red evidence
# Blue tech CANNOT upload red evidence (role violation -> 403)
test.state = TestState.draft
try:
_validate_upload_permission(test, TeamSide.red, blue_user)
assert False, "Should have raised HTTPException"
except HTTPException as exc:
assert exc.status_code == 403
validate_upload_permission(test, TeamSide.red, blue_user.role)
assert False, "Should have raised PermissionViolation"
except PermissionViolation:
pass
# ===========================================================================
@@ -477,15 +478,15 @@ def test_evidence_team_separation():
def test_red_edit_allowed_in_draft_and_red_executing():
"""Verify the red update router checks that state is draft or red_executing."""
from app.routers.tests import update_test_red
"""Verify the red update checks that state is draft or red_executing."""
from app.services.test_crud_service import update_test_red
import inspect
source = inspect.getsource(update_test_red)
# The function must guard against states other than draft/red_executing
# The service must guard against states other than draft/red_executing
assert "draft" in source, "Red update must allow draft state"
assert "red_executing" in source, "Red update must allow red_executing state"
assert "400" in source or "HTTP_400_BAD_REQUEST" in source, "Red update must return 400 for invalid state"
assert "BusinessRuleViolation" in source, "Must raise domain exception for invalid state (mapped to 400)"
# ===========================================================================
+15 -15
View File
@@ -2,30 +2,30 @@
## Tier 1 — Quick Wins
- [ ] QW-1: Wire existing repos into `techniques.py` router
- [ ] QW-2: Fix `audit_service` to follow UoW (no direct `db.commit()`)
- [ ] QW-3: Consolidate `status_service` with `TechniqueEntity.recalculate_status()`
- [ ] QW-4: Remove remaining `HTTPException` from services
- [x] QW-1: Wire existing repos into `techniques.py` router
- [~] QW-2: Fix `audit_service` to follow UoW — deferred, resolves naturally as routers adopt UoW
- [x] QW-3: Consolidate `status_service` with `TechniqueEntity.recalculate_status()`
- [x] QW-4: Remove remaining `HTTPException` from services — already resolved
## Tier 2 — Service Extraction (fat routers → thin routers + services)
- [ ] SE-1: Extract reports service from `reports.py`
- [ ] SE-2: Extract metrics service from `metrics.py`
- [ ] SE-3: Extract compliance service from `compliance.py`
- [ ] SE-4: Extract detection_rules service from `detection_rules.py`
- [ ] SE-5: Extract threat_actors service from `threat_actors.py`
- [x] SE-1: Extract reports service `coverage_report_service.py`
- [x] SE-2: Extract metrics service `metrics_query_service.py`
- [x] SE-3: Extract compliance service `compliance_service.py`
- [x] SE-4: Extract detection_rules service `detection_rule_service.py`
- [x] SE-5: Extract threat_actors service `threat_actor_service.py`
## Tier 3 — Architectural Fixes
- [ ] AF-1: Persist scoring weights in DB (replace mutable `settings`)
- [ ] AF-2: Slim `tests.py` router (CRUD to repo/service)
- [ ] AF-3: Slim `evidence.py` router (permissions to domain)
- [ ] AF-4: Slim `campaigns.py` router (CRUD to service)
- [x] AF-1: Persist scoring weights in DB `scoring_config` table + `scoring_config_service.py`
- [x] AF-2: Slim `tests.py` router `test_crud_service.py`
- [x] AF-3: Slim `evidence.py` router `evidence_service.py`
- [x] AF-4: Slim `campaigns.py` router `campaign_crud_service.py`
## Tier 4 — Polish
- [ ] P-1: Structured JSON logging
- [ ] P-2: Create architecture skill file for future agents
- [x] P-1: Structured JSON logging`logging_config.py`
- [x] P-2: Create architecture skill file `~/.cursor/skills/aegis-architecture/SKILL.md`
## Completed (prior sessions)