feat(phase-19): add remediation fields and reports system (T-130, T-131)
This commit is contained in:
44
backend/alembic/versions/b007_add_remediation_fields.py
Normal file
44
backend/alembic/versions/b007_add_remediation_fields.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""add_remediation_fields
|
||||
|
||||
Revision ID: b007remediation
|
||||
Revises: b006notifications
|
||||
Create Date: 2026-02-09 11:30:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b007remediation'
|
||||
down_revision: Union[str, Sequence[str], None] = 'b006notifications'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add remediation fields to tests and test_templates."""
|
||||
# Tests — remediation fields
|
||||
op.add_column('tests', sa.Column('remediation_steps', sa.Text(), nullable=True))
|
||||
op.add_column('tests', sa.Column('remediation_status', sa.String(), nullable=True))
|
||||
op.add_column('tests', sa.Column('remediation_assignee', UUID(as_uuid=True), nullable=True))
|
||||
op.create_foreign_key(
|
||||
'fk_tests_remediation_assignee',
|
||||
'tests', 'users',
|
||||
['remediation_assignee'], ['id'],
|
||||
)
|
||||
|
||||
# TestTemplates — suggested_remediation
|
||||
op.add_column('test_templates', sa.Column('suggested_remediation', sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove remediation fields."""
|
||||
op.drop_column('test_templates', 'suggested_remediation')
|
||||
op.drop_constraint('fk_tests_remediation_assignee', 'tests', type_='foreignkey')
|
||||
op.drop_column('tests', 'remediation_assignee')
|
||||
op.drop_column('tests', 'remediation_status')
|
||||
op.drop_column('tests', 'remediation_steps')
|
||||
@@ -17,6 +17,7 @@ from app.routers import metrics as metrics_router
|
||||
from app.routers import users as users_router
|
||||
from app.routers import audit as audit_router
|
||||
from app.routers import notifications as notifications_router
|
||||
from app.routers import reports as reports_router
|
||||
from app.storage import ensure_bucket_exists
|
||||
from app.jobs.mitre_sync_job import start_scheduler, scheduler
|
||||
|
||||
@@ -58,6 +59,7 @@ app.include_router(metrics_router.router, prefix="/api/v1")
|
||||
app.include_router(users_router.router, prefix="/api/v1")
|
||||
app.include_router(audit_router.router, prefix="/api/v1")
|
||||
app.include_router(notifications_router.router, prefix="/api/v1")
|
||||
app.include_router(reports_router.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -49,9 +49,15 @@ class Test(Base):
|
||||
blue_validation_status = Column(String, nullable=True) # pending / approved / rejected
|
||||
blue_validation_notes = Column(Text, nullable=True)
|
||||
|
||||
# ── Remediation fields ───────────────────────────────────────────
|
||||
remediation_steps = Column(Text, nullable=True)
|
||||
remediation_status = Column(String, nullable=True) # pending / in_progress / completed / not_applicable
|
||||
remediation_assignee = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# ── Relationships ───────────────────────────────────────────────
|
||||
technique = relationship("Technique", back_populates="tests")
|
||||
evidences = relationship("Evidence", back_populates="test")
|
||||
creator = relationship("User", foreign_keys=[created_by])
|
||||
red_validator = relationship("User", foreign_keys=[red_validated_by])
|
||||
blue_validator = relationship("User", foreign_keys=[blue_validated_by])
|
||||
remediation_user = relationship("User", foreign_keys=[remediation_assignee])
|
||||
|
||||
@@ -34,6 +34,7 @@ class TestTemplate(Base):
|
||||
tool_suggested = Column(String, nullable=True)
|
||||
severity = Column(String, nullable=True) # low / medium / high / critical
|
||||
atomic_test_id = Column(String, nullable=True) # ID in Atomic Red Team repo
|
||||
suggested_remediation = Column(Text, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
270
backend/app/routers/reports.py
Normal file
270
backend/app/routers/reports.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Reports endpoints — export coverage summaries and test results.
|
||||
|
||||
Endpoints
|
||||
---------
|
||||
GET /reports/coverage-summary — full coverage JSON report
|
||||
GET /reports/coverage-csv — CSV export of coverage
|
||||
GET /reports/test-results — test results report (JSON)
|
||||
GET /reports/remediation-status — remediation status report (JSON)
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.dependencies.auth import get_current_user
|
||||
from app.models.enums import TestState
|
||||
from app.models.technique import Technique
|
||||
from app.models.test import Test
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter(prefix="/reports", tags=["reports"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /reports/coverage-summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/coverage-summary")
|
||||
def coverage_summary(
|
||||
tactic: Optional[str] = Query(None, description="Filter by tactic"),
|
||||
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Full coverage report as JSON — technique-by-technique with test counts."""
|
||||
query = db.query(Technique)
|
||||
if tactic:
|
||||
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
|
||||
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
|
||||
rows = []
|
||||
for t in techniques:
|
||||
# Count tests per state for this technique
|
||||
test_counts = (
|
||||
db.query(Test.state, func.count(Test.id))
|
||||
.filter(Test.technique_id == t.id)
|
||||
.group_by(Test.state)
|
||||
.all()
|
||||
)
|
||||
counts = {str(state): count for state, count in test_counts}
|
||||
|
||||
# Filter by platform if requested (check if technique platforms contain it)
|
||||
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
||||
continue
|
||||
|
||||
rows.append({
|
||||
"mitre_id": t.mitre_id,
|
||||
"name": t.name,
|
||||
"tactic": t.tactic,
|
||||
"platforms": t.platforms,
|
||||
"status_global": t.status_global,
|
||||
"total_tests": sum(counts.values()),
|
||||
"tests_by_state": counts,
|
||||
})
|
||||
|
||||
total = len(rows)
|
||||
validated = sum(1 for r in rows if r["status_global"] == "validated")
|
||||
partial = sum(1 for r in rows if r["status_global"] == "partial")
|
||||
not_covered = sum(1 for r in rows if r["status_global"] == "not_covered")
|
||||
in_progress = sum(1 for r in rows if r["status_global"] == "in_progress")
|
||||
not_evaluated = sum(1 for r in rows if r["status_global"] == "not_evaluated")
|
||||
|
||||
return {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"summary": {
|
||||
"total_techniques": total,
|
||||
"validated": validated,
|
||||
"partial": partial,
|
||||
"not_covered": not_covered,
|
||||
"in_progress": in_progress,
|
||||
"not_evaluated": not_evaluated,
|
||||
"coverage_percentage": round((validated / total * 100) if total > 0 else 0, 1),
|
||||
},
|
||||
"techniques": rows,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /reports/coverage-csv
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/coverage-csv")
|
||||
def coverage_csv(
|
||||
tactic: Optional[str] = Query(None),
|
||||
platform: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Export coverage as a downloadable CSV."""
|
||||
query = db.query(Technique)
|
||||
if tactic:
|
||||
query = query.filter(Technique.tactic.ilike(f"%{tactic}%"))
|
||||
|
||||
techniques = query.order_by(Technique.mitre_id).all()
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow([
|
||||
"MITRE ID", "Name", "Tactic", "Platforms", "Status",
|
||||
"Total Tests", "Validated", "In Progress", "Not Covered",
|
||||
])
|
||||
|
||||
for t in techniques:
|
||||
if platform and platform.lower() not in [p.lower() for p in (t.platforms or [])]:
|
||||
continue
|
||||
|
||||
test_counts = (
|
||||
db.query(Test.state, func.count(Test.id))
|
||||
.filter(Test.technique_id == t.id)
|
||||
.group_by(Test.state)
|
||||
.all()
|
||||
)
|
||||
counts = {str(state): count for state, count in test_counts}
|
||||
|
||||
writer.writerow([
|
||||
t.mitre_id,
|
||||
t.name,
|
||||
t.tactic,
|
||||
", ".join(t.platforms or []),
|
||||
t.status_global,
|
||||
sum(counts.values()),
|
||||
counts.get("validated", 0),
|
||||
sum(counts.get(s, 0) for s in ["draft", "red_executing", "blue_evaluating", "in_review"]),
|
||||
counts.get("rejected", 0),
|
||||
])
|
||||
|
||||
output.seek(0)
|
||||
return StreamingResponse(
|
||||
iter([output.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename=aegis_coverage_{datetime.utcnow().strftime('%Y%m%d')}.csv"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /reports/test-results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/test-results")
|
||||
def test_results(
|
||||
state: Optional[str] = Query(None),
|
||||
date_from: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
||||
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Report of test results with optional filters."""
|
||||
query = db.query(Test)
|
||||
|
||||
if state:
|
||||
query = query.filter(Test.state == state)
|
||||
if date_from:
|
||||
try:
|
||||
dt = datetime.fromisoformat(date_from)
|
||||
query = query.filter(Test.created_at >= dt)
|
||||
except ValueError:
|
||||
pass
|
||||
if date_to:
|
||||
try:
|
||||
dt = datetime.fromisoformat(date_to)
|
||||
query = query.filter(Test.created_at <= dt)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
tests = query.order_by(Test.created_at.desc()).all()
|
||||
|
||||
# Summary
|
||||
total = len(tests)
|
||||
by_state = {}
|
||||
by_result = {}
|
||||
for t in tests:
|
||||
s = t.state.value if hasattr(t.state, "value") else str(t.state)
|
||||
by_state[s] = by_state.get(s, 0) + 1
|
||||
if t.detection_result:
|
||||
r = t.detection_result.value if hasattr(t.detection_result, "value") else str(t.detection_result)
|
||||
by_result[r] = by_result.get(r, 0) + 1
|
||||
|
||||
return {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"filters": {"state": state, "date_from": date_from, "date_to": date_to},
|
||||
"summary": {
|
||||
"total_tests": total,
|
||||
"by_state": by_state,
|
||||
"by_detection_result": by_result,
|
||||
},
|
||||
"tests": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"technique_id": str(t.technique_id),
|
||||
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
||||
"platform": t.platform,
|
||||
"attack_success": t.attack_success,
|
||||
"detection_result": (
|
||||
t.detection_result.value if t.detection_result and hasattr(t.detection_result, "value")
|
||||
else str(t.detection_result) if t.detection_result else None
|
||||
),
|
||||
"red_validation_status": t.red_validation_status,
|
||||
"blue_validation_status": t.blue_validation_status,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in tests
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /reports/remediation-status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/remediation-status")
|
||||
def remediation_status(
|
||||
status: Optional[str] = Query(None, description="Filter by remediation status"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Report of remediation status across all tests."""
|
||||
query = db.query(Test).filter(Test.remediation_steps.isnot(None))
|
||||
|
||||
if status:
|
||||
query = query.filter(Test.remediation_status == status)
|
||||
|
||||
tests = query.order_by(Test.created_at.desc()).all()
|
||||
|
||||
by_status = {}
|
||||
for t in tests:
|
||||
s = t.remediation_status or "unset"
|
||||
by_status[s] = by_status.get(s, 0) + 1
|
||||
|
||||
return {
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"summary": {
|
||||
"total_with_remediation": len(tests),
|
||||
"by_status": by_status,
|
||||
},
|
||||
"tests": [
|
||||
{
|
||||
"id": str(t.id),
|
||||
"name": t.name,
|
||||
"technique_id": str(t.technique_id),
|
||||
"state": t.state.value if hasattr(t.state, "value") else str(t.state),
|
||||
"remediation_status": t.remediation_status,
|
||||
"remediation_steps": t.remediation_steps,
|
||||
"remediation_assignee": str(t.remediation_assignee) if t.remediation_assignee else None,
|
||||
}
|
||||
for t in tests
|
||||
],
|
||||
}
|
||||
@@ -40,6 +40,7 @@ from app.schemas.test import (
|
||||
TestBlueUpdate,
|
||||
TestRedValidate,
|
||||
TestBlueValidate,
|
||||
TestRemediationUpdate,
|
||||
)
|
||||
from app.schemas.test_template import TestTemplateInstantiate
|
||||
from app.services.audit_service import log_action
|
||||
@@ -211,6 +212,7 @@ def create_test_from_template(
|
||||
platform=template.platform,
|
||||
procedure_text=template.attack_procedure,
|
||||
tool_used=template.tool_suggested,
|
||||
remediation_steps=template.suggested_remediation,
|
||||
created_by=current_user.id,
|
||||
state=TestState.draft,
|
||||
)
|
||||
@@ -520,6 +522,40 @@ def reopen(
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PATCH /tests/{id}/remediation — update remediation fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.patch("/{test_id}/remediation", response_model=TestOut)
|
||||
def update_remediation(
|
||||
test_id: uuid.UUID,
|
||||
payload: TestRemediationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update remediation fields on a test (any authenticated user)."""
|
||||
test = _get_test_or_404(db, test_id)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(test, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(test)
|
||||
|
||||
log_action(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
action="update_remediation",
|
||||
entity_type="test",
|
||||
entity_id=test.id,
|
||||
details={"updated_fields": list(update_data.keys())},
|
||||
)
|
||||
|
||||
return test
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tests/{id}/timeline — audit history for this test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -81,6 +81,17 @@ class TestBlueValidate(BaseModel):
|
||||
blue_validation_notes: str | None = None
|
||||
|
||||
|
||||
# ── Remediation update ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRemediationUpdate(BaseModel):
|
||||
"""Payload for updating remediation fields."""
|
||||
|
||||
remediation_steps: str | None = None
|
||||
remediation_status: str | None = None # pending / in_progress / completed / not_applicable
|
||||
remediation_assignee: uuid.UUID | None = None
|
||||
|
||||
|
||||
# ── Legacy validate (kept for backwards compat) ────────────────────
|
||||
|
||||
|
||||
@@ -126,6 +137,11 @@ class TestOut(BaseModel):
|
||||
blue_validation_status: str | None = None
|
||||
blue_validation_notes: str | None = None
|
||||
|
||||
# Remediation fields
|
||||
remediation_steps: str | None = None
|
||||
remediation_status: str | None = None
|
||||
remediation_assignee: uuid.UUID | None = None
|
||||
|
||||
# Technique info (populated when joined)
|
||||
technique_mitre_id: str | None = None
|
||||
technique_name: str | None = None
|
||||
|
||||
@@ -24,6 +24,7 @@ class TestTemplateOut(BaseModel):
|
||||
tool_suggested: str | None = None
|
||||
severity: str | None = None
|
||||
atomic_test_id: str | None = None
|
||||
suggested_remediation: str | None = None
|
||||
is_active: bool = True
|
||||
created_at: datetime | None = None
|
||||
|
||||
@@ -47,6 +48,7 @@ class TestTemplateCreate(BaseModel):
|
||||
tool_suggested: str | None = None
|
||||
severity: str | None = None
|
||||
atomic_test_id: str | None = None
|
||||
suggested_remediation: str | None = None
|
||||
|
||||
|
||||
# ── Summary (for listings) ─────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user