refactor(docs+comments): add Google-style docstrings and inline comments across backend

Task D — Google-style docstrings (Args/Returns) on every public function,
method, and class across all 158 Python files in the backend. Zero ruff D
violations (pydocstyle Google convention).

Task E — Explanatory one-line comment before every code line (~11600 new
comments). ruff check passes clean after isort re-sort.
This commit is contained in:
kitos
2026-06-10 12:37:15 +02:00
parent 9ff0f04ba3
commit d2a46feba8
158 changed files with 14861 additions and 248 deletions
+359 -8
View File
@@ -7,31 +7,56 @@ Uses ``bulk_technique_scores`` so that snapshot creation runs in a fixed
number of SQL queries regardless of technique count.
"""
# Import logging
import logging
# Import uuid
import uuid
# Import defaultdict from collections
from collections import defaultdict
# Import datetime, timedelta, timezone from datetime
from datetime import datetime, timedelta, timezone
# Import func from sqlalchemy
from sqlalchemy import func
# Import Session from sqlalchemy.orm
from sqlalchemy.orm import Session
# Import EntityNotFoundError from app.domain.errors
from app.domain.errors import EntityNotFoundError
# Import CoverageSnapshot, SnapshotTechniqueState from app.models.coverage_snapshot
from app.models.coverage_snapshot import CoverageSnapshot, SnapshotTechniqueState
# Import TechniqueStatus from app.models.enums
from app.models.enums import TechniqueStatus
# Import Technique from app.models.technique
from app.models.technique import Technique
# Import from app.services.scoring_service
from app.services.scoring_service import (
bulk_technique_scores,
calculate_organization_score,
)
# Assign logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# Coverage status ordering for snapshot delta comparisons (higher = better coverage)
_STATUS_ORDER: dict[str, int] = {
# Literal argument value
"not_evaluated": 0,
# Literal argument value
"not_covered": 1,
# Literal argument value
"in_progress": 2,
# Literal argument value
"partial": 3,
# Literal argument value
"validated": 4,
}
@@ -42,97 +67,207 @@ _STATUS_ORDER: dict[str, int] = {
def serialize_snapshot_summary(snap: CoverageSnapshot) -> dict:
"""Lightweight serialization for list views."""
"""Return a lightweight serialization of a snapshot for list views.
Args:
snap (CoverageSnapshot): The snapshot ORM object to serialize.
Returns:
dict: Flat dictionary with summary fields (counts, scores, tactic
breakdown) suitable for paginated list responses.
"""
# Return {
return {
# Literal argument value
"id": str(snap.id),
# Literal argument value
"name": snap.name,
# Literal argument value
"organization_score": snap.organization_score,
# Literal argument value
"total_techniques": snap.total_techniques,
# Literal argument value
"validated_count": snap.validated_count,
# Literal argument value
"partial_count": snap.partial_count,
# Literal argument value
"not_covered_count": snap.not_covered_count,
# Literal argument value
"in_progress_count": snap.in_progress_count,
# Literal argument value
"not_evaluated_count": snap.not_evaluated_count,
# Literal argument value
"coverage_percentage": getattr(snap, "coverage_percentage", 0.0),
# Literal argument value
"by_tactic": getattr(snap, "by_tactic", None) or {},
# Literal argument value
"by_status": getattr(snap, "by_status", None) or {},
# Literal argument value
"stale_count": getattr(snap, "stale_count", 0),
# Literal argument value
"never_tested_count": getattr(snap, "never_tested_count", 0),
# Literal argument value
"created_by": str(snap.created_by) if snap.created_by else None,
# Literal argument value
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
# Define function serialize_snapshot_detail
def serialize_snapshot_detail(db: Session, snap: CoverageSnapshot) -> dict:
"""Full serialization including technique states."""
"""Return full serialization of a snapshot including per-technique states.
Args:
db (Session): Active SQLAlchemy database session.
snap (CoverageSnapshot): The snapshot ORM object to serialize.
Returns:
dict: Summary fields merged with a ``technique_states`` list, each
entry containing ``mitre_id``, ``technique_id``, ``status``,
and ``score``.
"""
# Assign base = serialize_snapshot_summary(snap)
base = serialize_snapshot_summary(snap)
# Assign technique_states = (
technique_states = (
db.query(SnapshotTechniqueState)
# Chain .filter() call
.filter(SnapshotTechniqueState.snapshot_id == snap.id)
# Chain .order_by() call
.order_by(SnapshotTechniqueState.mitre_id)
# Chain .all() call
.all()
)
# Assign base["technique_states"] = [
base["technique_states"] = [
{
# Literal argument value
"mitre_id": s.mitre_id,
# Literal argument value
"technique_id": str(s.technique_id),
# Literal argument value
"status": s.status,
# Literal argument value
"score": s.score,
}
for s in technique_states
]
# Return base
return base
# Define function list_snapshots
def list_snapshots(
# Entry: db
db: Session,
*,
# Entry: offset
offset: int = 0,
# Entry: limit
limit: int = 50,
) -> dict:
"""List coverage snapshots ordered by creation date (newest first)."""
"""List coverage snapshots ordered by creation date (newest first).
Args:
db (Session): Active SQLAlchemy database session.
offset (int): Number of records to skip for pagination.
limit (int): Maximum number of records to return.
Returns:
dict: Contains ``total``, ``offset``, ``limit``, and ``items`` (list
of serialized snapshot summaries).
"""
# Assign query = db.query(CoverageSnapshot)
query = db.query(CoverageSnapshot)
# Assign total = query.count()
total = query.count()
# Assign snapshots = (
snapshots = (
query
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at.desc())
# Chain .offset() call
.offset(offset)
# Chain .limit() call
.limit(limit)
# Chain .all() call
.all()
)
# Return {
return {
# Literal argument value
"total": total,
# Literal argument value
"offset": offset,
# Literal argument value
"limit": limit,
# Literal argument value
"items": [serialize_snapshot_summary(s) for s in snapshots],
}
# Define function get_snapshot_or_raise
def get_snapshot_or_raise(db: Session, snapshot_id: str) -> CoverageSnapshot:
"""Fetch snapshot by ID or raise EntityNotFoundError."""
"""Fetch snapshot by ID or raise EntityNotFoundError.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_id (str): UUID string of the snapshot to retrieve.
Returns:
CoverageSnapshot: The matching snapshot ORM object.
"""
# Attempt the following; catch errors below
try:
# Assign sid = uuid.UUID(snapshot_id)
sid = uuid.UUID(snapshot_id)
# Handle (ValueError, TypeError)
except (ValueError, TypeError):
# Raise EntityNotFoundError
raise EntityNotFoundError("Snapshot", snapshot_id)
# Assign snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
snapshot = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == sid).first()
# Check: snapshot is None
if snapshot is None:
# Raise EntityNotFoundError
raise EntityNotFoundError("Snapshot", snapshot_id)
# Return snapshot
return snapshot
# Define function get_snapshot_detail
def get_snapshot_detail(db: Session, snapshot_id: str) -> dict:
"""Get detailed snapshot including per-technique states."""
"""Return detailed snapshot data including per-technique states.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_id (str): UUID string of the snapshot to retrieve.
Returns:
dict: Full snapshot serialization from
:func:`serialize_snapshot_detail`.
"""
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
snapshot = get_snapshot_or_raise(db, snapshot_id)
# Return serialize_snapshot_detail(db, snapshot)
return serialize_snapshot_detail(db, snapshot)
# Define function delete_snapshot
def delete_snapshot(db: Session, snapshot_id: str) -> None:
"""Delete a snapshot. Does not commit — caller must commit."""
"""Delete a snapshot. Does not commit — caller must commit.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_id (str): UUID string of the snapshot to delete.
"""
# Assign snapshot = get_snapshot_or_raise(db, snapshot_id)
snapshot = get_snapshot_or_raise(db, snapshot_id)
# Mark record for deletion on next commit
db.delete(snapshot)
@@ -142,8 +277,11 @@ def delete_snapshot(db: Session, snapshot_id: str) -> None:
def create_snapshot(
# Entry: db
db: Session,
# Entry: name
name: str | None = None,
# Entry: user_id
user_id: uuid.UUID | None = None,
) -> CoverageSnapshot:
"""Capture the current coverage state into a new snapshot.
@@ -153,121 +291,215 @@ def create_snapshot(
3. Compute the org score from the same bulk data.
4. Persist a ``CoverageSnapshot`` with normalised
``SnapshotTechniqueState`` rows.
Args:
db (Session): Active SQLAlchemy database session.
name (str | None): Optional human-readable label for the snapshot.
user_id (uuid.UUID | None): UUID of the user creating the snapshot,
stored for auditing.
Returns:
CoverageSnapshot: The newly created and committed snapshot ORM object.
"""
# Assign scores_map = bulk_technique_scores(db)
scores_map = bulk_technique_scores(db)
# Assign techniques = db.query(Technique).all()
techniques = db.query(Technique).all()
# Assign validated_count = 0
validated_count = 0
# Assign partial_count = 0
partial_count = 0
# Assign not_covered_count = 0
not_covered_count = 0
# Assign in_progress_count = 0
in_progress_count = 0
# Assign not_evaluated_count = 0
not_evaluated_count = 0
# Assign stale_count = 0
stale_count = 0
# Assign never_tested_count = 0
never_tested_count = 0
# Assign by_tactic = defaultdict(
by_tactic: dict[str, dict] = defaultdict(
# Entry: lambda
lambda: {"total": 0, "validated": 0, "partial": 0, "score_sum": 0.0}
)
# Assign by_status = defaultdict(int)
by_status: dict[str, int] = defaultdict(int)
# Assign technique_rows = []
technique_rows: list[dict] = []
# Iterate over techniques
for tech in techniques:
# Assign status_value = (
status_value = (
tech.status_global.value
if isinstance(tech.status_global, TechniqueStatus)
else (tech.status_global or "not_evaluated")
)
# Check: status_value == "validated"
if status_value == "validated":
# Assign validated_count = 1
validated_count += 1
# Alternative: status_value == "partial"
elif status_value == "partial":
# Assign partial_count = 1
partial_count += 1
# Alternative: status_value == "not_covered"
elif status_value == "not_covered":
# Assign not_covered_count = 1
not_covered_count += 1
# Alternative: status_value == "in_progress"
elif status_value == "in_progress":
# Assign in_progress_count = 1
in_progress_count += 1
# Fallback: handle remaining cases
else:
# Assign not_evaluated_count = 1
not_evaluated_count += 1
# Assign entry = scores_map.get(tech.id, {})
entry = scores_map.get(tech.id, {})
# Assign score = entry.get("total_score", 0)
score = entry.get("total_score", 0)
# Call technique_rows.append()
technique_rows.append({
# Literal argument value
"technique_id": tech.id,
# Literal argument value
"mitre_id": tech.mitre_id,
# Literal argument value
"status": status_value,
# Literal argument value
"score": score,
})
# Assign by_status[status_value] = 1
by_status[status_value] += 1
# Assign tactic_key = tech.tactic or "unknown"
tactic_key = tech.tactic or "unknown"
# Assign bucket = by_tactic[tactic_key]
bucket = by_tactic[tactic_key]
# Assign bucket["total"] = 1
bucket["total"] += 1
# Assign bucket["score_sum"] = score
bucket["score_sum"] += score
# Check: status_value == "validated"
if status_value == "validated":
# Assign bucket["validated"] = 1
bucket["validated"] += 1
# Alternative: status_value == "partial"
elif status_value == "partial":
# Assign bucket["partial"] = 1
bucket["partial"] += 1
# Check: status_value == "not_evaluated"
if status_value == "not_evaluated":
# Assign never_tested_count = 1
never_tested_count += 1
# Check: tech.review_required
if tech.review_required:
# Assign stale_count = 1
stale_count += 1
# Assign org_data = calculate_organization_score(db)
org_data = calculate_organization_score(db)
# Assign org_score = org_data.get("overall_score", 0)
org_score = org_data.get("overall_score", 0)
# Assign total_techniques = len(techniques) or 1
total_techniques = len(techniques) or 1
# Assign coverage_pct = round((validated_count / total_techniques) * 100, 1)
coverage_pct = round((validated_count / total_techniques) * 100, 1)
# Assign by_tactic_out = {
by_tactic_out = {
# Entry: tactic
tactic: {
# Literal argument value
"total": data["total"],
# Literal argument value
"validated": data["validated"],
# Literal argument value
"partial": data["partial"],
# Literal argument value
"average_score": round(data["score_sum"] / data["total"], 1) if data["total"] else 0,
}
for tactic, data in by_tactic.items()
}
# Assign snapshot = CoverageSnapshot(
snapshot = CoverageSnapshot(
# Keyword argument: name
name=name,
# Keyword argument: organization_score
organization_score=org_score,
# Keyword argument: total_techniques
total_techniques=len(techniques),
# Keyword argument: validated_count
validated_count=validated_count,
# Keyword argument: partial_count
partial_count=partial_count,
# Keyword argument: not_covered_count
not_covered_count=not_covered_count,
# Keyword argument: in_progress_count
in_progress_count=in_progress_count,
# Keyword argument: not_evaluated_count
not_evaluated_count=not_evaluated_count,
# Keyword argument: coverage_percentage
coverage_percentage=coverage_pct,
# Keyword argument: by_tactic
by_tactic=by_tactic_out,
# Keyword argument: by_status
by_status=dict(by_status),
# Keyword argument: stale_count
stale_count=stale_count,
# Keyword argument: never_tested_count
never_tested_count=never_tested_count,
# Keyword argument: created_by
created_by=user_id,
)
# Stage new record(s) for database insertion
db.add(snapshot)
# Flush changes to DB without committing the transaction
db.flush()
# Iterate over technique_rows
for row in technique_rows:
# Assign state = SnapshotTechniqueState(
state = SnapshotTechniqueState(
# Keyword argument: snapshot_id
snapshot_id=snapshot.id,
# Keyword argument: technique_id
technique_id=row["technique_id"],
# Keyword argument: mitre_id
mitre_id=row["mitre_id"],
# Keyword argument: status
status=row["status"],
# Keyword argument: score
score=row["score"],
)
# Stage new record(s) for database insertion
db.add(state)
# Commit all pending changes to the database
db.commit()
# Reload ORM object attributes from the database
db.refresh(snapshot)
# Log info:
logger.info(
# Literal argument value
"Snapshot '%s' created — %d techniques, org score %.1f",
snapshot.name or snapshot.id,
len(techniques),
org_score,
)
# Return snapshot
return snapshot
@@ -277,90 +509,160 @@ def create_snapshot(
def compare_snapshots(
# Entry: db
db: Session,
# Entry: snapshot_a_id
snapshot_a_id: uuid.UUID,
# Entry: snapshot_b_id
snapshot_b_id: uuid.UUID,
) -> dict:
"""Compare two snapshots and return deltas.
Returns improved/worsened technique lists plus aggregate statistics.
Args:
db (Session): Active SQLAlchemy database session.
snapshot_a_id (uuid.UUID): UUID of the baseline (older) snapshot.
snapshot_b_id (uuid.UUID): UUID of the comparison (newer) snapshot.
Returns:
dict: Contains ``snapshot_a``, ``snapshot_b``, ``score_delta``,
``improved``, ``worsened``, ``unchanged_count``, and ``summary``
keys.
"""
# Assign snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a...
snap_a = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_a_id).first()
# Assign snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b...
snap_b = db.query(CoverageSnapshot).filter(CoverageSnapshot.id == snapshot_b_id).first()
# Check: not snap_a or not snap_b
if not snap_a or not snap_b:
# Raise EntityNotFoundError
raise EntityNotFoundError("Snapshot", f"{snapshot_a_id} or {snapshot_b_id}")
# Build lookup dicts: mitre_id -> {status, score}
states_a = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
# Chain .filter() call
.filter(SnapshotTechniqueState.snapshot_id == snapshot_a_id)
# Chain .all() call
.all()
}
# Assign states_b = {
states_b = {
s.mitre_id: {"status": s.status, "score": s.score or 0}
for s in db.query(SnapshotTechniqueState)
# Chain .filter() call
.filter(SnapshotTechniqueState.snapshot_id == snapshot_b_id)
# Chain .all() call
.all()
}
# Assign improved = []
improved = []
# Assign worsened = []
worsened = []
# Assign unchanged_count = 0
unchanged_count = 0
# Assign all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
all_mitre_ids = set(states_a.keys()) | set(states_b.keys())
# Iterate over sorted(all_mitre_ids)
for mitre_id in sorted(all_mitre_ids):
# Assign a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
a = states_a.get(mitre_id, {"status": "not_evaluated", "score": 0})
# Assign b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
b = states_b.get(mitre_id, {"status": "not_evaluated", "score": 0})
# Assign a_order = _STATUS_ORDER.get(a["status"], 0)
a_order = _STATUS_ORDER.get(a["status"], 0)
# Assign b_order = _STATUS_ORDER.get(b["status"], 0)
b_order = _STATUS_ORDER.get(b["status"], 0)
# Check: b_order > a_order or (b_order == a_order and b["score"] > a["score"])
if b_order > a_order or (b_order == a_order and b["score"] > a["score"]):
# Call improved.append()
improved.append({
# Literal argument value
"mitre_id": mitre_id,
# Literal argument value
"old_status": a["status"],
# Literal argument value
"new_status": b["status"],
# Literal argument value
"old_score": a["score"],
# Literal argument value
"new_score": b["score"],
})
# Alternative: b_order < a_order or (b_order == a_order and b["score"] < a["score"])
elif b_order < a_order or (b_order == a_order and b["score"] < a["score"]):
# Call worsened.append()
worsened.append({
# Literal argument value
"mitre_id": mitre_id,
# Literal argument value
"old_status": a["status"],
# Literal argument value
"new_status": b["status"],
# Literal argument value
"old_score": a["score"],
# Literal argument value
"new_score": b["score"],
})
# Fallback: handle remaining cases
else:
# Assign unchanged_count = 1
unchanged_count += 1
# Define function _snap_summary
def _snap_summary(snap: CoverageSnapshot) -> dict:
# Return {
return {
# Literal argument value
"id": str(snap.id),
# Literal argument value
"name": snap.name,
# Literal argument value
"organization_score": snap.organization_score,
# Literal argument value
"total_techniques": snap.total_techniques,
# Literal argument value
"validated_count": snap.validated_count,
# Literal argument value
"partial_count": snap.partial_count,
# Literal argument value
"not_covered_count": snap.not_covered_count,
# Literal argument value
"in_progress_count": snap.in_progress_count,
# Literal argument value
"not_evaluated_count": snap.not_evaluated_count,
# Literal argument value
"created_at": snap.created_at.isoformat() if snap.created_at else None,
}
# Return {
return {
# Literal argument value
"snapshot_a": _snap_summary(snap_a),
# Literal argument value
"snapshot_b": _snap_summary(snap_b),
# Literal argument value
"score_delta": round(snap_b.organization_score - snap_a.organization_score, 1),
# Literal argument value
"improved": improved,
# Literal argument value
"worsened": worsened,
# Literal argument value
"unchanged_count": unchanged_count,
# Literal argument value
"summary": {
# Literal argument value
"improved_count": len(improved),
# Literal argument value
"worsened_count": len(worsened),
# Literal argument value
"new_count": len(states_b.keys() - states_a.keys()),
},
}
@@ -372,25 +674,53 @@ def compare_snapshots(
def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
"""Return snapshot trend points for the last *months* months."""
"""Return snapshot trend points for the last *months* months.
Args:
db (Session): Active SQLAlchemy database session.
months (int): Number of months to look back; defaults to 12.
Returns:
list[dict]: Snapshot trend entries ordered by creation date ascending,
each containing ``date``, ``name``, ``org_score``,
``coverage_pct``, ``by_tactic``, ``by_status``,
``stale_count``, ``never_tested_count``, ``validated_count``,
and ``total_techniques``.
"""
# Assign cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
cutoff = datetime.now(timezone.utc) - timedelta(days=months * 30)
# Assign snapshots = (
snapshots = (
db.query(CoverageSnapshot)
# Chain .filter() call
.filter(CoverageSnapshot.created_at >= cutoff)
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at.asc())
# Chain .all() call
.all()
)
# Return [
return [
{
# Literal argument value
"date": snap.created_at.isoformat() if snap.created_at else None,
# Literal argument value
"name": snap.name,
# Literal argument value
"org_score": snap.organization_score,
# Literal argument value
"coverage_pct": getattr(snap, "coverage_percentage", 0.0),
# Literal argument value
"by_tactic": getattr(snap, "by_tactic", None) or {},
# Literal argument value
"by_status": getattr(snap, "by_status", None) or {},
# Literal argument value
"stale_count": getattr(snap, "stale_count", 0),
# Literal argument value
"never_tested_count": getattr(snap, "never_tested_count", 0),
# Literal argument value
"validated_count": snap.validated_count,
# Literal argument value
"total_techniques": snap.total_techniques,
}
for snap in snapshots
@@ -405,25 +735,46 @@ def get_coverage_evolution(db: Session, *, months: int = 12) -> list[dict]:
def cleanup_old_snapshots(db: Session, keep_last: int = 52) -> int:
"""Delete oldest snapshots, keeping the most recent *keep_last*.
Returns the number of snapshots deleted.
Args:
db (Session): Active SQLAlchemy database session.
keep_last (int): Number of most-recent snapshots to retain; defaults
to 52 (one year of weekly snapshots).
Returns:
int: Number of snapshots deleted.
"""
# Assign total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
total = db.query(func.count(CoverageSnapshot.id)).scalar() or 0
# Check: total <= keep_last
if total <= keep_last:
# Return 0
return 0
# Assign to_delete = total - keep_last
to_delete = total - keep_last
# Assign old_snapshots = (
old_snapshots = (
db.query(CoverageSnapshot)
# Chain .order_by() call
.order_by(CoverageSnapshot.created_at.asc())
# Chain .limit() call
.limit(to_delete)
# Chain .all() call
.all()
)
# Assign deleted = 0
deleted = 0
# Iterate over old_snapshots
for snap in old_snapshots:
# Mark record for deletion on next commit
db.delete(snap)
# Assign deleted = 1
deleted += 1
# Commit all pending changes to the database
db.commit()
# Log info: "Snapshot cleanup — deleted %d old snapshots (kept
logger.info("Snapshot cleanup — deleted %d old snapshots (kept %d)", deleted, keep_last)
# Return deleted
return deleted