"""OSINT enrichment service — automatically discovers CVEs, advisories, and related intelligence for MITRE ATT&CK techniques using the NVD API. Designed to run as a weekly background job. Respects NVD rate limits (5 requests per 30 seconds without an API key, 50/30s with a key). """ import logging import time from typing import Optional from uuid import UUID import requests from sqlalchemy import func from sqlalchemy.orm import Session from app.config import settings from app.domain.errors import EntityNotFoundError from app.models.osint_item import OsintItem from app.models.technique import Technique logger = logging.getLogger(__name__) NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0" NVD_RATE_LIMIT_BATCH = 5 NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch def enrich_technique_with_cves(db: Session, technique: Technique) -> int: """Search for CVEs related to a technique via the NVD API. Uses the technique name as a keyword search. Deduplicates against existing OsintItems so re-runs are safe. Returns the number of new CVEs added. """ try: headers = {} if getattr(settings, "NVD_API_KEY", ""): headers["apiKey"] = settings.NVD_API_KEY params = { "keywordSearch": technique.name, "resultsPerPage": 10, } resp = requests.get( NVD_API_BASE, params=params, headers=headers, timeout=30, ) if resp.status_code != 200: logger.warning( "NVD API error for %s: HTTP %d", technique.mitre_id, resp.status_code, ) return 0 data = resp.json() count = 0 for vuln in data.get("vulnerabilities", []): cve = vuln.get("cve", {}) cve_id = cve.get("id") if not cve_id: continue # Deduplicate exists = ( db.query(OsintItem.id) .filter( OsintItem.technique_id == technique.id, OsintItem.source_url.contains(cve_id), ) .first() ) if exists: continue descriptions = cve.get("descriptions", []) desc = next( (d["value"] for d in descriptions if d["lang"] == "en"), "" ) # Extract CVSS severity metrics = cve.get("metrics", {}) cvss_v31 = metrics.get("cvssMetricV31", []) cvss_v30 = metrics.get("cvssMetricV30", []) cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30) else {} cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {} severity = cvss_data.get("baseSeverity", "UNKNOWN") score = cvss_data.get("baseScore") item = OsintItem( technique_id=technique.id, source_type="cve", source_url=f"https://nvd.nist.gov/vuln/detail/{cve_id}", title=cve_id, description=desc[:500] if desc else None, severity=severity, metadata_={"cvss_score": score, "cve_id": cve_id}, ) db.add(item) count += 1 if count > 0: technique.review_required = True db.commit() logger.info("Added %d CVEs for %s", count, technique.mitre_id) return count except requests.RequestException as e: logger.error( "HTTP error during OSINT enrichment for %s: %s", technique.mitre_id, e, ) return 0 except Exception as e: logger.error( "OSINT enrichment failed for %s: %s", technique.mitre_id, e, exc_info=True, ) return 0 def enrich_all_techniques(db: Session) -> int: """Enrich all techniques with CVE data from NVD. Rate-limited: processes *NVD_RATE_LIMIT_BATCH* techniques, then sleeps for *NVD_RATE_LIMIT_WAIT* seconds to stay under NVD limits. Returns total number of new OSINT items added. """ techniques = db.query(Technique).order_by(Technique.mitre_id).all() total = 0 logger.info( "Starting OSINT enrichment for %d techniques...", len(techniques), ) for i, tech in enumerate(techniques): total += enrich_technique_with_cves(db, tech) # Rate limiting: wait after every batch if (i + 1) % NVD_RATE_LIMIT_BATCH == 0 and (i + 1) < len(techniques): logger.debug( "Rate limit pause after %d techniques (%ds)...", i + 1, NVD_RATE_LIMIT_WAIT, ) time.sleep(NVD_RATE_LIMIT_WAIT) logger.info( "OSINT enrichment complete — %d new items across %d techniques", total, len(techniques), ) return total def get_osint_items_for_technique( db: Session, technique_id: str, source_type: str | None = None, reviewed: bool | None = None, ) -> list[OsintItem]: """Retrieve OSINT items for a technique with optional filters.""" query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id) if source_type: query = query.filter(OsintItem.source_type == source_type) if reviewed is not None: query = query.filter(OsintItem.reviewed == reviewed) return query.order_by(OsintItem.discovered_at.desc()).all() def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None: """Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork.""" item = db.query(OsintItem).filter(OsintItem.id == item_id).first() if item: item.reviewed = True return item def get_unreviewed_count(db: Session) -> int: """Return the total number of unreviewed OSINT items.""" return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712 def list_osint_items( db: Session, *, technique_id: Optional[UUID] = None, source_type: Optional[str] = None, reviewed: Optional[bool] = None, offset: int = 0, limit: int = 50, ) -> dict: """List OSINT items with optional filters and pagination.""" query = db.query(OsintItem) if technique_id: query = query.filter(OsintItem.technique_id == technique_id) if source_type: query = query.filter(OsintItem.source_type == source_type) if reviewed is not None: query = query.filter(OsintItem.reviewed == reviewed) total = query.count() items = ( query.order_by(OsintItem.discovered_at.desc()) .offset(offset) .limit(limit) .all() ) return { "total": total, "items": [ { "id": str(item.id), "technique_id": str(item.technique_id), "source_type": item.source_type, "source_url": item.source_url, "title": item.title, "description": item.description, "severity": item.severity, "discovered_at": item.discovered_at.isoformat() if item.discovered_at else None, "reviewed": item.reviewed, "metadata": item.metadata_, } for item in items ], } def get_osint_summary(db: Session) -> dict: """Summary statistics for OSINT items.""" total = db.query(func.count(OsintItem.id)).scalar() or 0 unreviewed = get_unreviewed_count(db) by_severity = dict( db.query(OsintItem.severity, func.count(OsintItem.id)) .group_by(OsintItem.severity) .all() ) by_type = dict( db.query(OsintItem.source_type, func.count(OsintItem.id)) .group_by(OsintItem.source_type) .all() ) techniques_with_items = ( db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0 ) return { "total_items": total, "unreviewed": unreviewed, "techniques_with_items": techniques_with_items, "by_severity": by_severity, "by_type": by_type, } def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique: """Get a technique by ID or raise EntityNotFoundError.""" technique = db.query(Technique).filter(Technique.id == technique_id).first() if not technique: raise EntityNotFoundError("Technique", str(technique_id)) return technique