"""Event channel API with separate sync and async implementations."""
import asyncio
import importlib
import inspect
import logging
import threading
from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError
from sqlspec.extensions.events._hints import get_runtime_hints, resolve_adapter_name
from sqlspec.extensions.events._models import EventMessage
from sqlspec.extensions.events._protocols import AsyncEventBackendProtocol, SyncEventBackendProtocol
from sqlspec.extensions.events._queue import build_queue_backend
from sqlspec.extensions.events._store import normalize_event_channel_name
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.type_guards import has_span_attribute
from sqlspec.utils.uuids import uuid4
if TYPE_CHECKING:
from sqlspec.config import AsyncDatabaseConfig, SyncDatabaseConfig
from sqlspec.extensions.events._protocols import AsyncEventHandler, SyncEventHandler
from sqlspec.observability import ObservabilityRuntime
logger = get_logger("sqlspec.events.channel")
__all__ = (
"AsyncEventChannel",
"AsyncEventListener",
"EventMessage",
"SyncEventChannel",
"SyncEventListener",
"load_native_backend",
"resolve_poll_interval",
)
_ADAPTER_MODULE_PARTS = 3
[docs]
@dataclass(slots=True)
class AsyncEventListener:
"""Represents a running async listener task."""
id: str
channel: str
task: "asyncio.Task[Any]"
stop_event: "asyncio.Event"
poll_interval: float
[docs]
async def stop(self) -> None:
"""Signal the listener to stop and await task completion."""
self.stop_event.set()
if not self.task.done():
await self.task
[docs]
@dataclass(slots=True)
class SyncEventListener:
"""Represents a running sync listener thread."""
id: str
channel: str
thread: threading.Thread
stop_event: threading.Event
poll_interval: float
[docs]
def stop(self) -> None:
"""Signal the listener to stop and join the thread."""
self.stop_event.set()
self.thread.join()
[docs]
def resolve_poll_interval(poll_interval: "float | None", default: float) -> float:
"""Resolve poll interval with validation."""
if poll_interval is None:
return default
if poll_interval <= 0:
msg = "poll_interval must be greater than zero"
raise ImproperConfigurationError(msg)
return poll_interval
def _resolve_event_type(payload: "dict[str, Any]", metadata: "dict[str, Any] | None") -> "str | None":
"""Resolve event type from payload or metadata."""
if metadata and metadata.get("event_type"):
return str(metadata["event_type"])
if payload.get("event_type") is not None:
return str(payload["event_type"])
if payload.get("type") is not None:
return str(payload["type"])
return None
_POSTGRES_ADAPTERS = frozenset({"asyncpg", "psycopg", "psqlpy"})
def _get_default_backend(adapter_name: "str | None") -> str:
"""Return the default events backend for an adapter."""
if adapter_name in _POSTGRES_ADAPTERS:
return "listen_notify"
return "table_queue"
[docs]
def load_native_backend(config: Any, backend_name: str | None, extension_settings: "dict[str, Any]") -> Any | None:
"""Load adapter-specific native backend if available."""
if backend_name in {None, "table_queue"}:
return None
module_name = type(config).__module__
parts = module_name.split(".")
if len(parts) < _ADAPTER_MODULE_PARTS or parts[0] != "sqlspec" or parts[1] != "adapters":
return None
adapter_name = parts[2]
backend_module_name = f"sqlspec.adapters.{adapter_name}.events.backend"
try:
backend_module = importlib.import_module(backend_module_name)
except ModuleNotFoundError:
log_with_context(
logger,
logging.DEBUG,
"event.listen",
adapter_name=adapter_name,
backend_module=backend_module_name,
status="backend_missing",
)
return None
except ImportError as error:
log_with_context(
logger,
logging.WARNING,
"event.listen",
adapter_name=adapter_name,
backend_module=backend_module_name,
error_type=type(error).__name__,
status="backend_import_failed",
)
return None
try:
factory = backend_module.create_event_backend
except AttributeError:
log_with_context(
logger,
logging.DEBUG,
"event.listen",
adapter_name=adapter_name,
backend_module=backend_module_name,
status="backend_factory_missing",
)
return None
try:
backend = factory(config, backend_name, extension_settings)
except MissingDependencyError as error:
log_with_context(
logger,
logging.WARNING,
"event.listen",
adapter_name=adapter_name,
backend_name=backend_name,
error_type=type(error).__name__,
status="backend_dependency_missing",
)
return None
except ImproperConfigurationError as error:
log_with_context(
logger,
logging.WARNING,
"event.listen",
adapter_name=adapter_name,
backend_name=backend_name,
error_type=type(error).__name__,
status="backend_config_rejected",
)
return None
return backend
def _start_event_span(
runtime: "ObservabilityRuntime",
operation: str,
backend_name: str,
adapter_name: "str | None",
channel: "str | None" = None,
mode: str = "sync",
) -> Any:
"""Start an observability span for event operations."""
if not runtime.span_manager.is_enabled:
return None
attributes: dict[str, Any] = {
"sqlspec.events.operation": operation,
"sqlspec.events.backend": backend_name,
"sqlspec.events.mode": mode,
}
if adapter_name:
attributes["sqlspec.events.adapter"] = adapter_name
if channel:
attributes["sqlspec.events.channel"] = channel
return runtime.start_span(f"sqlspec.events.{operation}", attributes=attributes)
def _end_event_span(
runtime: "ObservabilityRuntime", span: Any, *, error: "Exception | None" = None, result: "str | None" = None
) -> None:
"""End an observability span."""
if span is None:
return
if result is not None and has_span_attribute(span):
span.set_attribute("sqlspec.events.result", result)
runtime.end_span(span, error=error)
[docs]
class SyncEventChannel:
"""Event channel for synchronous database configurations."""
__slots__ = (
"_adapter_name",
"_backend",
"_backend_name",
"_config",
"_listeners",
"_poll_interval_default",
"_runtime",
)
_backend: "SyncEventBackendProtocol"
[docs]
def __init__(self, config: "SyncDatabaseConfig[Any, Any, Any]") -> None:
if config.is_async:
msg = "SyncEventChannel requires a sync configuration"
raise ImproperConfigurationError(msg)
extension_settings: dict[str, Any] = dict(config.extension_config.get("events", {}))
self._adapter_name = resolve_adapter_name(config)
hints = get_runtime_hints(self._adapter_name, config)
self._poll_interval_default = float(extension_settings.get("poll_interval") or hints.poll_interval)
queue_backend = build_queue_backend(config, extension_settings, adapter_name=self._adapter_name, hints=hints)
backend_name = extension_settings.get("backend") or _get_default_backend(self._adapter_name)
native_backend = load_native_backend(config, backend_name, extension_settings)
if native_backend is None:
if backend_name not in {None, "table_queue"}:
log_with_context(
logger,
logging.WARNING,
"event.listen",
adapter_name=self._adapter_name,
backend_name=backend_name,
fallback_backend="table_queue",
status="backend_unavailable",
)
self._backend = cast("SyncEventBackendProtocol", queue_backend)
backend_label = "table_queue"
else:
self._backend = cast("SyncEventBackendProtocol", native_backend)
if isinstance(native_backend, SyncEventBackendProtocol):
backend_label = native_backend.backend_name
else:
backend_label = backend_name or "table_queue"
self._config = config
self._backend_name = backend_label
self._runtime = config.get_observability_runtime()
self._listeners: dict[str, SyncEventListener] = {}
[docs]
def publish(self, channel: str, payload: "dict[str, Any]", metadata: "dict[str, Any] | None" = None) -> str:
"""Publish an event to a channel."""
channel = normalize_event_channel_name(channel)
if not self._backend.supports_sync:
msg = "Current events backend does not support sync publishing"
raise ImproperConfigurationError(msg)
span = _start_event_span(self._runtime, "publish", self._backend_name, self._adapter_name, channel, mode="sync")
try:
event_id = self._backend.publish(channel, payload, metadata)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="published")
log_with_context(
logger,
logging.DEBUG,
"event.publish",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
event_id=event_id,
event_type=_resolve_event_type(payload, metadata),
mode="sync",
)
return event_id
[docs]
def iter_events(self, channel: str, *, poll_interval: float | None = None) -> Iterator[EventMessage]:
"""Yield events as they become available."""
channel = normalize_event_channel_name(channel)
if not self._backend.supports_sync:
msg = "Current events backend does not support sync consumption"
raise ImproperConfigurationError(msg)
interval = resolve_poll_interval(poll_interval, self._poll_interval_default)
while True:
span = _start_event_span(
self._runtime, "dequeue", self._backend_name, self._adapter_name, channel, mode="sync"
)
try:
event = self._backend.dequeue(channel, interval)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
if event is None:
_end_event_span(self._runtime, span, result="empty")
continue
_end_event_span(self._runtime, span, result="delivered")
self._runtime.increment_metric("events.deliver")
log_with_context(
logger,
logging.DEBUG,
"event.receive",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
event_id=event.event_id,
event_type=_resolve_event_type(event.payload, event.metadata),
mode="sync",
)
yield event
[docs]
def listen(
self, channel: str, handler: "SyncEventHandler", *, poll_interval: float | None = None, auto_ack: bool = True
) -> SyncEventListener:
"""Start a background thread that invokes handler for each event."""
channel = normalize_event_channel_name(channel)
if not self._backend.supports_sync:
msg = "Current events backend does not support sync listeners"
raise ImproperConfigurationError(msg)
interval = resolve_poll_interval(poll_interval, self._poll_interval_default)
listener_id = uuid4().hex
stop_event = threading.Event()
thread = threading.Thread(
target=self._run_listener, args=(listener_id, channel, handler, stop_event, interval, auto_ack), daemon=True
)
listener = SyncEventListener(listener_id, channel, thread, stop_event, interval)
self._listeners[listener_id] = listener
self._runtime.increment_metric("events.listener.start")
log_with_context(
logger,
logging.DEBUG,
"event.listen",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
listener_id=listener_id,
mode="sync",
status="start",
)
thread.start()
return listener
[docs]
def stop_listener(self, listener_id: str) -> None:
"""Stop a running listener."""
listener = self._listeners.pop(listener_id, None)
if listener is None:
return
listener.stop()
self._runtime.increment_metric("events.listener.stop")
log_with_context(
logger,
logging.DEBUG,
"event.listen",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=listener.channel,
listener_id=listener_id,
mode="sync",
status="stop",
)
[docs]
def ack(self, event_id: str) -> None:
"""Acknowledge an event."""
if not self._backend.supports_sync:
msg = "Current events backend does not support sync ack"
raise ImproperConfigurationError(msg)
span = _start_event_span(self._runtime, "ack", self._backend_name, self._adapter_name, mode="sync")
try:
self._backend.ack(event_id)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="acked")
[docs]
def nack(self, event_id: str) -> None:
"""Return an event to the queue for redelivery."""
span = _start_event_span(self._runtime, "nack", self._backend_name, self._adapter_name, mode="sync")
try:
self._backend.nack(event_id)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="nacked")
[docs]
def shutdown(self) -> None:
"""Shutdown the event channel and release backend resources."""
span = _start_event_span(self._runtime, "shutdown", self._backend_name, self._adapter_name, mode="sync")
try:
for listener_id in list(self._listeners):
self.stop_listener(listener_id)
self._backend.shutdown()
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="shutdown")
self._runtime.increment_metric("events.shutdown")
def _run_listener(
self,
listener_id: str,
channel: str,
handler: "SyncEventHandler",
stop_event: threading.Event,
poll_interval: float,
auto_ack: bool,
) -> None:
"""Internal listener loop."""
try:
while not stop_event.is_set():
span = _start_event_span(
self._runtime, "dequeue", self._backend_name, self._adapter_name, channel, mode="sync"
)
try:
event = self._backend.dequeue(channel, poll_interval)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
if event is None:
_end_event_span(self._runtime, span, result="empty")
continue
_end_event_span(self._runtime, span, result="delivered")
try:
handler(event)
if auto_ack:
self._backend.ack(event.event_id)
except Exception as error:
log_with_context(
logger,
logging.WARNING,
"event.listen",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
listener_id=listener_id,
mode="sync",
error_type=type(error).__name__,
status="handler_error",
event_id=event.event_id,
event_type=_resolve_event_type(event.payload, event.metadata),
)
finally:
self._listeners.pop(listener_id, None)
[docs]
class AsyncEventChannel:
"""Event channel for asynchronous database configurations."""
__slots__ = (
"_adapter_name",
"_backend",
"_backend_name",
"_config",
"_listeners",
"_poll_interval_default",
"_runtime",
)
_backend: "AsyncEventBackendProtocol"
[docs]
def __init__(self, config: "AsyncDatabaseConfig[Any, Any, Any]") -> None:
if not config.is_async:
msg = "AsyncEventChannel requires an async configuration"
raise ImproperConfigurationError(msg)
extension_settings: dict[str, Any] = dict(config.extension_config.get("events", {}))
self._adapter_name = resolve_adapter_name(config)
hints = get_runtime_hints(self._adapter_name, config)
self._poll_interval_default = float(extension_settings.get("poll_interval") or hints.poll_interval)
queue_backend = build_queue_backend(config, extension_settings, adapter_name=self._adapter_name, hints=hints)
backend_name = extension_settings.get("backend") or _get_default_backend(self._adapter_name)
native_backend = load_native_backend(config, backend_name, extension_settings)
if native_backend is None:
if backend_name not in {None, "table_queue"}:
log_with_context(
logger,
logging.WARNING,
"event.listen",
adapter_name=self._adapter_name,
backend_name=backend_name,
fallback_backend="table_queue",
status="backend_unavailable",
)
self._backend = cast("AsyncEventBackendProtocol", queue_backend)
backend_label = "table_queue"
else:
self._backend = cast("AsyncEventBackendProtocol", native_backend)
if isinstance(native_backend, AsyncEventBackendProtocol):
backend_label = native_backend.backend_name
else:
backend_label = backend_name or "table_queue"
self._config = config
self._backend_name = backend_label
self._runtime = config.get_observability_runtime()
self._listeners: dict[str, AsyncEventListener] = {}
[docs]
async def publish(self, channel: str, payload: "dict[str, Any]", metadata: "dict[str, Any] | None" = None) -> str:
"""Publish an event to a channel."""
channel = normalize_event_channel_name(channel)
if not self._backend.supports_async:
msg = "Current events backend does not support async publishing"
raise ImproperConfigurationError(msg)
span = _start_event_span(
self._runtime, "publish", self._backend_name, self._adapter_name, channel, mode="async"
)
try:
event_id = await self._backend.publish(channel, payload, metadata)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="published")
log_with_context(
logger,
logging.DEBUG,
"event.publish",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
event_id=event_id,
event_type=_resolve_event_type(payload, metadata),
mode="async",
)
return event_id
[docs]
async def iter_events(self, channel: str, *, poll_interval: float | None = None) -> AsyncIterator[EventMessage]:
"""Yield events as they become available."""
channel = normalize_event_channel_name(channel)
if not self._backend.supports_async:
msg = "Current events backend does not support async consumption"
raise ImproperConfigurationError(msg)
interval = resolve_poll_interval(poll_interval, self._poll_interval_default)
while True:
span = _start_event_span(
self._runtime, "dequeue", self._backend_name, self._adapter_name, channel, mode="async"
)
try:
event = await self._backend.dequeue(channel, interval)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
if event is None:
_end_event_span(self._runtime, span, result="empty")
continue
_end_event_span(self._runtime, span, result="delivered")
self._runtime.increment_metric("events.deliver")
log_with_context(
logger,
logging.DEBUG,
"event.receive",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
event_id=event.event_id,
event_type=_resolve_event_type(event.payload, event.metadata),
mode="async",
)
yield event
[docs]
def listen(
self,
channel: str,
handler: "AsyncEventHandler | SyncEventHandler",
*,
poll_interval: float | None = None,
auto_ack: bool = True,
) -> AsyncEventListener:
"""Start an async task that delivers events to handler."""
channel = normalize_event_channel_name(channel)
if not self._backend.supports_async:
msg = "Current events backend does not support async listeners"
raise ImproperConfigurationError(msg)
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()
interval = resolve_poll_interval(poll_interval, self._poll_interval_default)
listener_id = uuid4().hex
task = loop.create_task(self._run_listener(listener_id, channel, handler, stop_event, interval, auto_ack))
listener = AsyncEventListener(listener_id, channel, task, stop_event, interval)
self._listeners[listener_id] = listener
self._runtime.increment_metric("events.listener.start")
log_with_context(
logger,
logging.DEBUG,
"event.listen",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
listener_id=listener_id,
mode="async",
status="start",
)
return listener
[docs]
async def stop_listener(self, listener_id: str) -> None:
"""Stop a running listener."""
listener = self._listeners.pop(listener_id, None)
if listener is None:
return
await listener.stop()
self._runtime.increment_metric("events.listener.stop")
log_with_context(
logger,
logging.DEBUG,
"event.listen",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=listener.channel,
listener_id=listener_id,
mode="async",
status="stop",
)
[docs]
async def ack(self, event_id: str) -> None:
"""Acknowledge an event."""
if not self._backend.supports_async:
msg = "Current events backend does not support async ack"
raise ImproperConfigurationError(msg)
span = _start_event_span(self._runtime, "ack", self._backend_name, self._adapter_name, mode="async")
try:
await self._backend.ack(event_id)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="acked")
[docs]
async def nack(self, event_id: str) -> None:
"""Return an event to the queue for redelivery."""
span = _start_event_span(self._runtime, "nack", self._backend_name, self._adapter_name, mode="async")
try:
await self._backend.nack(event_id)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="nacked")
[docs]
async def shutdown(self) -> None:
"""Shutdown the event channel and release backend resources."""
span = _start_event_span(self._runtime, "shutdown", self._backend_name, self._adapter_name, mode="async")
try:
for listener_id in list(self._listeners):
await self.stop_listener(listener_id)
await self._backend.shutdown()
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
_end_event_span(self._runtime, span, result="shutdown")
self._runtime.increment_metric("events.shutdown")
async def _run_listener(
self,
listener_id: str,
channel: str,
handler: "AsyncEventHandler | SyncEventHandler",
stop_event: "asyncio.Event",
poll_interval: float,
auto_ack: bool,
) -> None:
"""Internal listener loop."""
try:
while not stop_event.is_set():
span = _start_event_span(
self._runtime, "dequeue", self._backend_name, self._adapter_name, channel, mode="async"
)
try:
event = await self._backend.dequeue(channel, poll_interval)
except Exception as error:
_end_event_span(self._runtime, span, error=error)
raise
if event is None:
_end_event_span(self._runtime, span, result="empty")
continue
_end_event_span(self._runtime, span, result="delivered")
try:
result = handler(event)
if inspect.isawaitable(result):
await result
if auto_ack:
await self._backend.ack(event.event_id)
except Exception as error:
log_with_context(
logger,
logging.WARNING,
"event.listen",
adapter_name=self._adapter_name,
backend_name=self._backend_name,
channel=channel,
listener_id=listener_id,
mode="async",
error_type=type(error).__name__,
status="handler_error",
event_id=event.event_id,
event_type=_resolve_event_type(event.payload, event.metadata),
)
finally:
self._listeners.pop(listener_id, None)