Files
Aegis/backend/app/seed_demo.py
T
kitos 394d5d9056 refactor(types): add comprehensive type annotations across backend Python codebase
Enable ANN rules in ruff.toml (flake8-annotations) and resolve all 221 violations:

ANN201/ANN202 — return types on 168 public/private functions:
- All 28 FastAPI routers: endpoints annotated with dict/list/specific schema/
  StreamingResponse/FileResponse/JSONResponse as appropriate
- main.py: lifespan→AsyncGenerator[None,None], exception handlers→JSONResponse
- database.py: get_db→Generator[Session,None,None], proxy methods→correct types
- middleware/request_context.py: dispatch→Response with Callable call_next type

ANN001/ANN002/ANN003 — 32 missing argument types:
- seed_demo.py: all db parameters typed as Session
- domain/unit_of_work.py: __aexit__ exc_type/exc_val/exc_tb typed with TracebackType
- services: audit_service user_id→UUID|None, heatmap_service query/model/builder,
  notification_service test→Test, tempo_service test→Test/user→User,
  test_workflow_service test_id→UUID, campaign_crud **fields→object,
  test_crud **fields→object (4 sites)

ANN401 — 16 Any usages resolved:
- Domain entities (campaign/technique/threat_actor/test_entity): replaced Any with
  actual ORM types via TYPE_CHECKING guards to avoid circular imports
- detection_rule_service: test_id/detection_rule_id/evaluator_id→UUID
- score_cache: kept Any with # noqa: ANN401 (genuinely generic cache)
- jira_service/tempo_service: kept Any with # noqa: ANN401 (lazy optional deps)
- d3fend_import_service: _to_str(v: Any) kept with # noqa: ANN401

ANN204/ANN205/ANN206 — special/static/class methods:
- database.py proxy __call__/__getattr__: *args: object/**kwargs: object
- schemas/test.py model_validate: obj→object, **kwargs→object
- sa_technique_repository._int_type→type

All 439 unit tests pass. ruff check app/ → All checks passed!

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-09 17:04:51 +02:00

442 lines
15 KiB
Python

"""
Seed script — generates a realistic volume of demo data for V3 validation.
Usage:
python -m app.seed_demo
**Prerequisite**: The MITRE sync must have been completed first so that
real techniques exist in the database.
Running twice is safe — the script detects existing demo data (by username
prefix ``demo_``) and deletes it before re-creating, ensuring idempotency.
"""
import logging
import random
import uuid
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from app.auth import hash_password
from app.database import SessionLocal
from app.models.audit import AuditLog
from app.models.enums import TeamSide, TechniqueStatus, TestResult, TestState
from app.models.evidence import Evidence
from app.models.notification import Notification
from app.models.technique import Technique
from app.models.test import Test
from app.models.test_template import TestTemplate
from app.models.user import User
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DEMO_PREFIX = "demo_"
ROLES = ["red_tech", "blue_tech", "red_lead", "blue_lead", "admin"]
TECHNIQUE_STATUSES = [
TechniqueStatus.validated,
TechniqueStatus.partial,
TechniqueStatus.not_covered,
TechniqueStatus.in_progress,
TechniqueStatus.not_evaluated,
]
TEST_STATES = [
TestState.draft,
TestState.red_executing,
TestState.blue_evaluating,
TestState.in_review,
TestState.validated,
TestState.rejected,
]
TEST_RESULTS = [
TestResult.detected,
TestResult.not_detected,
TestResult.partially_detected,
]
NOTIFICATION_TYPES = [
"test_assigned",
"validation_needed",
"test_rejected",
"test_validated",
"test_state_changed",
]
AUDIT_ACTIONS = [
"create_test",
"update_test",
"validate_technique",
"upload_evidence",
"create_user",
"import_atomic_red_team",
"sync_mitre",
"login",
"reject_test",
"approve_test",
]
PLATFORMS = ["windows", "linux", "macos"]
TEMPLATE_NAMES = [
"Manual Credential Dumping Test",
"Custom Phishing Payload Delivery",
"Lateral Movement via RDP",
"Persistence via Registry Run Keys",
"Data Exfiltration over DNS",
"Process Injection via DLL",
"Privilege Escalation with Token Impersonation",
"Custom C2 Beacon Communication Test",
"Kerberoasting Attack Procedure",
"Living Off The Land Binaries Test",
]
# ---------------------------------------------------------------------------
# Cleanup
# ---------------------------------------------------------------------------
def _cleanup_demo_data(db: Session) -> None:
"""Remove all previously seeded demo data."""
# Delete in order to respect FK constraints
demo_users = db.query(User).filter(User.username.like(f"{DEMO_PREFIX}%")).all()
demo_user_ids = [u.id for u in demo_users]
if demo_user_ids:
# Notifications for demo users
db.query(Notification).filter(
Notification.user_id.in_(demo_user_ids)
).delete(synchronize_session=False)
# Audit logs for demo users
db.query(AuditLog).filter(
AuditLog.user_id.in_(demo_user_ids)
).delete(synchronize_session=False)
# Evidences for tests created by demo users
demo_tests = db.query(Test).filter(
Test.created_by.in_(demo_user_ids)
).all()
demo_test_ids = [t.id for t in demo_tests]
if demo_test_ids:
db.query(Evidence).filter(
Evidence.test_id.in_(demo_test_ids)
).delete(synchronize_session=False)
db.query(Test).filter(
Test.id.in_(demo_test_ids)
).delete(synchronize_session=False)
# Delete demo templates (by source = "demo")
db.query(TestTemplate).filter(
TestTemplate.source == "demo"
).delete(synchronize_session=False)
# Delete demo users
if demo_user_ids:
db.query(User).filter(
User.id.in_(demo_user_ids)
).delete(synchronize_session=False)
db.commit()
logger.info("Cleaned up existing demo data.")
# ---------------------------------------------------------------------------
# Seeders
# ---------------------------------------------------------------------------
def _seed_users(db: Session) -> list[User]:
"""Create 5 users per role (25 total)."""
users = []
for role in ROLES:
for i in range(1, 6):
user = User(
username=f"{DEMO_PREFIX}{role}_{i}",
email=f"{DEMO_PREFIX}{role}_{i}@aegis-demo.local",
hashed_password=hash_password("demo123"),
role=role,
is_active=True,
)
db.add(user)
users.append(user)
db.flush()
logger.info("Created %d demo users.", len(users))
return users
def _seed_technique_statuses(db: Session, count: int = 50) -> list[Technique]:
"""Set varied statuses on up to *count* techniques."""
techniques = db.query(Technique).limit(count).all()
if not techniques:
logger.warning("No techniques found — run MITRE sync first!")
return []
for tech in techniques:
tech.status_global = random.choice(TECHNIQUE_STATUSES)
if tech.status_global == TechniqueStatus.validated:
tech.last_review_date = datetime.utcnow() - timedelta(
days=random.randint(1, 30)
)
db.flush()
logger.info("Updated status on %d techniques.", len(techniques))
return techniques
def _seed_tests(db: Session, users: list[User], techniques: list[Technique], count: int = 100) -> list[Test]:
"""Create *count* tests in various pipeline states."""
if not techniques:
logger.warning("No techniques available — skipping test seeding.")
return []
red_techs = [u for u in users if u.role == "red_tech"]
blue_techs = [u for u in users if u.role == "blue_tech"]
red_leads = [u for u in users if u.role == "red_lead"]
blue_leads = [u for u in users if u.role == "blue_lead"]
tests = []
for i in range(count):
technique = random.choice(techniques)
state = random.choice(TEST_STATES)
creator = random.choice(red_techs + blue_techs)
test = Test(
technique_id=technique.id,
name=f"Demo Test {i + 1}{technique.name[:40]}",
description=f"Automated demo test #{i + 1} for {technique.mitre_id}.",
platform=random.choice(PLATFORMS),
procedure_text=(
f"Step 1: Prepare environment.\n"
f"Step 2: Execute {technique.mitre_id} procedure.\n"
f"Step 3: Observe results."
),
tool_used=random.choice(["powershell", "bash", "cmd", "python", "caldera", "metasploit"]),
execution_date=datetime.utcnow() - timedelta(days=random.randint(0, 60)),
created_by=creator.id,
result=random.choice(TEST_RESULTS) if state not in (TestState.draft, TestState.red_executing) else None,
state=state,
created_at=datetime.utcnow() - timedelta(days=random.randint(0, 90)),
)
# Populate team fields based on state
if state in (TestState.blue_evaluating, TestState.in_review, TestState.validated, TestState.rejected):
test.red_summary = f"Attack executed successfully using {test.tool_used}."
test.attack_success = random.choice([True, True, True, False])
if state in (TestState.in_review, TestState.validated, TestState.rejected):
test.blue_summary = "Detection observed in SIEM. Alert fired."
test.detection_result = random.choice(TEST_RESULTS)
if state == TestState.validated:
rv = random.choice(red_leads)
bv = random.choice(blue_leads)
test.red_validated_by = rv.id
test.red_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 10))
test.red_validation_status = "approved"
test.blue_validated_by = bv.id
test.blue_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 10))
test.blue_validation_status = "approved"
if state == TestState.rejected:
rejector = random.choice(red_leads + blue_leads)
if rejector.role == "red_lead":
test.red_validated_by = rejector.id
test.red_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 5))
test.red_validation_status = "rejected"
test.red_validation_notes = "Insufficient evidence of attack success."
else:
test.blue_validated_by = rejector.id
test.blue_validated_at = datetime.utcnow() - timedelta(days=random.randint(0, 5))
test.blue_validation_status = "rejected"
test.blue_validation_notes = "Detection evidence not conclusive."
db.add(test)
tests.append(test)
db.flush()
logger.info("Created %d demo tests.", len(tests))
return tests
def _seed_evidences(db: Session, tests: list[Test], users: list[User], count: int = 50) -> list[Evidence]:
"""Create *count* dummy evidence records."""
if not tests:
return []
# Pick tests that are past draft state
eligible = [t for t in tests if t.state != TestState.draft]
if not eligible:
eligible = tests
evidences = []
red_blue = [u for u in users if u.role in ("red_tech", "blue_tech")]
for i in range(count):
test = random.choice(eligible)
uploader = random.choice(red_blue)
team = TeamSide.red if uploader.role == "red_tech" else TeamSide.blue
ext = random.choice(["png", "log", "pcap", "csv", "txt", "json"])
fname = f"evidence_{i + 1}.{ext}"
evidence = Evidence(
test_id=test.id,
file_name=fname,
file_path=f"{test.id}/{uuid.uuid4()}_{fname}",
sha256_hash=uuid.uuid4().hex + uuid.uuid4().hex, # dummy hash
uploaded_by=uploader.id,
uploaded_at=datetime.utcnow() - timedelta(days=random.randint(0, 30)),
team=team,
notes=f"Auto-generated demo evidence #{i + 1}.",
)
db.add(evidence)
evidences.append(evidence)
db.flush()
logger.info("Created %d demo evidences.", len(evidences))
return evidences
def _seed_audit_logs(db: Session, users: list[User], count: int = 20) -> None:
"""Create *count* varied audit log entries."""
for i in range(count):
user = random.choice(users)
log = AuditLog(
user_id=user.id,
action=random.choice(AUDIT_ACTIONS),
entity_type=random.choice(["test", "technique", "user", "test_template"]),
entity_id=str(uuid.uuid4()),
timestamp=datetime.utcnow() - timedelta(days=random.randint(0, 60)),
details={"demo": True, "index": i},
)
db.add(log)
db.flush()
logger.info("Created %d demo audit logs.", count)
def _seed_notifications(db: Session, users: list[User], count: int = 30) -> None:
"""Create *count* notifications spread across demo users."""
for i in range(count):
user = random.choice(users)
ntype = random.choice(NOTIFICATION_TYPES)
notif = Notification(
user_id=user.id,
type=ntype,
title=f"Demo notification: {ntype.replace('_', ' ').title()} #{i + 1}",
message=f"This is an auto-generated demo notification ({ntype}).",
entity_type="test",
entity_id=uuid.uuid4(),
read=random.choice([True, False]),
created_at=datetime.utcnow() - timedelta(days=random.randint(0, 30)),
)
db.add(notif)
db.flush()
logger.info("Created %d demo notifications.", count)
def _seed_templates(db: Session, techniques: list[Technique], count: int = 10) -> None:
"""Create *count* manual demo templates."""
if not techniques:
return
for i, name in enumerate(TEMPLATE_NAMES[:count]):
technique = techniques[i % len(techniques)]
template = TestTemplate(
mitre_technique_id=technique.mitre_id,
name=name,
description=f"Demo template: {name}. Targets {technique.mitre_id} ({technique.name}).",
source="demo",
source_url=None,
attack_procedure=(
f"1. Set up environment for {technique.mitre_id}.\n"
"2. Execute the procedure.\n"
"3. Record observations."
),
expected_detection=f"SIEM should alert on {technique.mitre_id} indicators.",
platform=random.choice(PLATFORMS),
tool_suggested=random.choice(["powershell", "cmd", "bash", "python"]),
severity=random.choice(["low", "medium", "high", "critical"]),
is_active=True,
)
db.add(template)
db.flush()
logger.info("Created %d demo templates.", count)
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def seed_demo() -> dict:
"""Generate all demo data. Returns a summary dict."""
db = SessionLocal()
try:
logger.info("=== Starting V3 demo seed ===")
# Step 0: cleanup previous run
_cleanup_demo_data(db)
# Step 1: users
users = _seed_users(db)
# Step 2: technique statuses
techniques = _seed_technique_statuses(db, count=50)
# Step 3: tests
tests = _seed_tests(db, users, techniques, count=100)
# Step 4: evidences
evidences = _seed_evidences(db, tests, users, count=50)
# Step 5: audit logs
_seed_audit_logs(db, users, count=20)
# Step 6: notifications
_seed_notifications(db, users, count=30)
# Step 7: templates
_seed_templates(db, techniques, count=10)
db.commit()
summary = {
"users": len(users),
"techniques_updated": len(techniques),
"tests": len(tests),
"evidences": len(evidences),
"audit_logs": 20,
"notifications": 30,
"templates": 10,
}
logger.info("=== Demo seed complete: %s ===", summary)
return summary
except Exception:
db.rollback()
raise
finally:
db.close()
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s%(message)s",
)
result = seed_demo()
print(f"\nSeed complete: {result}")