feat(refactor): PEP8, type annotations, docstrings and PyJWT security fix
This commit is contained in:
@@ -6,111 +6,184 @@ that the router remains a thin HTTP adapter.
|
||||
This module is framework-agnostic: no FastAPI imports.
|
||||
"""
|
||||
|
||||
# Enable future language features for compatibility
|
||||
from __future__ import annotations
|
||||
|
||||
# Import csv
|
||||
import csv
|
||||
|
||||
# Import io
|
||||
import io
|
||||
|
||||
# Import Any from typing
|
||||
from typing import Any
|
||||
|
||||
# Import Session from sqlalchemy.orm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import EntityNotFoundError from app.domain.errors
|
||||
from app.domain.errors import EntityNotFoundError
|
||||
|
||||
# Import from app.models.compliance
|
||||
from app.models.compliance import (
|
||||
ComplianceFramework,
|
||||
ComplianceControl,
|
||||
ComplianceControlMapping,
|
||||
ComplianceFramework,
|
||||
)
|
||||
from app.models.technique import Technique
|
||||
from app.models.test_template import TestTemplate
|
||||
from app.models.threat_actor import ThreatActorTechnique
|
||||
from app.services.scoring_service import calculate_technique_score
|
||||
|
||||
# Import Technique from app.models.technique
|
||||
from app.models.technique import Technique
|
||||
|
||||
# Import TestTemplate from app.models.test_template
|
||||
from app.models.test_template import TestTemplate
|
||||
|
||||
# Import ThreatActorTechnique from app.models.threat_actor
|
||||
from app.models.threat_actor import ThreatActorTechnique
|
||||
|
||||
# Import calculate_technique_score from app.services.scoring_service
|
||||
from app.services.scoring_service import calculate_technique_score
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _classify_control(technique_scores: list[float]) -> str:
|
||||
"""Classify a control status based on its technique scores."""
|
||||
# Check: not technique_scores
|
||||
if not technique_scores:
|
||||
# Return "not_evaluated"
|
||||
return "not_evaluated"
|
||||
|
||||
# Assign all_above_70 = all(s >= 70 for s in technique_scores)
|
||||
all_above_70 = all(s >= 70 for s in technique_scores)
|
||||
# Assign any_above_30 = any(s >= 30 for s in technique_scores)
|
||||
any_above_30 = any(s >= 30 for s in technique_scores)
|
||||
# Assign all_below_30 = all(s < 30 for s in technique_scores)
|
||||
all_below_30 = all(s < 30 for s in technique_scores)
|
||||
# Assign all_zero = all(s == 0 for s in technique_scores)
|
||||
all_zero = all(s == 0 for s in technique_scores)
|
||||
|
||||
# Check: all_zero
|
||||
if all_zero:
|
||||
# Return "not_evaluated"
|
||||
return "not_evaluated"
|
||||
# Check: all_above_70
|
||||
if all_above_70:
|
||||
# Return "covered"
|
||||
return "covered"
|
||||
# Check: all_below_30
|
||||
if all_below_30:
|
||||
# Return "not_covered"
|
||||
return "not_covered"
|
||||
# Check: any_above_30
|
||||
if any_above_30:
|
||||
# Return "partially_covered"
|
||||
return "partially_covered"
|
||||
# Return "not_covered"
|
||||
return "not_covered"
|
||||
|
||||
|
||||
# Define function _get_control_status
|
||||
def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, Any]:
|
||||
"""Compute the status and score for a single control."""
|
||||
# Assign mappings = (
|
||||
mappings = (
|
||||
db.query(ComplianceControlMapping)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControlMapping.compliance_control_id == control.id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Check: not mappings
|
||||
if not mappings:
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"control_id": control.control_id,
|
||||
# Literal argument value
|
||||
"title": control.title,
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
# Literal argument value
|
||||
"status": "not_evaluated",
|
||||
# Literal argument value
|
||||
"score": 0,
|
||||
# Literal argument value
|
||||
"techniques_count": 0,
|
||||
# Literal argument value
|
||||
"techniques_covered": 0,
|
||||
# Literal argument value
|
||||
"techniques": [],
|
||||
}
|
||||
|
||||
# Assign technique_ids = [m.technique_id for m in mappings]
|
||||
technique_ids = [m.technique_id for m in mappings]
|
||||
# Assign techniques = (
|
||||
techniques = (
|
||||
db.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.id.in_(technique_ids))
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign tech_details = []
|
||||
tech_details = []
|
||||
# Assign scores = []
|
||||
scores = []
|
||||
# Assign covered_count = 0
|
||||
covered_count = 0
|
||||
|
||||
# Iterate over techniques
|
||||
for tech in techniques:
|
||||
# Assign result = calculate_technique_score(tech, db)
|
||||
result = calculate_technique_score(tech, db)
|
||||
# Assign score = result["total_score"]
|
||||
score = result["total_score"]
|
||||
# Call scores.append()
|
||||
scores.append(score)
|
||||
# Check: score >= 50
|
||||
if score >= 50:
|
||||
# Assign covered_count = 1
|
||||
covered_count += 1
|
||||
|
||||
# Call tech_details.append()
|
||||
tech_details.append({
|
||||
# Literal argument value
|
||||
"mitre_id": tech.mitre_id,
|
||||
# Literal argument value
|
||||
"name": tech.name,
|
||||
# Literal argument value
|
||||
"score": score,
|
||||
# Literal argument value
|
||||
"status": tech.status_global.value if tech.status_global else "not_evaluated",
|
||||
})
|
||||
|
||||
# Sort techniques by score ascending (worst first for priority)
|
||||
tech_details.sort(key=lambda t: t["score"])
|
||||
|
||||
# Assign avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
||||
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
|
||||
# Assign status = _classify_control(scores)
|
||||
status = _classify_control(scores)
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"control_id": control.control_id,
|
||||
# Literal argument value
|
||||
"title": control.title,
|
||||
"description": control.description,
|
||||
"category": control.category,
|
||||
# Literal argument value
|
||||
"status": status,
|
||||
# Literal argument value
|
||||
"score": avg_score,
|
||||
# Literal argument value
|
||||
"techniques_count": len(techniques),
|
||||
# Literal argument value
|
||||
"techniques_covered": covered_count,
|
||||
# Literal argument value
|
||||
"techniques": tech_details,
|
||||
}
|
||||
|
||||
@@ -120,95 +193,150 @@ def _get_control_status(control: ComplianceControl, db: Session) -> dict[str, An
|
||||
|
||||
def list_frameworks(db: Session) -> list[dict[str, Any]]:
|
||||
"""List all available compliance frameworks with control counts."""
|
||||
# Assign frameworks = (
|
||||
frameworks = (
|
||||
db.query(ComplianceFramework)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceFramework.is_active == True)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign result = []
|
||||
result = []
|
||||
# Iterate over frameworks
|
||||
for fw in frameworks:
|
||||
# Assign control_count = (
|
||||
control_count = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == fw.id)
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
# Call result.append()
|
||||
result.append({
|
||||
# Literal argument value
|
||||
"id": str(fw.id),
|
||||
# Literal argument value
|
||||
"name": fw.name,
|
||||
# Literal argument value
|
||||
"version": fw.version,
|
||||
# Literal argument value
|
||||
"description": fw.description,
|
||||
# Literal argument value
|
||||
"url": fw.url,
|
||||
# Literal argument value
|
||||
"is_active": fw.is_active,
|
||||
# Literal argument value
|
||||
"controls_count": control_count,
|
||||
})
|
||||
|
||||
# Return result
|
||||
return result
|
||||
|
||||
|
||||
# Define function get_framework
|
||||
def get_framework(db: Session, framework_id: str) -> ComplianceFramework | None:
|
||||
"""Get a framework by ID, or None if not found."""
|
||||
# Return (
|
||||
return (
|
||||
db.query(ComplianceFramework)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceFramework.id == framework_id)
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
# Define function get_framework_status
|
||||
def get_framework_status(db: Session, framework_id: str) -> dict[str, Any]:
|
||||
"""Get compliance status for each control in a framework.
|
||||
|
||||
Raises EntityNotFoundError if the framework does not exist.
|
||||
"""
|
||||
# Assign framework = get_framework(db, framework_id)
|
||||
framework = get_framework(db, framework_id)
|
||||
# Check: not framework
|
||||
if not framework:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Framework", framework_id)
|
||||
|
||||
# Assign controls = (
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(ComplianceControl.control_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign control_statuses = []
|
||||
control_statuses = []
|
||||
# Assign summary = {
|
||||
summary = {
|
||||
# Literal argument value
|
||||
"total_controls": len(controls),
|
||||
# Literal argument value
|
||||
"covered": 0,
|
||||
# Literal argument value
|
||||
"partially_covered": 0,
|
||||
# Literal argument value
|
||||
"not_covered": 0,
|
||||
# Literal argument value
|
||||
"not_evaluated": 0,
|
||||
}
|
||||
|
||||
# Iterate over controls
|
||||
for control in controls:
|
||||
# Assign status_data = _get_control_status(control, db)
|
||||
status_data = _get_control_status(control, db)
|
||||
# Call control_statuses.append()
|
||||
control_statuses.append(status_data)
|
||||
|
||||
# Assign status = status_data["status"]
|
||||
status = status_data["status"]
|
||||
# Check: status in summary
|
||||
if status in summary:
|
||||
# Assign summary[status] = 1
|
||||
summary[status] += 1
|
||||
|
||||
# Compliance percentage: (covered + partially_covered*0.5) / total * 100
|
||||
total = summary["total_controls"]
|
||||
# Check: total > 0
|
||||
if total > 0:
|
||||
# Assign compliance_pct = round(
|
||||
compliance_pct = round(
|
||||
(summary["covered"] + summary["partially_covered"] * 0.5) / total * 100,
|
||||
# Literal argument value
|
||||
1,
|
||||
)
|
||||
# Fallback: handle remaining cases
|
||||
else:
|
||||
# Assign compliance_pct = 0
|
||||
compliance_pct = 0
|
||||
|
||||
# Assign summary["compliance_percentage"] = compliance_pct
|
||||
summary["compliance_percentage"] = compliance_pct
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"framework": {"id": str(framework.id), "name": framework.name},
|
||||
# Literal argument value
|
||||
"summary": summary,
|
||||
# Literal argument value
|
||||
"controls": control_statuses,
|
||||
}
|
||||
|
||||
|
||||
# Define function build_framework_report_csv
|
||||
def build_framework_report_csv(
|
||||
# Entry: db
|
||||
db: Session,
|
||||
# Entry: framework_id
|
||||
framework_id: str,
|
||||
) -> tuple[bytes, str]:
|
||||
"""Build the compliance report CSV content and filename.
|
||||
@@ -217,33 +345,55 @@ def build_framework_report_csv(
|
||||
|
||||
Raises EntityNotFoundError if the framework does not exist.
|
||||
"""
|
||||
# Assign framework = get_framework(db, framework_id)
|
||||
framework = get_framework(db, framework_id)
|
||||
# Check: not framework
|
||||
if not framework:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Framework", framework_id)
|
||||
|
||||
# Assign controls = (
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(ComplianceControl.control_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign output = io.StringIO()
|
||||
output = io.StringIO()
|
||||
# Assign writer = csv.writer(output)
|
||||
writer = csv.writer(output)
|
||||
# Call writer.writerow()
|
||||
writer.writerow([
|
||||
# Literal argument value
|
||||
"control_id",
|
||||
# Literal argument value
|
||||
"title",
|
||||
# Literal argument value
|
||||
"category",
|
||||
# Literal argument value
|
||||
"status",
|
||||
# Literal argument value
|
||||
"score",
|
||||
# Literal argument value
|
||||
"techniques_total",
|
||||
# Literal argument value
|
||||
"techniques_covered",
|
||||
# Literal argument value
|
||||
"technique_ids",
|
||||
])
|
||||
|
||||
# Iterate over controls
|
||||
for control in controls:
|
||||
# Assign status_data = _get_control_status(control, db)
|
||||
status_data = _get_control_status(control, db)
|
||||
# Assign technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
|
||||
technique_ids = ",".join(t["mitre_id"] for t in status_data["techniques"])
|
||||
# Call writer.writerow()
|
||||
writer.writerow([
|
||||
status_data["control_id"],
|
||||
status_data["title"],
|
||||
@@ -255,75 +405,116 @@ def build_framework_report_csv(
|
||||
technique_ids,
|
||||
])
|
||||
|
||||
# Call output.seek()
|
||||
output.seek(0)
|
||||
# Assign filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
|
||||
filename = f"compliance_{framework.name.replace(' ', '_')}.csv"
|
||||
# Return output.getvalue().encode("utf-8"), filename
|
||||
return output.getvalue().encode("utf-8"), filename
|
||||
|
||||
|
||||
# Define function get_framework_gaps
|
||||
def get_framework_gaps(db: Session, framework_id: str) -> dict[str, Any]:
|
||||
"""Get controls with techniques that are not adequately covered.
|
||||
|
||||
Raises EntityNotFoundError if the framework does not exist.
|
||||
"""
|
||||
# Assign framework = get_framework(db, framework_id)
|
||||
framework = get_framework(db, framework_id)
|
||||
# Check: not framework
|
||||
if not framework:
|
||||
# Raise EntityNotFoundError
|
||||
raise EntityNotFoundError("Framework", framework_id)
|
||||
|
||||
# Assign controls = (
|
||||
controls = (
|
||||
db.query(ComplianceControl)
|
||||
# Chain .filter() call
|
||||
.filter(ComplianceControl.framework_id == framework.id)
|
||||
# Chain .order_by() call
|
||||
.order_by(ComplianceControl.control_id)
|
||||
# Chain .all() call
|
||||
.all()
|
||||
)
|
||||
|
||||
# Assign gaps = []
|
||||
gaps = []
|
||||
# Iterate over controls
|
||||
for control in controls:
|
||||
# Assign status_data = _get_control_status(control, db)
|
||||
status_data = _get_control_status(control, db)
|
||||
|
||||
# Check: status_data["status"] in ("not_covered", "partially_covered")
|
||||
if status_data["status"] in ("not_covered", "partially_covered"):
|
||||
# Find uncovered techniques
|
||||
uncovered_techniques = []
|
||||
# Iterate over status_data["techniques"]
|
||||
for tech_info in status_data["techniques"]:
|
||||
# Check: tech_info["score"] < 70
|
||||
if tech_info["score"] < 70:
|
||||
# Count available templates
|
||||
template_count = (
|
||||
db.query(TestTemplate)
|
||||
# Chain .filter() call
|
||||
.filter(TestTemplate.mitre_technique_id == tech_info["mitre_id"])
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
|
||||
# Count threat actors using this technique
|
||||
technique = (
|
||||
db.query(Technique)
|
||||
# Chain .filter() call
|
||||
.filter(Technique.mitre_id == tech_info["mitre_id"])
|
||||
# Chain .first() call
|
||||
.first()
|
||||
)
|
||||
# Assign actor_count = 0
|
||||
actor_count = 0
|
||||
# Check: technique
|
||||
if technique:
|
||||
# Assign actor_count = (
|
||||
actor_count = (
|
||||
db.query(ThreatActorTechnique)
|
||||
# Chain .filter() call
|
||||
.filter(ThreatActorTechnique.technique_id == technique.id)
|
||||
# Chain .count() call
|
||||
.count()
|
||||
)
|
||||
|
||||
# Call uncovered_techniques.append()
|
||||
uncovered_techniques.append({
|
||||
**tech_info,
|
||||
# Literal argument value
|
||||
"templates_available": template_count,
|
||||
# Literal argument value
|
||||
"threat_actors_using": actor_count,
|
||||
})
|
||||
|
||||
# Check: uncovered_techniques
|
||||
if uncovered_techniques:
|
||||
# Call gaps.append()
|
||||
gaps.append({
|
||||
# Literal argument value
|
||||
"control_id": status_data["control_id"],
|
||||
# Literal argument value
|
||||
"title": status_data["title"],
|
||||
# Literal argument value
|
||||
"category": status_data["category"],
|
||||
# Literal argument value
|
||||
"status": status_data["status"],
|
||||
# Literal argument value
|
||||
"score": status_data["score"],
|
||||
# Literal argument value
|
||||
"uncovered_techniques": uncovered_techniques,
|
||||
})
|
||||
|
||||
# Return {
|
||||
return {
|
||||
# Literal argument value
|
||||
"framework": {"id": str(framework.id), "name": framework.name},
|
||||
# Literal argument value
|
||||
"total_gaps": len(gaps),
|
||||
# Literal argument value
|
||||
"gaps": gaps,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user