Source code for sqlspec.adapters.psqlpy.config

"""Psqlpy database configuration."""

from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from mypy_extensions import mypyc_attr
from typing_extensions import NotRequired

from sqlspec.adapters.psqlpy._typing import PsqlpyConnection, PsqlpyCursor, PsqlpySessionContext
from sqlspec.adapters.psqlpy.core import (
    apply_driver_features,
    build_connection_config,
    build_postgres_extension_probe_names,
    default_statement_config,
    resolve_postgres_extension_state,
    resolve_runtime_statement_config,
)
from sqlspec.adapters.psqlpy.driver import PsqlpyDriver, PsqlpyExceptionHandler
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
from sqlspec.driver._async import AsyncPoolConnectionContext, AsyncPoolSessionFactory
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable
    from types import TracebackType

    from psqlpy import ConnectionPool

    from sqlspec.core import StatementConfig

__all__ = ("PsqlpyConfig", "PsqlpyConnectionParams", "PsqlpyCursor", "PsqlpyDriverFeatures", "PsqlpyPoolParams")


[docs] class PsqlpyConnectionParams(TypedDict): """Psqlpy connection parameters.""" dsn: NotRequired[str] username: NotRequired[str] password: NotRequired[str] db_name: NotRequired[str] host: NotRequired[str] port: NotRequired[int] connect_timeout_sec: NotRequired[int] connect_timeout_nanosec: NotRequired[int] tcp_user_timeout_sec: NotRequired[int] tcp_user_timeout_nanosec: NotRequired[int] keepalives: NotRequired[bool] keepalives_idle_sec: NotRequired[int] keepalives_idle_nanosec: NotRequired[int] keepalives_interval_sec: NotRequired[int] keepalives_interval_nanosec: NotRequired[int] keepalives_retries: NotRequired[int] ssl_mode: NotRequired[str] ca_file: NotRequired[str] target_session_attrs: NotRequired[str] options: NotRequired[str] application_name: NotRequired[str] client_encoding: NotRequired[str] gssencmode: NotRequired[str] sslnegotiation: NotRequired[str] sslcompression: NotRequired[str] sslcert: NotRequired[str] sslkey: NotRequired[str] sslpassword: NotRequired[str] sslrootcert: NotRequired[str] sslcrl: NotRequired[str] require_auth: NotRequired[str] channel_binding: NotRequired[str] krbsrvname: NotRequired[str] gsslib: NotRequired[str] gssdelegation: NotRequired[str] service: NotRequired[str] load_balance_hosts: NotRequired[str]
class PsqlpyPoolParams(PsqlpyConnectionParams): """Psqlpy pool parameters.""" hosts: NotRequired[list[str]] ports: NotRequired[list[int]] conn_recycling_method: NotRequired[str] max_db_pool_size: NotRequired[int] configure: NotRequired["Callable[..., Any]"] extra: NotRequired["dict[str, Any]"] class PsqlpyDriverFeatures(TypedDict): """Psqlpy driver feature flags. enable_cast_detection: Enable cast-aware parameter processing. Defaults to True. When enabled, SQL casts in prepared statements guide psqlpy parameter coercion for JSON, UUID, and timestamp-like values. enable_pgvector: Enable pgvector extension detection for vector similarity search. Defaults to True when the pgvector Python package is installed. Detects the PostgreSQL vector extension and promotes the dialect to "pgvector". It does not register psqlpy type handlers; use psqlpy.extra_types.PgVector or explicit SQL casts for vector values. enable_paradedb: Enable ParadeDB (pg_search) extension detection. When enabled and the pg_search extension is detected, the SQL dialect switches to "paradedb" which supports search operators (@@@, &&&, etc.) and inherits all pgvector distance operators. Defaults to True. Independent of enable_pgvector. json_serializer: Custom JSON serializer applied to the statement configuration. json_deserializer: Custom JSON deserializer retained alongside the serializer for parity with asyncpg. on_connection_create: Async callback executed when a connection is acquired from pool. Receives the raw psqlpy connection for low-level driver configuration. Called exactly once per physical connection using WeakSet tracking. enable_events: Enable database event channel support. Defaults to True when extension_config["events"] is configured. Provides pub/sub capabilities via LISTEN/NOTIFY or table-backed fallback. Requires extension_config["events"] for migration setup when using table_queue backend. events_backend: Event channel backend selection. Options: "listen_notify", "table_queue", "listen_notify_durable" - "listen_notify": Zero-copy PostgreSQL LISTEN/NOTIFY (ephemeral, real-time) - coming soon - "table_queue": Durable table-backed queue with retries and exactly-once delivery (current default) - "listen_notify_durable": Hybrid - real-time + durable (available when native support lands) Defaults to "table_queue" until native LISTEN/NOTIFY support is implemented. """ enable_cast_detection: NotRequired[bool] enable_pgvector: NotRequired[bool] enable_paradedb: NotRequired[bool] json_serializer: NotRequired["Callable[[Any], str]"] json_deserializer: NotRequired["Callable[[str], Any]"] on_connection_create: "NotRequired[Callable[[PsqlpyConnection], Awaitable[None]]]" enable_events: NotRequired[bool] events_backend: NotRequired[str] class _PsqlpySessionFactory(AsyncPoolSessionFactory): __slots__ = ("_ctx",) def __init__(self, config: "PsqlpyConfig") -> None: super().__init__(config) self._ctx: Any | None = None async def acquire_connection(self) -> "PsqlpyConnection": pool = self._config.connection_instance if pool is None: pool = await self._config.create_pool() self._config.connection_instance = pool ctx = pool.acquire() self._ctx = ctx connection = cast("PsqlpyConnection", await ctx.__aenter__()) await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage] return connection async def release_connection(self, _conn: "PsqlpyConnection", **kwargs: Any) -> None: if self._ctx is not None: await self._ctx.__aexit__(None, None, None) self._ctx = None class PsqlpyConnectionContext(AsyncPoolConnectionContext): """Async context manager for Psqlpy connections.""" __slots__ = ("_ctx",) def __init__(self, config: "PsqlpyConfig") -> None: super().__init__(config) self._ctx: Any = None async def __aenter__(self) -> PsqlpyConnection: pool = self._config.connection_instance if pool is None: pool = await self._config.create_pool() self._config.connection_instance = pool self._ctx = pool.acquire() connection = await self._ctx.__aenter__() await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage] return connection # type: ignore[no-any-return] async def __aexit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return await self._ctx.__aexit__(exc_type, exc_val, exc_tb) # type: ignore[no-any-return] return None
[docs] @mypyc_attr(native_class=False) class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, "ConnectionPool", PsqlpyDriver]): """Configuration for Psqlpy asynchronous database connections.""" driver_type: ClassVar[type[PsqlpyDriver]] = PsqlpyDriver connection_type: "ClassVar[type[PsqlpyConnection]]" = PsqlpyConnection supports_transactional_ddl: "ClassVar[bool]" = True supports_migration_schemas: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True supports_native_row_streaming: ClassVar[bool] = True _connection_context_class: "ClassVar[type[PsqlpyConnectionContext]]" = PsqlpyConnectionContext _session_factory_class: "ClassVar[type[_PsqlpySessionFactory]]" = _PsqlpySessionFactory _session_context_class: "ClassVar[type[PsqlpySessionContext]]" = PsqlpySessionContext _default_statement_config = default_statement_config
[docs] def __init__( self, *, connection_config: "PsqlpyPoolParams | dict[str, Any] | None" = None, connection_instance: "ConnectionPool | None" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "PsqlpyDriverFeatures | dict[str, Any] | None" = None, bind_key: str | None = None, extension_config: "ExtensionConfigs | None" = None, **kwargs: Any, ) -> None: """Initialize Psqlpy configuration. Extracts the 'on_connection_create' hook from driver_features before storing them. Initializes a set to track initialized connection IDs because psqlpy connections do not support weak references. Args: connection_config: Connection and pool configuration parameters. connection_instance: Existing connection pool instance to use. migration_config: Migration configuration. statement_config: SQL statement configuration. driver_features: Driver feature configuration (TypedDict or dict). bind_key: Optional unique identifier for this configuration. extension_config: Extension-specific configuration. **kwargs: Additional keyword arguments. """ connection_config = normalize_connection_config(connection_config) statement_config = statement_config or default_statement_config statement_config, driver_features = apply_driver_features(statement_config, driver_features) features_dict = dict(driver_features) if driver_features else {} self._user_connection_hook: Callable[[PsqlpyConnection], Awaitable[None]] | None = features_dict.pop( "on_connection_create", None ) self._initialized_connection_ids: set[int] = set() self._pgvector_available: bool | None = None self._paradedb_available: bool | None = None super().__init__( connection_config=connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, **kwargs, )
async def _ensure_connection_initialized(self, connection: "PsqlpyConnection") -> None: """Ensure connection callback has been called exactly once for this connection. Detects PostgreSQL extensions on first connection and updates dialect accordingly. """ if self._pgvector_available is None: detected_extensions: set[str] = set() extensions = build_postgres_extension_probe_names(self.driver_features) if extensions: try: result = await connection.fetch( "SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])", [extensions] ) rows = result.result() if result else [] detected_extensions = {r["extname"] for r in rows} except Exception: detected_extensions = set() self.statement_config, self._pgvector_available, self._paradedb_available = ( resolve_postgres_extension_state(self.statement_config, self.driver_features, detected_extensions) ) conn_id = id(connection) if conn_id in self._initialized_connection_ids: return if self._user_connection_hook is not None: await self._user_connection_hook(connection) self._initialized_connection_ids.add(conn_id) async def _create_pool(self) -> "ConnectionPool": """Create the actual async connection pool.""" from psqlpy import ConnectionPool return ConnectionPool(**build_connection_config(self.connection_config)) async def _close_pool(self) -> None: """Close the actual async connection pool.""" if not self.connection_instance: return self.connection_instance.close() self.connection_instance = None
[docs] async def create_connection(self) -> "PsqlpyConnection": """Create a single async connection (not from pool). Returns: A psqlpy Connection instance. """ pool = self.connection_instance if pool is None: pool = await self.create_pool() self.connection_instance = pool return await pool.connection()
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "PsqlpySessionContext": """Provide an async driver session context manager. Args: *_args: Additional arguments. statement_config: Optional statement configuration override. **_kwargs: Additional keyword arguments. Returns: A PsqlpyDriver session context manager. """ factory = _PsqlpySessionFactory(self) return PsqlpySessionContext( acquire_connection=factory.acquire_connection, release_connection=factory.release_connection, statement_config=statement_config or (lambda: resolve_runtime_statement_config(None, self.statement_config, default_statement_config)), driver_features=self.driver_features, prepare_driver=self._prepare_driver, )
[docs] async def provide_pool(self, *args: Any, **kwargs: Any) -> "ConnectionPool": """Provide async pool instance. Returns: The async connection pool. """ if not self.connection_instance: self.connection_instance = await self.create_pool() return self.connection_instance
[docs] def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for Psqlpy types. Returns: Dictionary mapping type names to types. """ namespace = super().get_signature_namespace() namespace.update({ "PsqlpyConnectionContext": PsqlpyConnectionContext, "PsqlpyConnection": PsqlpyConnection, "PsqlpyConnectionParams": PsqlpyConnectionParams, "PsqlpyCursor": PsqlpyCursor, "PsqlpyDriver": PsqlpyDriver, "PsqlpyDriverFeatures": PsqlpyDriverFeatures, "PsqlpyExceptionHandler": PsqlpyExceptionHandler, "PsqlpyPoolParams": PsqlpyPoolParams, "PsqlpySessionContext": PsqlpySessionContext, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": """Return LISTEN/NOTIFY defaults for Psqlpy adapters.""" return EventRuntimeHints(poll_interval=0.5, select_for_update=True, skip_locked=True, json_passthrough=True)