"""OSINT enrichment service — automatically discovers CVEs, advisories, and related intelligence for MITRE ATT&CK techniques using the NVD API. Designed to run as a weekly background job. Respects NVD rate limits (5 requests per 30 seconds without an API key, 50/30s with a key). """ import logging import time import requests from sqlalchemy.orm import Session from app.config import settings from app.models.osint_item import OsintItem from app.models.technique import Technique logger = logging.getLogger(__name__) NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0" NVD_RATE_LIMIT_BATCH = 5 NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch def enrich_technique_with_cves(db: Session, technique: Technique) -> int: """Search for CVEs related to a technique via the NVD API. Uses the technique name as a keyword search. Deduplicates against existing OsintItems so re-runs are safe. Returns the number of new CVEs added. """ try: headers = {} if getattr(settings, "NVD_API_KEY", ""): headers["apiKey"] = settings.NVD_API_KEY params = { "keywordSearch": technique.name, "resultsPerPage": 10, } resp = requests.get( NVD_API_BASE, params=params, headers=headers, timeout=30, ) if resp.status_code != 200: logger.warning( "NVD API error for %s: HTTP %d", technique.mitre_id, resp.status_code, ) return 0 data = resp.json() count = 0 for vuln in data.get("vulnerabilities", []): cve = vuln.get("cve", {}) cve_id = cve.get("id") if not cve_id: continue # Deduplicate exists = ( db.query(OsintItem.id) .filter( OsintItem.technique_id == technique.id, OsintItem.source_url.contains(cve_id), ) .first() ) if exists: continue descriptions = cve.get("descriptions", []) desc = next( (d["value"] for d in descriptions if d["lang"] == "en"), "" ) # Extract CVSS severity metrics = cve.get("metrics", {}) cvss_v31 = metrics.get("cvssMetricV31", []) cvss_v30 = metrics.get("cvssMetricV30", []) cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30) else {} cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {} severity = cvss_data.get("baseSeverity", "UNKNOWN") score = cvss_data.get("baseScore") item = OsintItem( technique_id=technique.id, source_type="cve", source_url=f"https://nvd.nist.gov/vuln/detail/{cve_id}", title=cve_id, description=desc[:500] if desc else None, severity=severity, metadata_={"cvss_score": score, "cve_id": cve_id}, ) db.add(item) count += 1 if count > 0: technique.review_required = True db.commit() logger.info("Added %d CVEs for %s", count, technique.mitre_id) return count except requests.RequestException as e: logger.error( "HTTP error during OSINT enrichment for %s: %s", technique.mitre_id, e, ) return 0 except Exception as e: logger.error( "OSINT enrichment failed for %s: %s", technique.mitre_id, e, exc_info=True, ) return 0 def enrich_all_techniques(db: Session) -> int: """Enrich all techniques with CVE data from NVD. Rate-limited: processes *NVD_RATE_LIMIT_BATCH* techniques, then sleeps for *NVD_RATE_LIMIT_WAIT* seconds to stay under NVD limits. Returns total number of new OSINT items added. """ techniques = db.query(Technique).order_by(Technique.mitre_id).all() total = 0 logger.info( "Starting OSINT enrichment for %d techniques...", len(techniques), ) for i, tech in enumerate(techniques): total += enrich_technique_with_cves(db, tech) # Rate limiting: wait after every batch if (i + 1) % NVD_RATE_LIMIT_BATCH == 0 and (i + 1) < len(techniques): logger.debug( "Rate limit pause after %d techniques (%ds)...", i + 1, NVD_RATE_LIMIT_WAIT, ) time.sleep(NVD_RATE_LIMIT_WAIT) logger.info( "OSINT enrichment complete — %d new items across %d techniques", total, len(techniques), ) return total def get_osint_items_for_technique( db: Session, technique_id: str, source_type: str | None = None, reviewed: bool | None = None, ) -> list[OsintItem]: """Retrieve OSINT items for a technique with optional filters.""" query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id) if source_type: query = query.filter(OsintItem.source_type == source_type) if reviewed is not None: query = query.filter(OsintItem.reviewed == reviewed) return query.order_by(OsintItem.discovered_at.desc()).all() def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None: """Mark an OSINT item as reviewed.""" item = db.query(OsintItem).filter(OsintItem.id == item_id).first() if item: item.reviewed = True db.commit() db.refresh(item) return item def get_unreviewed_count(db: Session) -> int: """Return the total number of unreviewed OSINT items.""" return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712