refactor(types): add comprehensive type annotations across backend Python codebase

Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!
This commit is contained in:
kitos
2026-06-09 17:04:51 +02:00
parent 8f98bdd273
commit 9ff0f04ba3
51 changed files with 267 additions and 223 deletions
+4 -4
View File
@@ -15,7 +15,7 @@ router = APIRouter(prefix="/metrics/advanced", tags=["advanced-metrics"])
def coverage_by_tactic(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Coverage percentage broken down by MITRE ATT&CK tactic."""
return advanced_metrics_service.get_coverage_by_tactic(db)
@@ -24,7 +24,7 @@ def coverage_by_tactic(
def never_tested_techniques(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Techniques that have never had a test created."""
return advanced_metrics_service.get_never_tested_techniques(db)
@@ -33,7 +33,7 @@ def never_tested_techniques(
def avg_validation_time(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Average time from test creation to validation, computed from audit logs.
Returns overall average and per-phase averages where data is available.
@@ -45,6 +45,6 @@ def avg_validation_time(
def detection_rate_trend(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Monthly detection rate trend for the last 12 months."""
return advanced_metrics_service.get_detection_rate_trend(db)
+4 -4
View File
@@ -19,7 +19,7 @@ router = APIRouter(prefix="/analytics", tags=["analytics"])
def analytics_coverage(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Coverage per technique — flat format for BI dashboards."""
return analytics_service.get_coverage_analytics(db)
@@ -30,7 +30,7 @@ def analytics_tests(
date_to: str = Query(None, description="ISO date filter (<=)"),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""All tests with timestamps — flat format for BI dashboards."""
return analytics_service.get_tests_analytics(
db, date_from=date_from, date_to=date_to
@@ -41,7 +41,7 @@ def analytics_tests(
def analytics_trends(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Historical coverage snapshots for trend visualization."""
return analytics_service.get_trends_analytics(db)
@@ -50,6 +50,6 @@ def analytics_trends(
def analytics_operators(
db: Session = Depends(get_db),
user: User = Depends(require_role("admin")),
):
) -> list:
"""Per-operator metrics — for workload management dashboards."""
return analytics_service.get_operators_analytics(db)
+3 -3
View File
@@ -30,7 +30,7 @@ def list_audit_logs(
limit: int = Query(50, ge=1, le=100, description="Max records to return"),
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> AuditLogPage:
"""Return paginated audit logs with optional filters.
**Requires admin role.**
@@ -57,7 +57,7 @@ def list_audit_logs(
def list_actions(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list[str]:
"""Return a list of distinct action types in the audit log.
**Requires admin role.**
@@ -69,7 +69,7 @@ def list_actions(
def list_entity_types(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list[str]:
"""Return a list of distinct entity types in the audit log.
**Requires admin role.**
+4 -4
View File
@@ -46,7 +46,7 @@ def login(
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db),
):
) -> TokenResponse:
"""Authenticate a user and return a JWT access token.
Rate-limited to **5 attempts per minute per IP**. Failed and successful
@@ -110,7 +110,7 @@ def logout(
request: Request,
response: Response,
aegis_token: str | None = Cookie(None),
):
) -> dict:
"""Clear the authentication cookie and revoke the current token."""
bearer = (
request.headers.get("Authorization")
@@ -148,7 +148,7 @@ def logout(
@router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)):
def read_current_user(current_user: User = Depends(get_current_user)) -> UserOut:
"""Return the profile of the currently authenticated user."""
return current_user
@@ -158,7 +158,7 @@ def change_password(
body: PasswordChange,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Change the current user's password."""
auth_change_password(
db,
+12 -12
View File
@@ -107,7 +107,7 @@ def list_campaigns(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List campaigns with optional filters and pagination."""
return crud_list(
db,
@@ -129,7 +129,7 @@ def create_campaign(
payload: CampaignCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Create a new campaign."""
with UnitOfWork(db) as uow:
result = crud_create(
@@ -165,7 +165,7 @@ def get_campaign(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed campaign info including tests and progress."""
return crud_get_detail(db, campaign_id)
@@ -180,7 +180,7 @@ def update_campaign(
payload: CampaignUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Update a campaign. Only allowed in draft or active state."""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
@@ -214,7 +214,7 @@ def add_test_to_campaign(
payload: AddTestPayload,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Add a test to a campaign with optional ordering and dependency."""
with UnitOfWork(db) as uow:
result = crud_add_test(
@@ -239,7 +239,7 @@ def remove_test_from_campaign(
campaign_test_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Remove a test from a campaign."""
with UnitOfWork(db) as uow:
crud_remove_test(db, campaign_id, campaign_test_id)
@@ -256,7 +256,7 @@ def activate_campaign(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Activate a campaign, moving it from draft to active."""
with UnitOfWork(db) as uow:
campaign = crud_activate(db, campaign_id)
@@ -292,7 +292,7 @@ def complete_campaign(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "admin")),
):
) -> dict:
"""Mark a campaign as completed."""
with UnitOfWork(db) as uow:
campaign = crud_complete(db, campaign_id)
@@ -319,7 +319,7 @@ def get_campaign_progress_endpoint(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get progress statistics for a campaign."""
return crud_get_progress(db, campaign_id)
@@ -333,7 +333,7 @@ def generate_campaign_from_actor(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Auto-generate a campaign from a threat actor's uncovered techniques.
Creates tests from the best available templates and orders them
@@ -369,7 +369,7 @@ def schedule_campaign(
payload: SchedulePayload,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Configure or update the recurrence schedule for a campaign.
Only the campaign creator or admin can change scheduling.
@@ -411,6 +411,6 @@ def get_campaign_history(
campaign_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List all child campaigns (execution history) of a recurring campaign."""
return crud_get_history(db, campaign_id)
+7 -7
View File
@@ -34,7 +34,7 @@ router = APIRouter(prefix="/compliance", tags=["compliance"])
def list_frameworks_endpoint(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List all available compliance frameworks."""
return list_frameworks(db)
@@ -47,7 +47,7 @@ def framework_status(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get compliance status for each control in a framework."""
return get_framework_status(db, framework_id)
@@ -60,7 +60,7 @@ def framework_report(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get the full compliance report (same as status but marked as report)."""
return get_framework_status(db, framework_id)
@@ -73,7 +73,7 @@ def framework_report_csv(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> StreamingResponse:
"""Export compliance report as CSV."""
csv_bytes, filename = build_framework_report_csv(db, framework_id)
return StreamingResponse(
@@ -93,7 +93,7 @@ def framework_gaps(
framework_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get controls with techniques that are not adequately covered."""
return get_framework_gaps(db, framework_id)
@@ -105,7 +105,7 @@ def framework_gaps(
def import_nist(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Import NIST 800-53 Rev 5 mappings (admin only)."""
result = import_nist_800_53_mappings(db)
return result
@@ -115,7 +115,7 @@ def import_nist(
def import_cis(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Import CIS Controls v8 mappings (admin only)."""
result = import_cis_controls_v8_mappings(db)
return result
+4 -4
View File
@@ -38,7 +38,7 @@ def list_defensive_techniques(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List all D3FEND defensive techniques with optional filters."""
return list_defensive_techniques_svc(
db, tactic=tactic, search=search, offset=offset, limit=limit
@@ -53,7 +53,7 @@ def list_defensive_techniques(
def list_d3fend_tactics_endpoint(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a list of all D3FEND tactics with counts."""
return list_d3fend_tactics(db)
@@ -67,7 +67,7 @@ def get_defenses_for_attack_technique_endpoint(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Get all D3FEND defensive techniques mapped to a given ATT&CK technique."""
return get_defenses_for_attack_technique(db, mitre_id)
@@ -80,7 +80,7 @@ def get_defenses_for_attack_technique_endpoint(
def trigger_d3fend_import(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Import D3FEND techniques and ATT&CK mappings. Admin only."""
tech_result = import_d3fend_techniques(db)
mapping_result = import_d3fend_mappings(db)
+5 -5
View File
@@ -47,7 +47,7 @@ router = APIRouter(prefix="/data-sources", tags=["data-sources"])
def list_data_sources(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list:
"""List all registered data sources.
**Requires** the ``admin`` role.
@@ -61,7 +61,7 @@ def update_data_source(
body: DataSourceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Update a data source (enable/disable, change config).
**Requires** the ``admin`` role.
@@ -87,7 +87,7 @@ def sync_data_source(
source_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Trigger sync/import for a specific data source.
**Requires** the ``admin`` role.
@@ -99,7 +99,7 @@ def sync_data_source(
def sync_all_data_sources(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Trigger sync for all enabled data sources (sequentially).
**Requires** the ``admin`` role.
@@ -125,7 +125,7 @@ def get_data_source_stats(
source_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Get detailed statistics for a specific data source.
**Requires** the ``admin`` role.
+5 -5
View File
@@ -51,7 +51,7 @@ def list_detection_rules(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List detection rules with optional filters and pagination."""
return list_rules(
db,
@@ -72,7 +72,7 @@ def get_detection_rules_for_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Get detection rules associated with a test template."""
return get_rules_for_template(db, template_id)
@@ -84,7 +84,7 @@ def get_detection_rules_for_template(
def auto_associate_detection_rules(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Auto-associate test templates with detection rules by MITRE technique ID.
For each active template, find all active detection rules for the same
@@ -102,7 +102,7 @@ def get_detection_rules_for_test(
test_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Get detection rules relevant to a test, along with their evaluation results.
Finds rules by matching the test's technique_id to detection rules,
@@ -119,7 +119,7 @@ def evaluate_detection_rule(
payload: DetectionRuleEvaluate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
) -> dict:
"""Save or update the evaluation result for a detection rule on a test."""
return evaluate_rule(
db,
+4 -4
View File
@@ -88,7 +88,7 @@ async def upload_evidence(
notes: Optional[str] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> EvidenceOut:
"""Upload a file as evidence for the given test.
The ``team`` field (sent as form data) determines whether this is
@@ -154,7 +154,7 @@ def list_evidence(
team: Optional[str] = Query(None, description="Filter by team: red or blue"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[EvidenceOut]:
"""List all evidences for a test, optionally filtered by team."""
get_test_or_raise(db, test_id)
evidences = list_evidence_for_test(db, test_id, team=team)
@@ -171,7 +171,7 @@ def get_evidence(
evidence_id: _uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> EvidenceOut:
"""Return evidence metadata together with a presigned download URL."""
evidence = get_evidence_or_raise(db, evidence_id)
return _evidence_to_out(evidence)
@@ -187,7 +187,7 @@ def delete_evidence(
evidence_id: _uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Delete an evidence record.
Only allowed in editable states:
+5 -5
View File
@@ -28,7 +28,7 @@ def heatmap_coverage(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Coverage layer — score based on status_global of each technique."""
return heatmap_service.build_coverage_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -43,7 +43,7 @@ def heatmap_threat_actor(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Threat actor layer — techniques used by an actor with coverage color."""
return heatmap_service.build_threat_actor_layer(
db, actor_id, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -57,7 +57,7 @@ def heatmap_detection_rules(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Detection rules layer — score based on ratio of rules available vs total."""
return heatmap_service.build_detection_rules_layer(
db, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -72,7 +72,7 @@ def heatmap_campaign(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Campaign layer — only techniques in the campaign, colored by test state."""
return heatmap_service.build_campaign_layer(
db, campaign_id, platforms=platforms, tactics=tactics, min_score=min_score,
@@ -88,7 +88,7 @@ def export_navigator(
min_score: int = Query(0, ge=0, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> StreamingResponse:
"""Export a heatmap layer as a downloadable JSON file for ATT&CK Navigator."""
data = heatmap_service.build_navigator_export(
db, layer, layer_id=layer_id,
+6 -6
View File
@@ -29,7 +29,7 @@ def search_issues(
q: str = Query(..., min_length=2),
max_results: int = Query(10, le=50),
user: User = Depends(get_current_user),
):
) -> list[JiraIssueResult]:
"""Search Jira issues by JQL or free text."""
return jira_service.search_jira_issues(q, max_results)
@@ -39,7 +39,7 @@ def create_link(
body: JiraLinkCreate,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> JiraLinkOut:
"""Associate an Aegis entity with a Jira issue."""
with UnitOfWork(db) as uow:
link = jira_service.create_link(
@@ -74,7 +74,7 @@ def list_links(
entity_id: Optional[UUID] = None,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list[JiraLinkOut]:
"""List Jira links, optionally filtered by entity."""
return jira_service.list_links(
db,
@@ -88,7 +88,7 @@ def sync_link(
link_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(require_role("admin")),
):
) -> dict:
"""Force bidirectional sync for a specific Jira link."""
with UnitOfWork(db) as uow:
link = jira_service.get_link_or_raise(db, link_id)
@@ -102,7 +102,7 @@ def delete_link(
link_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> None:
"""Remove a Jira link."""
with UnitOfWork(db) as uow:
link = jira_service.delete_link(db, link_id)
@@ -123,7 +123,7 @@ def create_issue_from_entity(
entity_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Auto-create a Jira issue from an Aegis entity and link them."""
with UnitOfWork(db) as uow:
result = jira_service.create_issue_and_link(
+6 -6
View File
@@ -42,7 +42,7 @@ router = APIRouter(prefix="/metrics", tags=["metrics"])
def coverage_summary(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> CoverageSummary:
"""Return a global coverage summary across all techniques."""
return get_coverage_summary(db)
@@ -56,7 +56,7 @@ def coverage_summary(
def coverage_by_tactic(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[TacticCoverage]:
"""Return coverage breakdown grouped by tactic."""
return get_coverage_by_tactic(db)
@@ -70,7 +70,7 @@ def coverage_by_tactic(
def test_pipeline(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> TestPipelineCounts:
"""Return how many tests are in each pipeline state."""
return get_test_pipeline_counts(db)
@@ -84,7 +84,7 @@ def test_pipeline(
def team_activity(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[TeamActivity]:
"""Return activity summary for Red and Blue teams."""
return get_team_activity(db)
@@ -98,7 +98,7 @@ def team_activity(
def validation_rate(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[ValidationRate]:
"""Return approval and rejection rates for Red Lead and Blue Lead."""
return get_validation_rate(db)
@@ -112,6 +112,6 @@ def validation_rate(
def recent_tests(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[RecentTestItem]:
"""Return the 10 most recently created tests."""
return get_recent_tests(db, limit=10)
+4 -4
View File
@@ -39,7 +39,7 @@ def list_notifications_endpoint(
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list[NotificationOut]:
"""Return paginated notifications for the current user, newest first."""
return list_notifications(db, current_user.id, offset=offset, limit=limit)
@@ -53,7 +53,7 @@ def list_notifications_endpoint(
def unread_count(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> UnreadCountOut:
"""Return the number of unread notifications for the current user."""
count = get_unread_count(db, current_user.id)
return UnreadCountOut(unread_count=count)
@@ -69,7 +69,7 @@ def read_notification(
notification_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> NotificationOut:
"""Mark a single notification as read."""
with UnitOfWork(db) as uow:
notif = mark_as_read(db, notification_id, current_user.id)
@@ -86,7 +86,7 @@ def read_notification(
def read_all_notifications(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Mark all notifications for the current user as read."""
with UnitOfWork(db) as uow:
count = mark_all_as_read(db, current_user.id)
+3 -3
View File
@@ -25,7 +25,7 @@ router = APIRouter(prefix="/metrics/operational", tags=["operational-metrics"])
def operational_metrics(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get all operational metrics (MTTD, MTTR, etc.) — cached for 5 min."""
from app.services.score_cache import get_operational_metrics_cached
@@ -40,7 +40,7 @@ def operational_trend(
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get weekly trend data for operational metrics."""
return get_operational_trend(db, period)
@@ -52,6 +52,6 @@ def operational_trend(
def metrics_by_team(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get metrics broken down by Red Team vs Blue Team."""
return get_metrics_by_team(db)
+5 -5
View File
@@ -57,7 +57,7 @@ def list_osint_items(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""List OSINT items with optional filters."""
return service_list_osint_items(
db,
@@ -73,7 +73,7 @@ def list_osint_items(
def osint_summary(
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Summary statistics for OSINT items."""
return get_osint_summary(db)
@@ -83,7 +83,7 @@ def review_osint_item(
item_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> dict:
"""Mark an OSINT item as reviewed."""
with UnitOfWork(db) as uow:
item = mark_osint_reviewed(db, str(item_id))
@@ -101,7 +101,7 @@ def trigger_technique_enrichment(
technique_id: UUID,
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Manually trigger OSINT enrichment for a single technique."""
technique = get_technique_or_raise(db, technique_id)
count = enrich_technique_with_cves(db, technique)
@@ -119,7 +119,7 @@ def get_technique_osint(
reviewed: bool | None = Query(None),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> list:
"""Get all OSINT items for a specific technique."""
items = get_osint_items_for_technique(
db,
+5 -5
View File
@@ -29,7 +29,7 @@ def generate_purple_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate a Purple Team campaign assessment report."""
filepath = report_generation_service.generate_purple_campaign_report(
db, str(campaign_id), output_format=format,
@@ -48,7 +48,7 @@ def generate_coverage_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate an organization-wide MITRE ATT&CK coverage report."""
filepath = report_generation_service.generate_coverage_report(
db, output_format=format,
@@ -67,7 +67,7 @@ def generate_executive_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate an executive security summary report."""
filepath = report_generation_service.generate_executive_summary(
db, output_format=format,
@@ -86,7 +86,7 @@ def generate_quarterly_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_lead", "blue_lead", "viewer")),
):
) -> FileResponse:
"""Generate a quarterly security summary report."""
filepath = report_generation_service.generate_quarterly_summary(
db, output_format=format,
@@ -106,7 +106,7 @@ def generate_technique_report(
format: str = Query("pdf", pattern="^(pdf|docx|html)$"),
db: Session = Depends(get_db),
user: User = Depends(get_current_user),
):
) -> FileResponse:
"""Generate a detailed report for one MITRE technique."""
filepath = report_generation_service.generate_technique_detail_report(
db, str(technique_id), output_format=format,
+4 -4
View File
@@ -38,7 +38,7 @@ def coverage_summary(
platform: Optional[str] = Query(None, description="Filter by platform (in techniques)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Full coverage report as JSON — technique-by-technique with test counts."""
return build_coverage_summary(db, tactic=tactic, platform=platform)
@@ -49,7 +49,7 @@ def coverage_csv(
platform: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> StreamingResponse:
"""Export coverage as a downloadable CSV."""
rows = build_coverage_csv_rows(db, tactic=tactic, platform=platform)
@@ -74,7 +74,7 @@ def test_results(
date_to: Optional[str] = Query(None, description="ISO date string YYYY-MM-DD"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Report of test results with optional filters."""
return build_test_results_report(db, state=state, date_from=date_from, date_to=date_to)
@@ -84,6 +84,6 @@ def remediation_status(
status: Optional[str] = Query(None, description="Filter by remediation status"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Report of remediation status across all tests."""
return build_remediation_status_report(db, status=status)
+7 -7
View File
@@ -35,7 +35,7 @@ def score_technique(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed score with breakdown for a specific technique."""
return score_technique_by_mitre_id(db, mitre_id)
@@ -48,7 +48,7 @@ def score_tactic(
tactic: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get average score for a tactic."""
return calculate_tactic_score(tactic, db)
@@ -61,7 +61,7 @@ def score_threat_actor(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get coverage score against a specific threat actor."""
return score_actor_by_id(db, actor_id)
@@ -73,7 +73,7 @@ def score_threat_actor(
def score_organization(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get the overall organization security score (cached for 5 min)."""
from app.services.score_cache import get_organization_score_cached
@@ -88,7 +88,7 @@ def score_history(
period: str = Query("90d", pattern="^(30d|90d|1y)$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get historical score data points (weekly)."""
return get_score_history(db, period)
@@ -100,7 +100,7 @@ def score_history(
def get_scoring_config(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Get current scoring weights (admin only)."""
return get_weights_dict(db)
@@ -123,7 +123,7 @@ def update_scoring_config(
payload: ScoringConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Update scoring weights (admin only).
Weights are persisted in the database and survive restarts.
+6 -6
View File
@@ -52,7 +52,7 @@ def list_snapshots(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List coverage snapshots ordered by creation date (newest first)."""
return list_snapshots_svc(db, offset=offset, limit=limit)
@@ -66,7 +66,7 @@ def create_snapshot_endpoint(
payload: SnapshotCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead", "admin")),
):
) -> dict:
"""Create a manual coverage snapshot with an optional name."""
snapshot = create_snapshot(db, name=payload.name, user_id=current_user.id)
@@ -94,7 +94,7 @@ def coverage_evolution(
months: int = Query(12, ge=1, le=36),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return coverage snapshots for trend charts (last *months* months)."""
return get_coverage_evolution(db, months=months)
@@ -109,7 +109,7 @@ def compare_snapshots_endpoint(
b: str = Query(..., description="Snapshot B ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Compare two snapshots showing improved, worsened, and unchanged techniques."""
try:
a_id = uuid.UUID(a)
@@ -129,7 +129,7 @@ def get_snapshot(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed snapshot information including per-technique states."""
return get_snapshot_detail(db, snapshot_id)
@@ -143,7 +143,7 @@ def delete_snapshot_endpoint(
snapshot_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Delete a snapshot (admin only)."""
snapshot = get_snapshot_or_raise(db, snapshot_id)
+4 -4
View File
@@ -30,7 +30,7 @@ def trigger_mitre_sync(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Manually trigger a MITRE ATT&CK synchronisation.
**Requires** the ``admin`` role.
@@ -50,7 +50,7 @@ def trigger_mitre_sync(
def trigger_intel_scan(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Manually trigger a threat-intelligence scan.
**Requires** the ``admin`` role.
@@ -71,7 +71,7 @@ def trigger_atomic_import(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Trigger an import of Atomic Red Team tests as TestTemplates.
**Requires** the ``admin`` role.
@@ -101,7 +101,7 @@ def trigger_atomic_import(
@router.get("/scheduler-status")
def scheduler_status(
current_user: User = Depends(require_role("admin")),
):
) -> dict:
"""Return the current state of the background scheduler.
**Requires** the ``admin`` role.
+5 -5
View File
@@ -45,7 +45,7 @@ def list_techniques(
review_required: bool | None = Query(None, description="Filter by review flag"),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a lightweight list of techniques, optionally filtered."""
return repo.list_all(
tactic=tactic,
@@ -64,7 +64,7 @@ def get_technique(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Return full details for a single technique, including its tests and D3FEND defenses."""
return get_technique_detail(db, mitre_id)
@@ -84,7 +84,7 @@ def create_technique(
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")),
):
) -> TechniqueOut:
"""Create a new technique manually."""
if repo.exists_by_mitre_id(payload.mitre_id):
raise DuplicateEntityError("Technique", "mitre_id", payload.mitre_id)
@@ -124,7 +124,7 @@ def update_technique(
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_role("admin")),
):
) -> TechniqueOut:
"""Update one or more fields of an existing technique."""
entity = repo.find_by_mitre_id(mitre_id)
if entity is None:
@@ -160,7 +160,7 @@ def review_technique(
db: Session = Depends(get_db),
repo: SATechniqueRepository = Depends(get_technique_repository),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TechniqueOut:
"""Mark a technique as reviewed.
Sets ``review_required`` to *False* and records the current timestamp
+9 -9
View File
@@ -78,7 +78,7 @@ def _list_templates_handler(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a paginated, filterable list of test templates."""
return list_templates(
db,
@@ -102,7 +102,7 @@ def _list_templates_handler(
def template_stats(
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Return catalog statistics: active, by_source, by_platform."""
return get_template_stats(db)
@@ -117,7 +117,7 @@ def bulk_activate_templates(
activate: bool = Query(True, description="True to activate all, False to deactivate all"),
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Set all templates to active or inactive."""
count = bulk_activate(db, activate=activate)
with UnitOfWork(db) as uow:
@@ -148,7 +148,7 @@ def _templates_by_technique_handler(
mitre_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return all active templates mapped to a specific MITRE technique."""
return templates_by_technique(db, mitre_id)
@@ -163,7 +163,7 @@ def get_template(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> TestTemplateOut:
"""Return full details for a single test template."""
return get_template_or_raise(db, template_id)
@@ -182,7 +182,7 @@ def create_template(
payload: TestTemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestTemplateOut:
"""Create a custom test template."""
template = create_template_svc(db, **payload.model_dump())
with UnitOfWork(db) as uow:
@@ -215,7 +215,7 @@ def update_template(
payload: TestTemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestTemplateOut:
"""Update fields of an existing test template."""
template = update_template_svc(db, template_id, **payload.model_dump(exclude_unset=True))
with UnitOfWork(db) as uow:
@@ -243,7 +243,7 @@ def toggle_template_active(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestTemplateOut:
"""Toggle a template between active and inactive (is_active = not is_active)."""
template = toggle_template_active_svc(db, template_id)
with UnitOfWork(db) as uow:
@@ -271,7 +271,7 @@ def delete_template(
template_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> dict:
"""Soft-delete a test template by setting ``is_active=False``."""
template = get_template_or_raise(db, template_id)
soft_delete_template(db, template_id)
+19 -19
View File
@@ -126,7 +126,7 @@ def list_tests(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return a paginated list of tests, optionally filtered by state, technique, platform or creator."""
return crud_list_tests(
db,
@@ -156,7 +156,7 @@ def create_test(
payload: TestCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Create a new test linked to an existing technique.
``created_by`` is set automatically and ``state`` defaults to *draft*.
@@ -198,7 +198,7 @@ def create_test_from_template(
payload: TestTemplateInstantiate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Instantiate a real Test from an existing TestTemplate.
The template's fields are copied into the new test as starting data.
@@ -238,7 +238,7 @@ def get_test(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> TestOut:
"""Return full details for a single test, including its evidences."""
return crud_get_test_detail(db, test_id)
@@ -254,7 +254,7 @@ def update_test(
payload: TestUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Update one or more fields of an existing test.
Only leads or admins can update general test fields.
@@ -294,7 +294,7 @@ def update_test_classification(
payload: TestClassificationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> TestOut:
"""Update the data classification label for a test (admin only)."""
with UnitOfWork(db) as uow:
test = crud_get_test_or_raise(db, test_id)
@@ -324,7 +324,7 @@ def update_test_red(
payload: TestRedUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
) -> TestOut:
"""Red Team updates their fields (allowed in ``draft`` and ``red_executing``)."""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
@@ -354,7 +354,7 @@ def update_test_blue(
payload: TestBlueUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
) -> TestOut:
"""Blue Team updates their fields (allowed only in ``blue_evaluating``)."""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
@@ -383,7 +383,7 @@ def start_execution(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
) -> TestOut:
"""Move a test from ``draft`` to ``red_executing``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -403,7 +403,7 @@ def submit_red(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "red_lead")),
):
) -> TestOut:
"""Red Team finalises — move from ``red_executing`` to ``blue_evaluating``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -423,7 +423,7 @@ def submit_blue(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_tech", "blue_lead")),
):
) -> TestOut:
"""Blue Team finalises — move from ``blue_evaluating`` to ``in_review``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -443,7 +443,7 @@ def pause_timer(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
) -> TestOut:
"""Pause the running timer for the current phase (red_executing or blue_evaluating)."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -463,7 +463,7 @@ def resume_timer(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
) -> TestOut:
"""Resume the paused timer for the current phase."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -484,7 +484,7 @@ def validate_red(
payload: TestRedValidate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead")),
):
) -> TestOut:
"""Red Lead approves or rejects the red side of a test."""
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
@@ -511,7 +511,7 @@ def validate_blue(
payload: TestBlueValidate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("blue_lead")),
):
) -> TestOut:
"""Blue Lead approves or rejects the blue side of a test."""
test = crud_get_test_with_technique(db, test_id)
with UnitOfWork(db) as uow:
@@ -537,7 +537,7 @@ def reopen(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Reopen a rejected test, moving it back to ``draft``."""
test = crud_get_test_or_raise(db, test_id)
with UnitOfWork(db) as uow:
@@ -558,7 +558,7 @@ def update_remediation(
payload: TestRemediationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_any_role("red_lead", "blue_lead")),
):
) -> TestOut:
"""Update remediation fields on a test.
When ``remediation_status`` transitions to ``'completed'``, an automatic
@@ -602,7 +602,7 @@ def get_test_timeline(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return the chronological audit-log history for a test."""
return crud_get_test_timeline(db, test_id)
@@ -617,7 +617,7 @@ def get_retest_chain(
test_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Return the full chain of retests (original + all retests) for a test."""
chain = wf_get_retest_chain(db, test_id)
if not chain:
+4 -4
View File
@@ -36,7 +36,7 @@ def list_threat_actors(
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""List threat actors with optional filters and pagination.
**Requires** authentication (any role).
@@ -58,7 +58,7 @@ def get_threat_actor(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Get detailed info about a threat actor including techniques.
**Requires** authentication (any role).
@@ -71,7 +71,7 @@ def get_threat_actor_coverage(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> dict:
"""Calculate coverage percentage against a specific threat actor.
**Requires** authentication (any role).
@@ -87,7 +87,7 @@ def get_threat_actor_gaps(
actor_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
) -> list:
"""Identify techniques of this actor that are NOT fully validated.
**Requires** authentication (any role).
+4 -4
View File
@@ -30,7 +30,7 @@ router = APIRouter(prefix="/users", tags=["users"])
def list_users_route(
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> list[UserOut]:
"""Return a list of all users. **Requires admin role.**"""
return list_users(db)
@@ -45,7 +45,7 @@ def create_user_route(
payload: UserCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> UserOut:
"""Create a new user. **Requires admin role.**"""
with UnitOfWork(db) as uow:
user = create_user(
@@ -79,7 +79,7 @@ def get_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> UserOut:
"""Return a single user by ID. **Requires admin role.**"""
return get_user_or_raise(db, user_id)
@@ -95,7 +95,7 @@ def update_user_route(
payload: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_role("admin")),
):
) -> UserOut:
"""Update one or more fields of an existing user. **Requires admin role.**"""
update_data = payload.model_dump(exclude_unset=True)
with UnitOfWork(db) as uow:
+4 -4
View File
@@ -56,7 +56,7 @@ def create(
body: WorklogCreate,
db: Session = Depends(get_db),
user: User = Depends(require_any_role("red_tech", "blue_tech", "red_lead", "blue_lead")),
):
) -> WorklogOut:
"""Create a manually-logged worklog entry."""
with UnitOfWork(db) as uow:
wl = worklog_service.create_worklog(
@@ -82,7 +82,7 @@ def list_all(
user_id: Optional[UUID] = None,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
):
) -> list[WorklogOut]:
"""List worklogs with optional filters."""
return worklog_service.list_worklogs(
db,
@@ -97,7 +97,7 @@ def get_one(
worklog_id: UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
):
) -> WorklogOut:
"""Get a single worklog by ID."""
return worklog_service.get_worklog_or_raise(db, worklog_id)
@@ -107,7 +107,7 @@ def verify_integrity(
worklog_id: UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
):
) -> dict:
"""Check whether a worklog's integrity hash is still valid."""
wl = worklog_service.get_worklog_or_raise(db, worklog_id)
return {