Source code for sqlspec.observability._dispatcher

"""Lifecycle dispatcher for drivers and registry hooks."""

import inspect
from collections.abc import Callable, Iterable
from typing import Any, Literal

from sqlspec.utils.logging import get_logger

__all__ = ("LifecycleContext", "LifecycleDispatcher", "LifecycleHook")


logger = get_logger("sqlspec.observability.lifecycle")

LifecycleContext = dict[str, Any]
LifecycleHook = Callable[[LifecycleContext], Any]

LifecycleEvent = Literal[
    "on_pool_create",
    "on_pool_destroying",
    "on_pool_destroy",
    "on_connection_create",
    "on_connection_destroy",
    "on_session_start",
    "on_session_end",
    "on_query_start",
    "on_query_complete",
    "on_error",
]
EVENT_ATTRS: tuple[LifecycleEvent, ...] = (
    "on_pool_create",
    "on_pool_destroying",
    "on_pool_destroy",
    "on_connection_create",
    "on_connection_destroy",
    "on_session_start",
    "on_session_end",
    "on_query_start",
    "on_query_complete",
    "on_error",
)


[docs] class LifecycleDispatcher: """Dispatches lifecycle hooks with guard flags and diagnostics counters.""" __slots__ = ( "_counters", "_hooks", "_is_enabled", "has_connection_create", "has_connection_destroy", "has_error", "has_pool_create", "has_pool_destroy", "has_pool_destroying", "has_query_complete", "has_query_start", "has_session_end", "has_session_start", )
[docs] def __init__(self, hooks: "dict[str, Iterable[LifecycleHook]] | None" = None) -> None: self.has_pool_create = False self.has_pool_destroying = False self.has_pool_destroy = False self.has_connection_create = False self.has_connection_destroy = False self.has_session_start = False self.has_session_end = False self.has_query_start = False self.has_query_complete = False self.has_error = False normalized: dict[LifecycleEvent, list[LifecycleHook]] = {} for event_name in EVENT_ATTRS: callables = hooks.get(event_name) if hooks else None normalized[event_name] = list(callables) if callables else [] if normalized[event_name]: self._enable_event_guard(event_name) self._hooks: dict[LifecycleEvent, list[LifecycleHook]] = normalized self._counters: dict[LifecycleEvent, int] = dict.fromkeys(EVENT_ATTRS, 0) self._is_enabled = any(self._hooks.values())
@property def is_enabled(self) -> bool: """Return True when at least one hook is registered.""" return self._is_enabled
[docs] def emit_pool_create_sync(self, context: LifecycleContext) -> None: """Fire pool creation hooks synchronously.""" self._emit_sync("on_pool_create", context)
[docs] async def emit_pool_create_async(self, context: LifecycleContext) -> None: """Fire pool creation hooks, awaiting any awaitable return values.""" await self._emit_async("on_pool_create", context)
[docs] def emit_pool_create(self, context: LifecycleContext) -> None: """Fire pool creation hooks synchronously.""" self.emit_pool_create_sync(context)
[docs] def emit_pool_destroying_sync(self, context: LifecycleContext) -> None: """Fire pre-destruction hooks synchronously.""" self._emit_sync("on_pool_destroying", context)
[docs] async def emit_pool_destroying_async(self, context: LifecycleContext) -> None: """Fire pre-destruction hooks, awaiting any awaitable return values.""" await self._emit_async("on_pool_destroying", context)
[docs] def emit_pool_destroying(self, context: LifecycleContext) -> None: """Fire pre-destruction hooks synchronously.""" self.emit_pool_destroying_sync(context)
[docs] def emit_pool_destroy_sync(self, context: LifecycleContext) -> None: """Fire pool destruction hooks synchronously.""" self._emit_sync("on_pool_destroy", context)
[docs] async def emit_pool_destroy_async(self, context: LifecycleContext) -> None: """Fire pool destruction hooks, awaiting any awaitable return values.""" await self._emit_async("on_pool_destroy", context)
[docs] def emit_pool_destroy(self, context: LifecycleContext) -> None: """Fire pool destruction hooks synchronously.""" self.emit_pool_destroy_sync(context)
[docs] def emit_connection_create_sync(self, context: LifecycleContext) -> None: """Fire connection creation hooks synchronously.""" self._emit_sync("on_connection_create", context)
[docs] async def emit_connection_create_async(self, context: LifecycleContext) -> None: """Fire connection creation hooks, awaiting any awaitable return values.""" await self._emit_async("on_connection_create", context)
[docs] def emit_connection_create(self, context: LifecycleContext) -> None: """Fire connection creation hooks synchronously.""" self.emit_connection_create_sync(context)
[docs] def emit_connection_destroy_sync(self, context: LifecycleContext) -> None: """Fire connection teardown hooks synchronously.""" self._emit_sync("on_connection_destroy", context)
[docs] async def emit_connection_destroy_async(self, context: LifecycleContext) -> None: """Fire connection teardown hooks, awaiting any awaitable return values.""" await self._emit_async("on_connection_destroy", context)
[docs] def emit_connection_destroy(self, context: LifecycleContext) -> None: """Fire connection teardown hooks synchronously.""" self.emit_connection_destroy_sync(context)
[docs] def emit_session_start_sync(self, context: LifecycleContext) -> None: """Fire session start hooks synchronously.""" self._emit_sync("on_session_start", context)
[docs] async def emit_session_start_async(self, context: LifecycleContext) -> None: """Fire session start hooks, awaiting any awaitable return values.""" await self._emit_async("on_session_start", context)
[docs] def emit_session_start(self, context: LifecycleContext) -> None: """Fire session start hooks synchronously.""" self.emit_session_start_sync(context)
[docs] def emit_session_end_sync(self, context: LifecycleContext) -> None: """Fire session end hooks synchronously.""" self._emit_sync("on_session_end", context)
[docs] async def emit_session_end_async(self, context: LifecycleContext) -> None: """Fire session end hooks, awaiting any awaitable return values.""" await self._emit_async("on_session_end", context)
[docs] def emit_session_end(self, context: LifecycleContext) -> None: """Fire session end hooks synchronously.""" self.emit_session_end_sync(context)
[docs] def emit_query_start_sync(self, context: LifecycleContext) -> None: """Fire query start hooks synchronously.""" self._emit_sync("on_query_start", context)
[docs] async def emit_query_start_async(self, context: LifecycleContext) -> None: """Fire query start hooks, awaiting any awaitable return values.""" await self._emit_async("on_query_start", context)
[docs] def emit_query_start(self, context: LifecycleContext) -> None: """Fire query start hooks synchronously.""" self.emit_query_start_sync(context)
[docs] def emit_query_complete_sync(self, context: LifecycleContext) -> None: """Fire query completion hooks synchronously.""" self._emit_sync("on_query_complete", context)
[docs] async def emit_query_complete_async(self, context: LifecycleContext) -> None: """Fire query completion hooks, awaiting any awaitable return values.""" await self._emit_async("on_query_complete", context)
[docs] def emit_query_complete(self, context: LifecycleContext) -> None: """Fire query completion hooks synchronously.""" self.emit_query_complete_sync(context)
[docs] def emit_error_sync(self, context: LifecycleContext) -> None: """Fire error hooks synchronously.""" self._emit_sync("on_error", context)
[docs] async def emit_error_async(self, context: LifecycleContext) -> None: """Fire error hooks, awaiting any awaitable return values.""" await self._emit_async("on_error", context)
[docs] def emit_error(self, context: LifecycleContext) -> None: """Fire error hooks synchronously.""" self.emit_error_sync(context)
[docs] def register_hook(self, event: LifecycleEvent, callback: LifecycleHook) -> None: """Append a hook at runtime.""" callbacks = self._hooks.setdefault(event, []) callbacks.append(callback) self._enable_event_guard(event) self._is_enabled = True
[docs] def snapshot(self, *, prefix: str | None = None) -> "dict[str, int]": """Return counter snapshot keyed for diagnostics export.""" metrics: dict[str, int] = {} for event_name, count in self._counters.items(): key = event_name.replace("on_", "lifecycle.") if prefix: key = f"{prefix}.{key}" metrics[key] = count return metrics
def _enable_event_guard(self, event: LifecycleEvent) -> None: match event: case "on_pool_create": self.has_pool_create = True case "on_pool_destroying": self.has_pool_destroying = True case "on_pool_destroy": self.has_pool_destroy = True case "on_connection_create": self.has_connection_create = True case "on_connection_destroy": self.has_connection_destroy = True case "on_session_start": self.has_session_start = True case "on_session_end": self.has_session_end = True case "on_query_start": self.has_query_start = True case "on_query_complete": self.has_query_complete = True case "on_error": self.has_error = True def _emit_sync(self, event: LifecycleEvent, context: LifecycleContext) -> None: callbacks = self._hooks.get(event) if not callbacks: return self._counters[event] += 1 for callback in callbacks: self._invoke_callback(callback, context, event) async def _emit_async(self, event: LifecycleEvent, context: LifecycleContext) -> None: callbacks = self._hooks.get(event) if not callbacks: return self._counters[event] += 1 for callback in callbacks: try: result = callback(context) except Exception as exc: # pragma: no cover logger.warning("Lifecycle hook failed: event=%s error=%s", event, exc) continue if inspect.isawaitable(result): try: await result except Exception as exc: # pragma: no cover logger.warning("Lifecycle hook failed: event=%s error=%s", event, exc) @staticmethod def _invoke_callback(callback: LifecycleHook, context: LifecycleContext, event: LifecycleEvent) -> None: try: callback(context) except Exception as exc: # pragma: no cover logger.warning("Lifecycle hook failed: event=%s error=%s", event, exc)