Source code for sqlspec.adapters.psycopg.config

"""Psycopg database configuration with direct field-based configuration."""

from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from mypy_extensions import mypyc_attr
from psycopg_pool import AsyncConnectionPool, ConnectionPool
from typing_extensions import NotRequired

from sqlspec.adapters.psycopg._typing import PsycopgAsyncConnection, PsycopgSyncConnection
from sqlspec.adapters.psycopg.core import apply_driver_features, default_statement_config
from sqlspec.adapters.psycopg.driver import (
    PsycopgAsyncCursor,
    PsycopgAsyncDriver,
    PsycopgAsyncExceptionHandler,
    PsycopgAsyncSessionContext,
    PsycopgSyncCursor,
    PsycopgSyncDriver,
    PsycopgSyncExceptionHandler,
    PsycopgSyncSessionContext,
)
from sqlspec.adapters.psycopg.type_converter import register_pgvector_async, register_pgvector_sync
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable

    from sqlspec.core import StatementConfig


[docs] class PsycopgConnectionParams(TypedDict): """Psycopg connection parameters.""" conninfo: NotRequired[str] host: NotRequired[str] port: NotRequired[int] user: NotRequired[str] password: NotRequired[str] dbname: NotRequired[str] connect_timeout: NotRequired[int] options: NotRequired[str] application_name: NotRequired[str] sslmode: NotRequired[str] sslcert: NotRequired[str] sslkey: NotRequired[str] sslrootcert: NotRequired[str] autocommit: NotRequired[bool] extra: NotRequired["dict[str, Any]"]
[docs] class PsycopgPoolParams(PsycopgConnectionParams): """Psycopg pool parameters.""" min_size: NotRequired[int] max_size: NotRequired[int] name: NotRequired[str] timeout: NotRequired[float] max_waiting: NotRequired[int] max_lifetime: NotRequired[float] max_idle: NotRequired[float] reconnect_timeout: NotRequired[float] num_workers: NotRequired[int] configure: NotRequired["Callable[..., Any]"] kwargs: NotRequired["dict[str, Any]"]
class PsycopgDriverFeatures(TypedDict): """Psycopg driver feature flags. enable_pgvector: Enable automatic pgvector extension support for vector similarity search. Requires pgvector-python package (pip install pgvector) and PostgreSQL with pgvector extension. Defaults to True when pgvector-python is installed. Provides automatic conversion between Python objects and PostgreSQL vector types. Enables vector similarity operations and index support. Set to False to disable pgvector support even when package is available. enable_paradedb: Enable ParadeDB (pg_search) extension detection. When enabled and the pg_search extension is detected, the SQL dialect switches to "paradedb" which supports search operators (@@@, &&&, etc.) and inherits all pgvector distance operators. Defaults to True. Independent of enable_pgvector. json_serializer: Custom JSON serializer for StatementConfig parameter handling. json_deserializer: Custom JSON deserializer reference stored alongside the serializer for parity with asyncpg. on_connection_create: Callback executed when a connection is created/acquired from the pool. Receives the raw psycopg connection for low-level driver configuration. Runs after internal setup (pgvector registration). For sync config: Callable[[PsycopgSyncConnection], None] For async config: Callable[[PsycopgAsyncConnection], Awaitable[None]] enable_events: Enable database event channel support. Defaults to True when extension_config["events"] is configured. Provides pub/sub capabilities via LISTEN/NOTIFY or table-backed fallback. Requires extension_config["events"] for migration setup when using table_queue backend. events_backend: Event channel backend selection. Options: "listen_notify", "table_queue", "listen_notify_durable" - "listen_notify": Zero-copy PostgreSQL LISTEN/NOTIFY (ephemeral, real-time) - coming soon - "table_queue": Durable table-backed queue with retries and exactly-once delivery (current default) - "listen_notify_durable": Hybrid - real-time + durable (available when native support lands) Defaults to "table_queue" until native LISTEN/NOTIFY support is implemented. """ enable_pgvector: NotRequired[bool] enable_paradedb: NotRequired[bool] json_serializer: NotRequired["Callable[[Any], str]"] json_deserializer: NotRequired["Callable[[str], Any]"] on_connection_create: NotRequired["Callable[..., Any]"] enable_events: NotRequired[bool] events_backend: NotRequired[str] __all__ = ( "PsycopgAsyncConfig", "PsycopgAsyncCursor", "PsycopgConnectionParams", "PsycopgDriverFeatures", "PsycopgPoolParams", "PsycopgSyncConfig", "PsycopgSyncCursor", ) class PsycopgSyncConnectionContext: """Context manager for Psycopg connections.""" __slots__ = ("_config", "_ctx") def __init__(self, config: "PsycopgSyncConfig") -> None: self._config = config self._ctx: Any = None def __enter__(self) -> "PsycopgSyncConnection": if self._config.connection_instance: self._ctx = self._config.connection_instance.connection() return cast("PsycopgSyncConnection", self._ctx.__enter__()) # Fallback for no pool self._ctx = self._config.create_connection() return cast("PsycopgSyncConnection", self._ctx) def __exit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any ) -> bool | None: if self._config.connection_instance and self._ctx: return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb)) if self._ctx: self._ctx.close() return None class _PsycopgSyncSessionConnectionHandler: __slots__ = ("_config", "_conn", "_ctx") def __init__(self, config: "PsycopgSyncConfig") -> None: self._config = config self._ctx: Any = None self._conn: PsycopgSyncConnection | None = None def acquire_connection(self) -> "PsycopgSyncConnection": if self._config.connection_instance: self._ctx = self._config.connection_instance.connection() return cast("PsycopgSyncConnection", self._ctx.__enter__()) self._conn = self._config.create_connection() return self._conn def release_connection(self, _conn: "PsycopgSyncConnection") -> None: if self._ctx is not None: self._ctx.__exit__(None, None, None) self._ctx = None return if self._conn is not None: self._conn.close() self._conn = None
[docs] class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool, PsycopgSyncDriver]): """Configuration for Psycopg synchronous database connections with direct field-based configuration.""" driver_type: "ClassVar[type[PsycopgSyncDriver]]" = PsycopgSyncDriver connection_type: "ClassVar[type[PsycopgSyncConnection]]" = PsycopgSyncConnection supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True
[docs] def __init__( self, *, connection_config: "PsycopgPoolParams | dict[str, Any] | None" = None, connection_instance: "ConnectionPool | None" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "PsycopgDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, **kwargs: Any, ) -> None: """Initialize Psycopg synchronous configuration. Args: connection_config: Connection and pool configuration parameters (TypedDict or dict) connection_instance: Existing pool instance to use migration_config: Migration configuration statement_config: Default SQL statement configuration driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) **kwargs: Additional keyword arguments """ connection_config = normalize_connection_config(connection_config) statement_config = statement_config or default_statement_config statement_config, driver_features = apply_driver_features(statement_config, driver_features) # Extract user connection hook before storing driver_features features_dict = dict(driver_features) if driver_features else {} self._user_connection_hook: Callable[[PsycopgSyncConnection], None] | None = features_dict.pop( "on_connection_create", None ) self._pgvector_available: bool | None = None self._paradedb_available: bool | None = None super().__init__( connection_config=connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, **kwargs, )
def _create_pool(self) -> "ConnectionPool": """Create the actual connection pool.""" all_config = dict(self.connection_config) pool_parameters = { "min_size": all_config.pop("min_size", 4), "max_size": all_config.pop("max_size", None), "name": all_config.pop("name", None), "timeout": all_config.pop("timeout", 30.0), "max_waiting": all_config.pop("max_waiting", 0), "max_lifetime": all_config.pop("max_lifetime", 3600.0), "max_idle": all_config.pop("max_idle", 600.0), "reconnect_timeout": all_config.pop("reconnect_timeout", 300.0), "num_workers": all_config.pop("num_workers", 3), } pool_parameters["configure"] = all_config.pop("configure", self._configure_connection) pool_parameters = {k: v for k, v in pool_parameters.items() if v is not None} conninfo = all_config.pop("conninfo", None) if conninfo: pool = ConnectionPool(conninfo, open=True, **pool_parameters) else: kwargs = all_config.pop("kwargs", {}) all_config.update(kwargs) pool = ConnectionPool("", kwargs=all_config, open=True, **pool_parameters) return pool def _update_dialect_for_extensions(self) -> None: """Update statement_config dialect based on detected extensions. Priority: paradedb > pgvector > postgres (default). """ current_dialect = getattr(self.statement_config, "dialect", "postgres") if current_dialect != "postgres": return if self._paradedb_available: self.statement_config = self.statement_config.replace(dialect="paradedb") elif self._pgvector_available: self.statement_config = self.statement_config.replace(dialect="pgvector") def _configure_connection(self, conn: "PsycopgSyncConnection") -> None: autocommit_setting = self.connection_config.get("autocommit") if autocommit_setting is not None: conn.autocommit = autocommit_setting # Detect extensions on first connection, update dialect if self._pgvector_available is None: extensions = [ name for name, enabled in [ ("vector", self.driver_features.get("enable_pgvector", False)), ("pg_search", self.driver_features.get("enable_paradedb", False)), ] if enabled ] if extensions: try: cursor = conn.execute( "SELECT extname FROM pg_extension WHERE extname = ANY(%s::text[])", (extensions,) ) results = cursor.fetchall() detected = {r[0] for r in results} # type: ignore[index] self._pgvector_available = "vector" in detected self._paradedb_available = "pg_search" in detected except Exception: self._pgvector_available = False self._paradedb_available = False else: self._pgvector_available = False self._paradedb_available = False self._update_dialect_for_extensions() if self._pgvector_available: register_pgvector_sync(conn) # Ensure connection is not left in INTRANS state from extension detection or registration if not conn.autocommit: conn.rollback() # Call user-provided callback after internal setup if self._user_connection_hook is not None: self._user_connection_hook(conn) def _close_pool(self) -> None: """Close the actual connection pool.""" if not self.connection_instance: return try: self.connection_instance.close() finally: self.connection_instance = None
[docs] def create_connection(self) -> "PsycopgSyncConnection": """Create a single connection (not from pool). Returns: A psycopg Connection instance. """ if self.connection_instance is None: self.connection_instance = self.create_pool() return cast("PsycopgSyncConnection", self.connection_instance.getconn()) # pyright: ignore
[docs] def provide_connection(self, *args: Any, **kwargs: Any) -> "PsycopgSyncConnectionContext": """Provide a connection context manager. Args: *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: A psycopg Connection context manager. """ return PsycopgSyncConnectionContext(self)
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "PsycopgSyncSessionContext": """Provide a driver session context manager. Args: *_args: Additional arguments. statement_config: Optional statement configuration override. **_kwargs: Additional keyword arguments. Returns: A PsycopgSyncDriver session context manager. """ handler = _PsycopgSyncSessionConnectionHandler(self) return PsycopgSyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or (lambda: self.statement_config or default_statement_config), driver_features=self.driver_features, prepare_driver=self._prepare_driver, )
[docs] def provide_pool(self, *args: Any, **kwargs: Any) -> "ConnectionPool": """Provide pool instance. Returns: The connection pool. """ if not self.connection_instance: self.connection_instance = self.create_pool() return self.connection_instance
[docs] def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for Psycopg types. This provides all Psycopg-specific types that Litestar needs to recognize to avoid serialization attempts. Returns: Dictionary mapping type names to types. """ namespace = super().get_signature_namespace() namespace.update({ "PsycopgConnectionParams": PsycopgConnectionParams, "PsycopgPoolParams": PsycopgPoolParams, "PsycopgSyncConnectionContext": PsycopgSyncConnectionContext, "PsycopgSyncConnection": PsycopgSyncConnection, "PsycopgSyncCursor": PsycopgSyncCursor, "PsycopgSyncDriver": PsycopgSyncDriver, "PsycopgSyncExceptionHandler": PsycopgSyncExceptionHandler, "PsycopgSyncSessionContext": PsycopgSyncSessionContext, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": """Return polling defaults for PostgreSQL queue fallback.""" return EventRuntimeHints(poll_interval=0.5, select_for_update=True, skip_locked=True)
class PsycopgAsyncConnectionContext: """Async context manager for Psycopg connections.""" __slots__ = ("_config", "_ctx") def __init__(self, config: "PsycopgAsyncConfig") -> None: self._config = config self._ctx: Any = None async def __aenter__(self) -> "PsycopgAsyncConnection": if self._config.connection_instance is None: self._config.connection_instance = await self._config.create_pool() # pool.connection() returns an async context manager if self._config.connection_instance: self._ctx = self._config.connection_instance.connection() return cast("PsycopgAsyncConnection", await self._ctx.__aenter__()) msg = "Connection pool not initialized" raise ImproperConfigurationError(msg) async def __aexit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any ) -> bool | None: if self._ctx: return cast("bool | None", await self._ctx.__aexit__(exc_type, exc_val, exc_tb)) return None class _PsycopgAsyncSessionConnectionHandler: __slots__ = ("_config", "_ctx") def __init__(self, config: "PsycopgAsyncConfig") -> None: self._config = config self._ctx: Any = None async def acquire_connection(self) -> "PsycopgAsyncConnection": if self._config.connection_instance is None: self._config.connection_instance = await self._config.create_pool() self._ctx = self._config.connection_instance.connection() return cast("PsycopgAsyncConnection", await self._ctx.__aenter__()) async def release_connection(self, _conn: "PsycopgAsyncConnection") -> None: if self._ctx is None: return await self._ctx.__aexit__(None, None, None) self._ctx = None
[docs] @mypyc_attr(native_class=False) class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnectionPool, PsycopgAsyncDriver]): """Configuration for Psycopg asynchronous database connections with direct field-based configuration.""" driver_type: ClassVar[type[PsycopgAsyncDriver]] = PsycopgAsyncDriver connection_type: "ClassVar[type[PsycopgAsyncConnection]]" = PsycopgAsyncConnection supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True
[docs] def __init__( self, *, connection_config: "PsycopgPoolParams | dict[str, Any] | None" = None, connection_instance: "AsyncConnectionPool | None" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "PsycopgDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, **kwargs: Any, ) -> None: """Initialize Psycopg asynchronous configuration. Args: connection_config: Connection and pool configuration parameters (TypedDict or dict) connection_instance: Existing pool instance to use migration_config: Migration configuration statement_config: Default SQL statement configuration driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) **kwargs: Additional keyword arguments """ connection_config = normalize_connection_config(connection_config) statement_config = statement_config or default_statement_config statement_config, driver_features = apply_driver_features(statement_config, driver_features) # Extract user connection hook before storing driver_features features_dict = dict(driver_features) if driver_features else {} self._user_connection_hook: Callable[[PsycopgAsyncConnection], Awaitable[None]] | None = features_dict.pop( "on_connection_create", None ) self._pgvector_available: bool | None = None self._paradedb_available: bool | None = None super().__init__( connection_config=connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, **kwargs, )
async def _create_pool(self) -> "AsyncConnectionPool": """Create the actual async connection pool.""" all_config = dict(self.connection_config) pool_parameters = { "min_size": all_config.pop("min_size", 4), "max_size": all_config.pop("max_size", None), "name": all_config.pop("name", None), "timeout": all_config.pop("timeout", 30.0), "max_waiting": all_config.pop("max_waiting", 0), "max_lifetime": all_config.pop("max_lifetime", 3600.0), "max_idle": all_config.pop("max_idle", 600.0), "reconnect_timeout": all_config.pop("reconnect_timeout", 300.0), "num_workers": all_config.pop("num_workers", 3), } pool_parameters["configure"] = all_config.pop("configure", self._configure_async_connection) pool_parameters = {k: v for k, v in pool_parameters.items() if v is not None} conninfo = all_config.pop("conninfo", None) if conninfo: pool = AsyncConnectionPool(conninfo, open=False, **pool_parameters) else: kwargs = all_config.pop("kwargs", {}) all_config.update(kwargs) pool = AsyncConnectionPool("", kwargs=all_config, open=False, **pool_parameters) await pool.open() return pool def _update_dialect_for_extensions(self) -> None: """Update statement_config dialect based on detected extensions. Priority: paradedb > pgvector > postgres (default). """ current_dialect = getattr(self.statement_config, "dialect", "postgres") if current_dialect != "postgres": return if self._paradedb_available: self.statement_config = self.statement_config.replace(dialect="paradedb") elif self._pgvector_available: self.statement_config = self.statement_config.replace(dialect="pgvector") async def _configure_async_connection(self, conn: "PsycopgAsyncConnection") -> None: autocommit_setting = self.connection_config.get("autocommit") if autocommit_setting is not None: await conn.set_autocommit(autocommit_setting) # Detect extensions on first connection, update dialect if self._pgvector_available is None: extensions = [ name for name, enabled in [ ("vector", self.driver_features.get("enable_pgvector", False)), ("pg_search", self.driver_features.get("enable_paradedb", False)), ] if enabled ] if extensions: try: cursor = await conn.execute( "SELECT extname FROM pg_extension WHERE extname = ANY(%s::text[])", (extensions,) ) results = await cursor.fetchall() detected = {r[0] for r in results} # type: ignore[index] self._pgvector_available = "vector" in detected self._paradedb_available = "pg_search" in detected except Exception: self._pgvector_available = False self._paradedb_available = False else: self._pgvector_available = False self._paradedb_available = False self._update_dialect_for_extensions() if self._pgvector_available: await register_pgvector_async(conn) # Ensure connection is not left in INTRANS state from extension detection or registration if not conn.autocommit: await conn.rollback() # Call user-provided callback after internal setup if self._user_connection_hook is not None: await self._user_connection_hook(conn) async def _close_pool(self) -> None: """Close the actual async connection pool.""" if not self.connection_instance: return try: await self.connection_instance.close() finally: self.connection_instance = None
[docs] async def create_connection(self) -> "PsycopgAsyncConnection": # pyright: ignore """Create a single async connection (not from pool). Returns: A psycopg AsyncConnection instance. """ if self.connection_instance is None: self.connection_instance = await self.create_pool() return cast("PsycopgAsyncConnection", await self.connection_instance.getconn()) # pyright: ignore
[docs] def provide_connection(self, *args: Any, **kwargs: Any) -> "PsycopgAsyncConnectionContext": # pyright: ignore """Provide an async connection context manager. Args: *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: A psycopg AsyncConnection context manager. """ return PsycopgAsyncConnectionContext(self)
[docs] def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for PsycopgAsyncConfig types. Returns: Dictionary mapping type names to types. """ namespace = super().get_signature_namespace() namespace.update({ "PsycopgAsyncConnectionContext": PsycopgAsyncConnectionContext, "PsycopgAsyncConnection": PsycopgAsyncConnection, "PsycopgAsyncCursor": PsycopgAsyncCursor, "PsycopgAsyncDriver": PsycopgAsyncDriver, "PsycopgAsyncExceptionHandler": PsycopgAsyncExceptionHandler, "PsycopgAsyncSessionContext": PsycopgAsyncSessionContext, "PsycopgConnectionParams": PsycopgConnectionParams, "PsycopgPoolParams": PsycopgPoolParams, }) return namespace
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "PsycopgAsyncSessionContext": """Provide an async driver session context manager. Args: *_args: Additional arguments. statement_config: Optional statement configuration override. **_kwargs: Additional keyword arguments. Returns: A PsycopgAsyncDriver session context manager. """ handler = _PsycopgAsyncSessionConnectionHandler(self) return PsycopgAsyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or (lambda: self.statement_config or default_statement_config), driver_features=self.driver_features, prepare_driver=self._prepare_driver, )
[docs] async def provide_pool(self, *args: Any, **kwargs: Any) -> "AsyncConnectionPool": """Provide async pool instance. Returns: The async connection pool. """ if not self.connection_instance: self.connection_instance = await self.create_pool() return self.connection_instance
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": """Return polling defaults for PostgreSQL queue fallback.""" return EventRuntimeHints(poll_interval=0.5, select_for_update=True, skip_locked=True)