diff --git a/backend/app/services/report_generation_service.py b/backend/app/services/report_generation_service.py index 447443b..1c6ec64 100644 --- a/backend/app/services/report_generation_service.py +++ b/backend/app/services/report_generation_service.py @@ -2,6 +2,7 @@ import logging from datetime import datetime, timedelta +from uuid import UUID from sqlalchemy.orm import Session @@ -22,14 +23,15 @@ def generate_purple_campaign_report( output_format: str = "pdf", ) -> str: """Generate the full Purple Team campaign report.""" - campaign = db.query(Campaign).filter(Campaign.id == campaign_id).first() + cid = campaign_id if isinstance(campaign_id, UUID) else UUID(str(campaign_id)) + campaign = db.query(Campaign).filter(Campaign.id == cid).first() if not campaign: raise EntityNotFoundError("Campaign", campaign_id) campaign_tests = ( db.query(Test) .join(CampaignTest, CampaignTest.test_id == Test.id) - .filter(CampaignTest.campaign_id == campaign_id) + .filter(CampaignTest.campaign_id == cid) .all() ) @@ -227,6 +229,123 @@ def generate_executive_summary( return _generate(output_format, "executive_summary", context) +def generate_quarterly_summary( + db: Session, + output_format: str = "pdf", +) -> str: + """Quarterly summary — reuses executive metrics plus snapshot trend rows.""" + from sqlalchemy import case as sql_case, func + + org_score = _safe_org_score(db) + quarter_ago = datetime.utcnow() - timedelta(days=90) + tests_this_quarter = ( + db.query(func.count(Test.id)).filter(Test.created_at >= quarter_ago).scalar() or 0 + ) + + techniques = db.query(Technique).all() + validated_count = sum( + 1 for t in techniques if t.status_global and t.status_global.value == "validated" + ) + detected_count = ( + db.query(func.count(Test.id)) + .filter(Test.state == "validated", Test.detection_result == "detected") + .scalar() or 0 + ) + detection_rate = ( + round((detected_count / validated_count) * 100, 1) if validated_count > 0 else 0 + ) + + tactic_rows = ( + db.query( + Technique.tactic, + func.count(Technique.id).label("total"), + func.sum(sql_case((Technique.status_global == "validated", 1), else_=0)).label( + "validated", + ), + ) + .group_by(Technique.tactic) + .all() + ) + top_gaps = sorted( + [ + { + "tactic": r[0] or "Unknown", + "coverage_pct": round((int(r[2]) / r[1]) * 100, 1) if r[1] > 0 else 0, + } + for r in tactic_rows + ], + key=lambda x: x["coverage_pct"], + )[:5] + + snapshots = ( + db.query(CoverageSnapshot) + .filter(CoverageSnapshot.created_at >= quarter_ago) + .order_by(CoverageSnapshot.created_at) + .all() + ) + trend_rows = [ + { + "date": s.created_at.strftime("%Y-%m-%d") if s.created_at else "", + "validated_count": s.validated_count, + "total_techniques": s.total_techniques, + "organization_score": round(s.organization_score, 1), + } + for s in snapshots + ] + + now = datetime.utcnow() + quarter_label = f"Q{((now.month - 1) // 3) + 1} {now.year}" + + context = { + "quarter_label": quarter_label, + "org_score": org_score, + "tests_this_quarter": tests_this_quarter, + "detection_rate": detection_rate, + "trend_rows": trend_rows, + "top_gaps": top_gaps, + } + return _generate(output_format, "quarterly_summary", context) + + +def generate_technique_detail_report( + db: Session, + technique_id: str, + output_format: str = "pdf", +) -> str: + """Detailed report for a single MITRE technique and its tests.""" + tid = technique_id if isinstance(technique_id, UUID) else UUID(str(technique_id)) + technique = db.query(Technique).filter(Technique.id == tid).first() + if not technique: + raise EntityNotFoundError("Technique", str(technique_id)) + + related_tests = ( + db.query(Test) + .filter(Test.technique_id == tid) + .order_by(Test.created_at.desc()) + .all() + ) + tests_data = [ + { + "name": t.name, + "state": t.state.value if t.state else "draft", + "detection_result": ( + t.detection_result.value if t.detection_result else "pending" + ), + "created_at": t.created_at.strftime("%Y-%m-%d") if t.created_at else "", + } + for t in related_tests + ] + + context = { + "technique": technique, + "technique_status": ( + technique.status_global.value if technique.status_global else "not_evaluated" + ), + "tests": tests_data, + } + return _generate(output_format, "technique_detail", context) + + # ── Helpers ────────────────────────────────────────────────────────── diff --git a/backend/tests/test_report_generation_service.py b/backend/tests/test_report_generation_service.py new file mode 100644 index 0000000..2edf430 --- /dev/null +++ b/backend/tests/test_report_generation_service.py @@ -0,0 +1,58 @@ +"""Report generation service tests (FASE-2.3).""" + +import uuid +from unittest.mock import patch + +import pytest + +from app.domain.exceptions import EntityNotFoundError +from app.models.campaign import Campaign, CampaignTest +from app.models.enums import TestState +from app.models.technique import Technique +from app.models.test import Test + + +@patch("app.services.report_generation_service.report_engine.generate_pdf") +def test_generate_purple_campaign_report_pdf(mock_pdf, db, admin_user): + mock_pdf.return_value = "/tmp/fake.pdf" + + technique = Technique( + mitre_id="T1059.001", + name="PowerShell", + tactic="execution", + ) + db.add(technique) + db.flush() + + campaign = Campaign(name="Q1 Purple", description="Scope", status="active") + db.add(campaign) + db.flush() + + test = Test( + name="PS test", + technique_id=technique.id, + state=TestState.validated, + created_by=admin_user.id, + ) + db.add(test) + db.flush() + db.add(CampaignTest(campaign_id=campaign.id, test_id=test.id)) + db.commit() + + path = __import__( + "app.services.report_generation_service", + fromlist=["generate_purple_campaign_report"], + ).generate_purple_campaign_report(db, str(campaign.id), "pdf") + + assert path == "/tmp/fake.pdf" + mock_pdf.assert_called_once() + context = mock_pdf.call_args[0][1] + assert context["tests_validated"] == 1 + assert len(context["tests"]) == 1 + + +def test_generate_technique_detail_not_found(db): + from app.services.report_generation_service import generate_technique_detail_report + + with pytest.raises(EntityNotFoundError): + generate_technique_detail_report(db, str(uuid.uuid4()), "html")