Source code for sqlspec.observability._runtime

"""Runtime helpers that bundle lifecycle, observer, and span orchestration."""

import re
from typing import TYPE_CHECKING, Any, cast

from sqlspec.observability._common import compute_sql_hash, get_trace_context, resolve_db_system
from sqlspec.observability._config import LoggingConfig, ObservabilityConfig
from sqlspec.observability._dispatcher import LifecycleDispatcher, LifecycleHook
from sqlspec.observability._observer import StatementObserver, create_event, create_statement_observer
from sqlspec.observability._spans import SpanManager
from sqlspec.utils.correlation import CorrelationContext
from sqlspec.utils.type_guards import has_span_attribute

_LITERAL_PATTERN = re.compile(r"'(?:''|[^'])*'")

if TYPE_CHECKING:
    from collections.abc import Iterable

    from sqlspec.storage import StorageTelemetry


[docs] class ObservabilityRuntime: """Aggregates dispatchers, observers, spans, and custom metrics.""" __slots__ = ( "_is_idle_cached", "_metrics", "_redaction", "_statement_observers", "bind_key", "config", "config_name", "lifecycle", "span_manager", ) # Allow test injection with fake span managers (mypyc strict typing workaround) span_manager: "Any"
[docs] def __init__( self, config: ObservabilityConfig | None = None, *, bind_key: str | None = None, config_name: str | None = None ) -> None: config = config.copy() if config else ObservabilityConfig() if config.logging is None: config.logging = LoggingConfig() self.config = config self.bind_key = bind_key self.config_name = config_name or "SQLSpecConfig" lifecycle_config = cast("dict[str, Iterable[LifecycleHook]] | None", config.lifecycle) self.lifecycle = LifecycleDispatcher(lifecycle_config) self.span_manager = SpanManager(config.telemetry) observers: list[StatementObserver] = [] if config.statement_observers: observers.extend(config.statement_observers) if config.print_sql: observers.append(create_statement_observer(config.logging)) self._statement_observers = tuple(observers) self._redaction = config.redaction.copy() if config.redaction else None self._metrics: dict[str, float] = {} # Pre-compute the non-span idle state (lifecycle and observers are immutable) # span_manager can be replaced for testing so we check it separately self._is_idle_cached = not self.lifecycle.is_enabled and not self._statement_observers
@property def is_idle(self) -> bool: """Return True when no observability features are active. A runtime is idle if it has no lifecycle hooks, no statement observers, and telemetry spans are disabled. Drivers can use this to skip expensive context construction. """ # Fast path: lifecycle and observers state is cached (immutable after init) # span_manager is checked each time as it can be replaced for testing return self._is_idle_cached and not self.span_manager.is_enabled @property def has_statement_observers(self) -> bool: """Return True when any observers are registered.""" return bool(self._statement_observers) @property def diagnostics_key(self) -> str: """Derive diagnostics key from bind key or configuration name.""" if self.bind_key: return self.bind_key return self.config_name
[docs] def base_context(self) -> dict[str, Any]: """Return the base payload for lifecycle events.""" context = {"config": self.config_name} if self.bind_key: context["bind_key"] = self.bind_key correlation_id = CorrelationContext.get() if correlation_id: context["correlation_id"] = correlation_id return context
def _build_context(self, **extras: Any) -> dict[str, Any]: context = self.base_context() context.update({key: value for key, value in extras.items() if value is not None}) return context
[docs] def lifecycle_snapshot(self) -> dict[str, int]: """Return lifecycle counters keyed under the diagnostics prefix.""" return self.lifecycle.snapshot(prefix=self.diagnostics_key)
[docs] def metrics_snapshot(self) -> dict[str, float]: """Return accumulated custom metrics with diagnostics prefix.""" if not self._metrics: return {} prefix = self.diagnostics_key return {f"{prefix}.{name}": value for name, value in self._metrics.items()}
[docs] def increment_metric(self, name: str, amount: float = 1.0) -> None: """Increment a custom metric counter.""" self._metrics[name] = self._metrics.get(name, 0.0) + amount
[docs] def record_metric(self, name: str, value: float) -> None: """Set a custom metric to an explicit value.""" self._metrics[name] = value
[docs] def start_migration_span( self, event: str, *, version: "str | None" = None, metadata: "dict[str, Any] | None" = None ) -> Any: """Start a migration span when telemetry is enabled.""" if not self.span_manager.is_enabled: return None attributes: dict[str, Any] = {"sqlspec.migration.event": event, "sqlspec.config": self.config_name} if self.bind_key: attributes["sqlspec.bind_key"] = self.bind_key correlation_id = CorrelationContext.get() if correlation_id: attributes["sqlspec.correlation_id"] = correlation_id if version: attributes["sqlspec.migration.version"] = version if metadata: for key, value in metadata.items(): if value is not None: attributes[f"sqlspec.migration.{key}"] = value return self.span_manager.start_span(f"sqlspec.migration.{event}", attributes)
[docs] def end_migration_span( self, span: Any, *, duration_ms: "int | None" = None, error: "Exception | None" = None ) -> None: """Finish a migration span, attaching optional duration metadata.""" if span is None: return if duration_ms is not None and has_span_attribute(span): span.set_attribute("sqlspec.migration.duration_ms", duration_ms) self.span_manager.end_span(span, error=error)
def emit_pool_create(self, pool: Any) -> None: span = self._start_lifecycle_span("pool.create", subject=pool) try: if self.lifecycle.has_pool_create: self.lifecycle.emit_pool_create(self._build_context(pool=pool)) finally: self.span_manager.end_span(span) def emit_pool_destroy(self, pool: Any) -> None: span = self._start_lifecycle_span("pool.destroy", subject=pool) try: if self.lifecycle.has_pool_destroy: self.lifecycle.emit_pool_destroy(self._build_context(pool=pool)) finally: self.span_manager.end_span(span) def emit_connection_create(self, connection: Any) -> None: span = self._start_lifecycle_span("connection.create", subject=connection) try: if self.lifecycle.has_connection_create: self.lifecycle.emit_connection_create(self._build_context(connection=connection)) finally: self.span_manager.end_span(span) def emit_connection_destroy(self, connection: Any) -> None: span = self._start_lifecycle_span("connection.destroy", subject=connection) try: if self.lifecycle.has_connection_destroy: self.lifecycle.emit_connection_destroy(self._build_context(connection=connection)) finally: self.span_manager.end_span(span) def emit_session_start(self, session: Any) -> None: span = self._start_lifecycle_span("session.start", subject=session) try: if self.lifecycle.has_session_start: self.lifecycle.emit_session_start(self._build_context(session=session)) finally: self.span_manager.end_span(span) def emit_session_end(self, session: Any) -> None: span = self._start_lifecycle_span("session.end", subject=session) try: if self.lifecycle.has_session_end: self.lifecycle.emit_session_end(self._build_context(session=session)) finally: self.span_manager.end_span(span) def emit_query_start(self, **extras: Any) -> None: if self.lifecycle.has_query_start: self.lifecycle.emit_query_start(self._build_context(**extras)) def emit_query_complete(self, **extras: Any) -> None: if self.lifecycle.has_query_complete: self.lifecycle.emit_query_complete(self._build_context(**extras)) def emit_error(self, exception: Exception, **extras: Any) -> None: if self.lifecycle.has_error: payload = self._build_context(exception=exception) payload.update({key: value for key, value in extras.items() if value is not None}) self.lifecycle.emit_error(payload) self.increment_metric("errors", 1.0)
[docs] def emit_statement_event( self, *, sql: str, parameters: Any, driver: str, operation: str, execution_mode: str | None, is_many: bool, is_script: bool, rows_affected: int | None, duration_s: float, storage_backend: str | None, started_at: float | None = None, ) -> None: """Emit a statement event to all registered observers.""" if not self._statement_observers: return sanitized_sql = self._redact_sql(sql) sanitized_params = self._redact_parameters(parameters) correlation_id = CorrelationContext.get() logging_config = self.config.logging db_system = resolve_db_system(self.config_name) sql_hash = None if logging_config and logging_config.include_sql_hash: sql_hash = compute_sql_hash(sanitized_sql) sql_truncation_length = logging_config.sql_truncation_length if logging_config else 2000 sql_original_length = len(sanitized_sql) sql_truncated = sql_original_length > sql_truncation_length trace_id = None span_id = None if logging_config and logging_config.include_trace_context: trace_id, span_id = get_trace_context() event = create_event( sql=sanitized_sql, parameters=sanitized_params, driver=driver, adapter=self.config_name, bind_key=self.bind_key, db_system=db_system, operation=operation, execution_mode=execution_mode, is_many=is_many, is_script=is_script, rows_affected=rows_affected, duration_s=duration_s, correlation_id=correlation_id, storage_backend=storage_backend, started_at=started_at, sql_hash=sql_hash, sql_truncated=sql_truncated, sql_original_length=sql_original_length, trace_id=trace_id, span_id=span_id, ) for observer in self._statement_observers: observer(event)
[docs] def start_query_span(self, sql: str, operation: str, driver: str) -> Any: """Start a query span with runtime metadata.""" if not self.span_manager.is_enabled: return None sql_hash = compute_sql_hash(sql) connection_info = {"sqlspec.statement.hash": sql_hash, "sqlspec.statement.length": len(sql)} sql_payload = "" if self.config.print_sql: sql_payload = self._redact_sql(sql) sql_payload, truncated = _truncate_text(sql_payload, max_chars=4096) if truncated: connection_info["sqlspec.statement.truncated"] = True correlation_id = CorrelationContext.get() return self.span_manager.start_query_span( driver=driver, adapter=self.config_name, bind_key=self.bind_key, sql=sql_payload, operation=operation, connection_info=connection_info, correlation_id=correlation_id, )
[docs] def start_storage_span( self, operation: str, *, destination: str | None = None, format_label: str | None = None ) -> Any: """Start a storage bridge span for read/write operations.""" if not self.span_manager.is_enabled: return None attributes: dict[str, Any] = {"sqlspec.storage.operation": operation, "sqlspec.config": self.config_name} if self.bind_key: attributes["sqlspec.bind_key"] = self.bind_key correlation_id = CorrelationContext.get() if correlation_id: attributes["sqlspec.correlation_id"] = correlation_id if destination: attributes["sqlspec.storage.destination"] = destination if format_label: attributes["sqlspec.storage.format"] = format_label return self.span_manager.start_span(f"sqlspec.storage.{operation}", attributes)
[docs] def start_span(self, name: str, *, attributes: dict[str, Any] | None = None) -> Any: """Start a custom span enriched with configuration context.""" if not self.span_manager.is_enabled: return None merged: dict[str, Any] = attributes.copy() if attributes else {} merged.setdefault("sqlspec.config", self.config_name) if self.bind_key: merged.setdefault("sqlspec.bind_key", self.bind_key) correlation_id = CorrelationContext.get() if correlation_id: merged.setdefault("sqlspec.correlation_id", correlation_id) return self.span_manager.start_span(name, merged)
[docs] def end_span(self, span: Any, *, error: Exception | None = None) -> None: """Finish a custom span.""" self.span_manager.end_span(span, error=error)
[docs] def end_storage_span( self, span: Any, *, telemetry: "StorageTelemetry | None" = None, error: Exception | None = None ) -> None: """Finish a storage span, attaching telemetry metadata when available.""" if span is None: return if telemetry: telemetry = self.annotate_storage_telemetry(telemetry) self._attach_storage_telemetry(span, telemetry) self.span_manager.end_span(span, error=error)
[docs] def annotate_storage_telemetry(self, telemetry: "StorageTelemetry") -> "StorageTelemetry": """Add bind key / config / correlation metadata to telemetry payloads.""" annotated = telemetry base = self.base_context() correlation_id = base.get("correlation_id") if correlation_id and not annotated.get("correlation_id"): annotated["correlation_id"] = correlation_id annotated.setdefault("config", self.config_name) if self.bind_key and not annotated.get("bind_key"): annotated["bind_key"] = self.bind_key return annotated
def _start_lifecycle_span(self, event: str, subject: Any | None = None) -> Any: if not self.span_manager.is_enabled: return None attributes: dict[str, Any] = {"sqlspec.lifecycle.event": event, "sqlspec.config": self.config_name} if self.bind_key: attributes["sqlspec.bind_key"] = self.bind_key correlation_id = CorrelationContext.get() if correlation_id: attributes["sqlspec.correlation_id"] = correlation_id if subject is not None: attributes["sqlspec.lifecycle.subject_type"] = type(subject).__name__ return self.span_manager.start_span(f"sqlspec.lifecycle.{event}", attributes) def _attach_storage_telemetry(self, span: Any, telemetry: "StorageTelemetry") -> None: if not has_span_attribute(span): return if "backend" in telemetry and telemetry["backend"] is not None: span.set_attribute("sqlspec.storage.backend", telemetry["backend"]) if "bytes_processed" in telemetry and telemetry["bytes_processed"] is not None: span.set_attribute("sqlspec.storage.bytes_processed", telemetry["bytes_processed"]) if "rows_processed" in telemetry and telemetry["rows_processed"] is not None: span.set_attribute("sqlspec.storage.rows_processed", telemetry["rows_processed"]) if "destination" in telemetry and telemetry["destination"] is not None: span.set_attribute("sqlspec.storage.destination", telemetry["destination"]) if "format" in telemetry and telemetry["format"] is not None: span.set_attribute("sqlspec.storage.format", telemetry["format"]) if "duration_s" in telemetry and telemetry["duration_s"] is not None: span.set_attribute("sqlspec.storage.duration_s", telemetry["duration_s"]) if "correlation_id" in telemetry and telemetry["correlation_id"] is not None: span.set_attribute("sqlspec.correlation_id", telemetry["correlation_id"]) def _redact_sql(self, sql: str) -> str: config = self._redaction if config is None or not config.mask_literals: return sql return _LITERAL_PATTERN.sub("'***'", sql) def _redact_parameters(self, parameters: Any) -> Any: config = self._redaction if config is None or not config.mask_parameters: return parameters allow_list = set(config.parameter_allow_list or ()) return _mask_parameters(parameters, allow_list)
def _mask_parameters(value: Any, allow_list: set[str]) -> Any: if isinstance(value, dict): masked: dict[str, Any] = {} for key, item in value.items(): if allow_list and key in allow_list: masked[key] = _mask_parameters(item, allow_list) else: masked[key] = "***" return masked if isinstance(value, list): return [_mask_parameters(item, allow_list) for item in value] if isinstance(value, tuple): return tuple(_mask_parameters(item, allow_list) for item in value) return "***" def _truncate_text(value: str, *, max_chars: int) -> tuple[str, bool]: if len(value) <= max_chars: return value, False return value[:max_chars], True __all__ = ("ObservabilityRuntime",)