Files
Aegis/backend/app/services/osint_enrichment_service.py

278 lines
8.4 KiB
Python

"""OSINT enrichment service — automatically discovers CVEs, advisories, and
related intelligence for MITRE ATT&CK techniques using the NVD API.
Designed to run as a weekly background job. Respects NVD rate limits
(5 requests per 30 seconds without an API key, 50/30s with a key).
"""
import logging
import time
from typing import Optional
from uuid import UUID
import requests
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.config import settings
from app.domain.errors import EntityNotFoundError
from app.models.osint_item import OsintItem
from app.models.technique import Technique
logger = logging.getLogger(__name__)
NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0"
NVD_RATE_LIMIT_BATCH = 5
NVD_RATE_LIMIT_WAIT = 31 # seconds to wait after each batch
def enrich_technique_with_cves(db: Session, technique: Technique) -> int:
"""Search for CVEs related to a technique via the NVD API.
Uses the technique name as a keyword search. Deduplicates against
existing OsintItems so re-runs are safe.
Returns the number of new CVEs added.
"""
try:
headers = {}
if getattr(settings, "NVD_API_KEY", ""):
headers["apiKey"] = settings.NVD_API_KEY
params = {
"keywordSearch": technique.name,
"resultsPerPage": 10,
}
resp = requests.get(
NVD_API_BASE,
params=params,
headers=headers,
timeout=30,
)
if resp.status_code != 200:
logger.warning(
"NVD API error for %s: HTTP %d",
technique.mitre_id,
resp.status_code,
)
return 0
data = resp.json()
count = 0
for vuln in data.get("vulnerabilities", []):
cve = vuln.get("cve", {})
cve_id = cve.get("id")
if not cve_id:
continue
# Deduplicate
exists = (
db.query(OsintItem.id)
.filter(
OsintItem.technique_id == technique.id,
OsintItem.source_url.contains(cve_id),
)
.first()
)
if exists:
continue
descriptions = cve.get("descriptions", [])
desc = next(
(d["value"] for d in descriptions if d["lang"] == "en"), ""
)
# Extract CVSS severity
metrics = cve.get("metrics", {})
cvss_v31 = metrics.get("cvssMetricV31", [])
cvss_v30 = metrics.get("cvssMetricV30", [])
cvss_entry = (cvss_v31[0] if cvss_v31 else cvss_v30[0]) if (cvss_v31 or cvss_v30) else {}
cvss_data = cvss_entry.get("cvssData", {}) if cvss_entry else {}
severity = cvss_data.get("baseSeverity", "UNKNOWN")
score = cvss_data.get("baseScore")
item = OsintItem(
technique_id=technique.id,
source_type="cve",
source_url=f"https://nvd.nist.gov/vuln/detail/{cve_id}",
title=cve_id,
description=desc[:500] if desc else None,
severity=severity,
metadata_={"cvss_score": score, "cve_id": cve_id},
)
db.add(item)
count += 1
if count > 0:
technique.review_required = True
db.commit()
logger.info("Added %d CVEs for %s", count, technique.mitre_id)
return count
except requests.RequestException as e:
logger.error(
"HTTP error during OSINT enrichment for %s: %s",
technique.mitre_id,
e,
)
return 0
except Exception as e:
logger.error(
"OSINT enrichment failed for %s: %s",
technique.mitre_id,
e,
exc_info=True,
)
return 0
def enrich_all_techniques(db: Session) -> int:
"""Enrich all techniques with CVE data from NVD.
Rate-limited: processes *NVD_RATE_LIMIT_BATCH* techniques, then
sleeps for *NVD_RATE_LIMIT_WAIT* seconds to stay under NVD limits.
Returns total number of new OSINT items added.
"""
techniques = db.query(Technique).order_by(Technique.mitre_id).all()
total = 0
logger.info(
"Starting OSINT enrichment for %d techniques...",
len(techniques),
)
for i, tech in enumerate(techniques):
total += enrich_technique_with_cves(db, tech)
# Rate limiting: wait after every batch
if (i + 1) % NVD_RATE_LIMIT_BATCH == 0 and (i + 1) < len(techniques):
logger.debug(
"Rate limit pause after %d techniques (%ds)...",
i + 1,
NVD_RATE_LIMIT_WAIT,
)
time.sleep(NVD_RATE_LIMIT_WAIT)
logger.info(
"OSINT enrichment complete — %d new items across %d techniques",
total,
len(techniques),
)
return total
def get_osint_items_for_technique(
db: Session,
technique_id: str,
source_type: str | None = None,
reviewed: bool | None = None,
) -> list[OsintItem]:
"""Retrieve OSINT items for a technique with optional filters."""
query = db.query(OsintItem).filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
return query.order_by(OsintItem.discovered_at.desc()).all()
def mark_osint_reviewed(db: Session, item_id: str) -> OsintItem | None:
"""Mark an OSINT item as reviewed. Does not commit; caller uses UnitOfWork."""
item = db.query(OsintItem).filter(OsintItem.id == item_id).first()
if item:
item.reviewed = True
return item
def get_unreviewed_count(db: Session) -> int:
"""Return the total number of unreviewed OSINT items."""
return db.query(OsintItem).filter(OsintItem.reviewed == False).count() # noqa: E712
def list_osint_items(
db: Session,
*,
technique_id: Optional[UUID] = None,
source_type: Optional[str] = None,
reviewed: Optional[bool] = None,
offset: int = 0,
limit: int = 50,
) -> dict:
"""List OSINT items with optional filters and pagination."""
query = db.query(OsintItem)
if technique_id:
query = query.filter(OsintItem.technique_id == technique_id)
if source_type:
query = query.filter(OsintItem.source_type == source_type)
if reviewed is not None:
query = query.filter(OsintItem.reviewed == reviewed)
total = query.count()
items = (
query.order_by(OsintItem.discovered_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"items": [
{
"id": str(item.id),
"technique_id": str(item.technique_id),
"source_type": item.source_type,
"source_url": item.source_url,
"title": item.title,
"description": item.description,
"severity": item.severity,
"discovered_at": item.discovered_at.isoformat() if item.discovered_at else None,
"reviewed": item.reviewed,
"metadata": item.metadata_,
}
for item in items
],
}
def get_osint_summary(db: Session) -> dict:
"""Summary statistics for OSINT items."""
total = db.query(func.count(OsintItem.id)).scalar() or 0
unreviewed = get_unreviewed_count(db)
by_severity = dict(
db.query(OsintItem.severity, func.count(OsintItem.id))
.group_by(OsintItem.severity)
.all()
)
by_type = dict(
db.query(OsintItem.source_type, func.count(OsintItem.id))
.group_by(OsintItem.source_type)
.all()
)
techniques_with_items = (
db.query(func.count(func.distinct(OsintItem.technique_id))).scalar() or 0
)
return {
"total_items": total,
"unreviewed": unreviewed,
"techniques_with_items": techniques_with_items,
"by_severity": by_severity,
"by_type": by_type,
}
def get_technique_or_raise(db: Session, technique_id: UUID) -> Technique:
"""Get a technique by ID or raise EntityNotFoundError."""
technique = db.query(Technique).filter(Technique.id == technique_id).first()
if not technique:
raise EntityNotFoundError("Technique", str(technique_id))
return technique