"""Shared utilities for observability instrumentation."""
import hashlib
from typing import Any
__all__ = ("compute_sql_hash", "get_trace_context", "resolve_db_system")
_DB_SYSTEM_MAP: "tuple[tuple[str, str], ...]" = (
("asyncpg", "postgresql"),
("psycopg", "postgresql"),
("psqlpy", "postgresql"),
("postgres", "postgresql"),
("asyncmy", "mysql"),
("mysql", "mysql"),
("mariadb", "mysql"),
("aiosqlite", "sqlite"),
("sqlite", "sqlite"),
("duckdb", "duckdb"),
("bigquery", "bigquery"),
("spanner", "spanner"),
("oracle", "oracle"),
("oracledb", "oracle"),
("adbc", "adbc"),
)
[docs]
def resolve_db_system(adapter_or_driver: str) -> str:
"""Resolve adapter/driver name to OTel db.system value.
Args:
adapter_or_driver: Adapter or driver name.
Returns:
OTel db.system value.
"""
normalized = adapter_or_driver.lower()
for needle, system in _DB_SYSTEM_MAP:
if needle in normalized:
return system
return "other_sql"
[docs]
def get_trace_context() -> "tuple[str | None, str | None]":
"""Extract trace_id and span_id from the current OTel context.
Returns:
Tuple of trace_id and span_id if available, otherwise (None, None).
"""
try:
from opentelemetry import trace
except ImportError:
return None, None
span: Any | None = trace.get_current_span()
if span is None or not span.is_recording():
return None, None
ctx = span.get_span_context()
if not ctx or not ctx.is_valid:
return None, None
if not ctx.trace_id or not ctx.span_id:
return None, None
return format(ctx.trace_id, "032x"), format(ctx.span_id, "016x")
[docs]
def compute_sql_hash(sql: str) -> str:
"""Return the 16-character SHA256 hash for SQL text.
Args:
sql: SQL statement text.
Returns:
SHA256 hash prefix.
"""
return hashlib.sha256(sql.encode("utf-8")).hexdigest()[:16]