Source code for sqlspec.observability._common

"""Shared utilities for observability instrumentation."""

import hashlib
from typing import Any

__all__ = ("compute_sql_hash", "get_trace_context", "resolve_db_system")

_DB_SYSTEM_MAP: "tuple[tuple[str, str], ...]" = (
    ("asyncpg", "postgresql"),
    ("psycopg", "postgresql"),
    ("psqlpy", "postgresql"),
    ("postgres", "postgresql"),
    ("asyncmy", "mysql"),
    ("mysql", "mysql"),
    ("mariadb", "mysql"),
    ("aiosqlite", "sqlite"),
    ("sqlite", "sqlite"),
    ("duckdb", "duckdb"),
    ("bigquery", "bigquery"),
    ("spanner", "spanner"),
    ("oracle", "oracle"),
    ("oracledb", "oracle"),
    ("adbc", "adbc"),
)


[docs] def resolve_db_system(adapter_or_driver: str) -> str: """Resolve adapter/driver name to OTel db.system value. Args: adapter_or_driver: Adapter or driver name. Returns: OTel db.system value. """ normalized = adapter_or_driver.lower() for needle, system in _DB_SYSTEM_MAP: if needle in normalized: return system return "other_sql"
[docs] def get_trace_context() -> "tuple[str | None, str | None]": """Extract trace_id and span_id from the current OTel context. Returns: Tuple of trace_id and span_id if available, otherwise (None, None). """ try: from opentelemetry import trace except ImportError: return None, None span: Any | None = trace.get_current_span() if span is None or not span.is_recording(): return None, None ctx = span.get_span_context() if not ctx or not ctx.is_valid: return None, None if not ctx.trace_id or not ctx.span_id: return None, None return format(ctx.trace_id, "032x"), format(ctx.span_id, "016x")
[docs] def compute_sql_hash(sql: str) -> str: """Return the 16-character SHA256 hash for SQL text. Args: sql: SQL statement text. Returns: SHA256 hash prefix. """ return hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16]