diff --git a/backend/app/routers/test_templates.py b/backend/app/routers/test_templates.py index ae55756..e569570 100644 --- a/backend/app/routers/test_templates.py +++ b/backend/app/routers/test_templates.py @@ -31,6 +31,7 @@ from sqlalchemy.orm import Session from app.database import get_db from app.dependencies.auth import get_current_user, require_any_role from app.domain.unit_of_work import UnitOfWork +from app.models.technique import Technique from app.models.user import User from app.schemas.test_template import ( TestTemplateCreate, @@ -178,6 +179,15 @@ def create_template( """Create a custom test template.""" template = create_template_svc(db, **payload.model_dump()) with UnitOfWork(db) as uow: + # Flag the associated technique for review — new template available + if template.mitre_technique_id: + technique = ( + db.query(Technique) + .filter(Technique.mitre_id == template.mitre_technique_id) + .first() + ) + if technique: + technique.review_required = True log_action( db, user_id=current_user.id, diff --git a/backend/app/services/atomic_import_service.py b/backend/app/services/atomic_import_service.py index 9f88278..870980d 100644 --- a/backend/app/services/atomic_import_service.py +++ b/backend/app/services/atomic_import_service.py @@ -35,6 +35,7 @@ import yaml from sqlalchemy.orm import Session from app.models.test_template import TestTemplate +from app.models.technique import Technique from app.services.audit_service import log_action logger = logging.getLogger(__name__) @@ -218,6 +219,7 @@ def import_atomic_red_team(db: Session) -> dict: created = 0 skipped = 0 + new_technique_ids: set[str] = set() for item in parsed_tests: if item["atomic_test_id"] in existing_ids: @@ -238,8 +240,14 @@ def import_atomic_red_team(db: Session) -> dict: ) db.add(template) existing_ids.add(item["atomic_test_id"]) + new_technique_ids.add(item["technique_id"]) created += 1 + if new_technique_ids: + db.query(Technique).filter( + Technique.mitre_id.in_(new_technique_ids) + ).update({"review_required": True}, synchronize_session=False) + db.commit() # Count distinct YAML files by technique_id diff --git a/backend/app/services/caldera_import_service.py b/backend/app/services/caldera_import_service.py index 5c64978..cdea97f 100644 --- a/backend/app/services/caldera_import_service.py +++ b/backend/app/services/caldera_import_service.py @@ -35,6 +35,7 @@ from sqlalchemy.orm import Session from app.models.test_template import TestTemplate from app.models.data_source import DataSource +from app.models.technique import Technique from app.services.audit_service import log_action logger = logging.getLogger(__name__) @@ -237,6 +238,7 @@ def sync(db: Session) -> dict: created = 0 skipped = 0 + new_technique_ids: set[str] = set() for item in parsed: if item["atomic_test_id"] in existing_ids: @@ -257,8 +259,14 @@ def sync(db: Session) -> dict: ) db.add(template) existing_ids.add(item["atomic_test_id"]) + new_technique_ids.add(item["mitre_technique_id"]) created += 1 + if new_technique_ids: + db.query(Technique).filter( + Technique.mitre_id.in_(new_technique_ids) + ).update({"review_required": True}, synchronize_session=False) + db.commit() summary = { diff --git a/backend/app/services/lolbas_import_service.py b/backend/app/services/lolbas_import_service.py index 24ab76e..bac41c0 100644 --- a/backend/app/services/lolbas_import_service.py +++ b/backend/app/services/lolbas_import_service.py @@ -39,6 +39,7 @@ from sqlalchemy.orm import Session from app.models.test_template import TestTemplate from app.models.data_source import DataSource +from app.models.technique import Technique from app.services.audit_service import log_action logger = logging.getLogger(__name__) @@ -295,6 +296,7 @@ def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict: created = 0 skipped = 0 + new_technique_ids: set[str] = set() for item in items: if item["atomic_test_id"] in existing_ids: @@ -315,8 +317,14 @@ def _upsert_templates(db: Session, items: list[dict], source_name: str) -> dict: ) db.add(template) existing_ids.add(item["atomic_test_id"]) + new_technique_ids.add(item["mitre_technique_id"]) created += 1 + if new_technique_ids: + db.query(Technique).filter( + Technique.mitre_id.in_(new_technique_ids) + ).update({"review_required": True}, synchronize_session=False) + db.commit() return {"created": created, "skipped_existing": skipped, "total_parsed": len(items)}