Source code for sqlspec.adapters.cockroach_psycopg.config

"""CockroachDB configuration using psycopg."""

from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from psycopg import crdb as psycopg_crdb
from psycopg_pool import AsyncConnectionPool, ConnectionPool
from typing_extensions import NotRequired

from sqlspec.adapters.cockroach_psycopg._typing import (
    CockroachAsyncConnection,
    CockroachPsycopgAsyncSessionContext,
    CockroachPsycopgSyncSessionContext,
    CockroachSyncConnection,
)
from sqlspec.adapters.cockroach_psycopg.core import (
    CockroachPsycopgRetryConfig,
    apply_driver_features,
    build_statement_config,
)
from sqlspec.adapters.cockroach_psycopg.driver import (
    CockroachPsycopgAsyncDriver,
    CockroachPsycopgAsyncExceptionHandler,
    CockroachPsycopgSyncDriver,
    CockroachPsycopgSyncExceptionHandler,
)
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable

    from sqlspec.core import StatementConfig
    from sqlspec.observability import ObservabilityConfig

__all__ = (
    "CockroachPsycopgAsyncConfig",
    "CockroachPsycopgConnectionConfig",
    "CockroachPsycopgDriverFeatures",
    "CockroachPsycopgPoolConfig",
    "CockroachPsycopgSyncConfig",
)


class CockroachPsycopgConnectionConfig(TypedDict):
    """CockroachDB connection parameters."""

    conninfo: NotRequired[str]
    host: NotRequired[str]
    port: NotRequired[int]
    user: NotRequired[str]
    password: NotRequired[str]
    dbname: NotRequired[str]
    connect_timeout: NotRequired[int]
    options: NotRequired[str]
    application_name: NotRequired[str]
    sslmode: NotRequired[str]
    sslcert: NotRequired[str]
    sslkey: NotRequired[str]
    sslrootcert: NotRequired[str]
    autocommit: NotRequired[bool]
    cluster: NotRequired[str]
    extra: NotRequired["dict[str, Any]"]


class CockroachPsycopgPoolConfig(CockroachPsycopgConnectionConfig):
    """CockroachDB pool parameters."""

    min_size: NotRequired[int]
    max_size: NotRequired[int]
    name: NotRequired[str]
    timeout: NotRequired[float]
    max_waiting: NotRequired[int]
    max_lifetime: NotRequired[float]
    max_idle: NotRequired[float]
    reconnect_timeout: NotRequired[float]
    num_workers: NotRequired[int]
    configure: NotRequired["Callable[..., Any]"]
    kwargs: NotRequired["dict[str, Any]"]


class CockroachPsycopgDriverFeatures(TypedDict):
    """CockroachDB driver feature configuration.

    on_connection_create: Callback executed when a connection is acquired from pool.
        For sync: Callable[[CockroachSyncConnection], None]
        For async: Callable[[CockroachAsyncConnection], Awaitable[None]]
        Called after internal setup.
    """

    enable_auto_retry: NotRequired[bool]
    max_retries: NotRequired[int]
    retry_delay_base_ms: NotRequired[float]
    retry_delay_max_ms: NotRequired[float]
    enable_retry_logging: NotRequired[bool]
    enable_follower_reads: NotRequired[bool]
    default_staleness: NotRequired[str]
    prefer_uuid_keys: NotRequired[bool]
    json_serializer: NotRequired["Callable[[Any], str]"]
    json_deserializer: NotRequired["Callable[[str], Any]"]
    on_connection_create: "NotRequired[Callable[..., Any]]"
    enable_events: NotRequired[bool]
    events_backend: NotRequired[str]


class CockroachPsycopgSyncConnectionContext:
    """Context manager for CockroachDB psycopg connections."""

    __slots__ = ("_config", "_ctx")

    def __init__(self, config: "CockroachPsycopgSyncConfig") -> None:
        self._config = config
        self._ctx: Any = None

    def __enter__(self) -> "CockroachSyncConnection":
        if self._config.connection_instance:
            self._ctx = self._config.connection_instance.connection()
            return cast("CockroachSyncConnection", self._ctx.__enter__())
        self._ctx = self._config.create_connection()
        return cast("CockroachSyncConnection", self._ctx)

    def __exit__(
        self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any
    ) -> bool | None:
        if self._config.connection_instance and self._ctx:
            return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb))
        if self._ctx:
            self._ctx.close()
        return None


class _CockroachPsycopgSyncSessionConnectionHandler:
    __slots__ = ("_config", "_conn", "_ctx")

    def __init__(self, config: "CockroachPsycopgSyncConfig") -> None:
        self._config = config
        self._ctx: Any = None
        self._conn: CockroachSyncConnection | None = None

    def acquire_connection(self) -> "CockroachSyncConnection":
        if self._config.connection_instance:
            self._ctx = self._config.connection_instance.connection()
            return cast("CockroachSyncConnection", self._ctx.__enter__())
        self._conn = self._config.create_connection()
        return self._conn

    def release_connection(self, _conn: "CockroachSyncConnection") -> None:
        if self._ctx is not None:
            self._ctx.__exit__(None, None, None)
            self._ctx = None
            return
        if self._conn is not None:
            self._conn.close()
            self._conn = None


[docs] class CockroachPsycopgSyncConfig( SyncDatabaseConfig[CockroachSyncConnection, ConnectionPool, CockroachPsycopgSyncDriver] ): """Configuration for CockroachDB synchronous connections using psycopg.""" driver_type: "ClassVar[type[CockroachPsycopgSyncDriver]]" = CockroachPsycopgSyncDriver connection_type: "ClassVar[type[CockroachSyncConnection]]" = CockroachSyncConnection supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True
[docs] def __init__( self, *, connection_config: "CockroachPsycopgPoolConfig | dict[str, Any] | None" = None, connection_instance: "ConnectionPool | None" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "CockroachPsycopgDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: connection_config = normalize_connection_config(connection_config) statement_config = statement_config or build_statement_config() statement_config, driver_features = apply_driver_features(statement_config, driver_features) driver_features.setdefault("enable_auto_retry", True) _ = CockroachPsycopgRetryConfig.from_features(driver_features) # Extract user connection hook before storing driver_features features_dict = dict(driver_features) if driver_features else {} self._user_connection_hook: Callable[[CockroachSyncConnection], None] | None = features_dict.pop( "on_connection_create", None ) super().__init__( connection_config=connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, **kwargs, )
def _create_pool(self) -> "ConnectionPool": all_config = dict(self.connection_config) pool_parameters = { "min_size": all_config.pop("min_size", 4), "max_size": all_config.pop("max_size", None), "name": all_config.pop("name", None), "timeout": all_config.pop("timeout", 30.0), "max_waiting": all_config.pop("max_waiting", 0), "max_lifetime": all_config.pop("max_lifetime", 3600.0), "max_idle": all_config.pop("max_idle", 600.0), "reconnect_timeout": all_config.pop("reconnect_timeout", 300.0), "num_workers": all_config.pop("num_workers", 3), } pool_parameters["configure"] = all_config.pop("configure", self._configure_connection) pool_parameters = {k: v for k, v in pool_parameters.items() if v is not None} conninfo = all_config.pop("conninfo", None) if conninfo: return ConnectionPool(conninfo, open=True, connection_class=psycopg_crdb.CrdbConnection, **pool_parameters) kwargs = all_config.pop("kwargs", {}) all_config.update(kwargs) return ConnectionPool( "", kwargs=all_config, open=True, connection_class=psycopg_crdb.CrdbConnection, **pool_parameters ) def _configure_connection(self, conn: "CockroachSyncConnection") -> None: autocommit_setting = self.connection_config.get("autocommit") if autocommit_setting is not None: conn.autocommit = autocommit_setting # Call user-provided callback after internal setup if self._user_connection_hook is not None: self._user_connection_hook(conn) def _close_pool(self) -> None: if not self.connection_instance: return try: self.connection_instance.close() finally: self.connection_instance = None
[docs] def create_connection(self) -> "CockroachSyncConnection": if self.connection_instance is None: self.connection_instance = self.create_pool() return cast("CockroachSyncConnection", self.connection_instance.getconn())
[docs] def provide_connection(self, *args: Any, **kwargs: Any) -> "CockroachPsycopgSyncConnectionContext": return CockroachPsycopgSyncConnectionContext(self)
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, follower_reads: bool | None = None, staleness: str | None = None, **_kwargs: Any, ) -> "CockroachPsycopgSyncSessionContext": handler = _CockroachPsycopgSyncSessionConnectionHandler(self) driver_features = dict(self.driver_features) if follower_reads is not None: driver_features["enable_follower_reads"] = follower_reads if staleness is not None: driver_features["default_staleness"] = staleness return CockroachPsycopgSyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or self.statement_config or build_statement_config(), driver_features=driver_features, prepare_driver=self._prepare_driver, )
[docs] def provide_pool(self, *args: Any, **kwargs: Any) -> "ConnectionPool": if not self.connection_instance: self.connection_instance = self.create_pool() return self.connection_instance
[docs] def get_signature_namespace(self) -> "dict[str, Any]": namespace = super().get_signature_namespace() namespace.update({ "CockroachPsycopgConnectionConfig": CockroachPsycopgConnectionConfig, "CockroachPsycopgPoolConfig": CockroachPsycopgPoolConfig, "CockroachSyncConnection": CockroachSyncConnection, "CockroachPsycopgSyncDriver": CockroachPsycopgSyncDriver, "CockroachPsycopgSyncExceptionHandler": CockroachPsycopgSyncExceptionHandler, "CockroachPsycopgSyncSessionContext": CockroachPsycopgSyncSessionContext, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": return EventRuntimeHints(poll_interval=0.5, select_for_update=True, skip_locked=True)
class CockroachPsycopgAsyncConnectionContext: """Async context manager for CockroachDB psycopg connections.""" __slots__ = ("_config", "_ctx") def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: self._config = config self._ctx: Any = None async def __aenter__(self) -> "CockroachAsyncConnection": if self._config.connection_instance is None: self._config.connection_instance = await self._config.create_pool() if self._config.connection_instance: self._ctx = self._config.connection_instance.connection() return cast("CockroachAsyncConnection", await self._ctx.__aenter__()) msg = "Connection pool is not initialized" raise ImproperConfigurationError(msg) async def __aexit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any ) -> bool | None: if self._ctx: return cast("bool | None", await self._ctx.__aexit__(exc_type, exc_val, exc_tb)) return None class _CockroachPsycopgAsyncSessionConnectionHandler: __slots__ = ("_config", "_ctx") def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: self._config = config self._ctx: Any = None async def acquire_connection(self) -> "CockroachAsyncConnection": pool = self._config.connection_instance if pool is None: pool = await self._config.create_pool() self._config.connection_instance = pool ctx = pool.connection() self._ctx = ctx return cast("CockroachAsyncConnection", await ctx.__aenter__()) async def release_connection(self, _conn: "CockroachAsyncConnection") -> None: if self._ctx is not None: await self._ctx.__aexit__(None, None, None) self._ctx = None
[docs] class CockroachPsycopgAsyncConfig( AsyncDatabaseConfig[CockroachAsyncConnection, AsyncConnectionPool, CockroachPsycopgAsyncDriver] ): """Configuration for CockroachDB async connections using psycopg.""" driver_type: "ClassVar[type[CockroachPsycopgAsyncDriver]]" = CockroachPsycopgAsyncDriver connection_type: "ClassVar[type[CockroachAsyncConnection]]" = CockroachAsyncConnection supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True
[docs] def __init__( self, *, connection_config: "CockroachPsycopgPoolConfig | dict[str, Any] | None" = None, connection_instance: "AsyncConnectionPool | None" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "CockroachPsycopgDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: connection_config = normalize_connection_config(connection_config) statement_config = statement_config or build_statement_config() statement_config, driver_features = apply_driver_features(statement_config, driver_features) driver_features.setdefault("enable_auto_retry", True) _ = CockroachPsycopgRetryConfig.from_features(driver_features) # Extract user connection hook before storing driver_features features_dict = dict(driver_features) if driver_features else {} self._user_connection_hook: Callable[[CockroachAsyncConnection], Awaitable[None]] | None = features_dict.pop( "on_connection_create", None ) super().__init__( connection_config=connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, **kwargs, )
async def _create_pool(self) -> "AsyncConnectionPool": all_config = dict(self.connection_config) pool_parameters = { "min_size": all_config.pop("min_size", 4), "max_size": all_config.pop("max_size", None), "name": all_config.pop("name", None), "timeout": all_config.pop("timeout", 30.0), "max_waiting": all_config.pop("max_waiting", 0), "max_lifetime": all_config.pop("max_lifetime", 3600.0), "max_idle": all_config.pop("max_idle", 600.0), "reconnect_timeout": all_config.pop("reconnect_timeout", 300.0), "num_workers": all_config.pop("num_workers", 3), } pool_parameters["configure"] = all_config.pop("configure", self._configure_async_connection) pool_parameters = {k: v for k, v in pool_parameters.items() if v is not None} conninfo = all_config.pop("conninfo", None) if conninfo: pool = AsyncConnectionPool( conninfo, open=False, connection_class=psycopg_crdb.AsyncCrdbConnection, **pool_parameters ) else: kwargs = all_config.pop("kwargs", {}) all_config.update(kwargs) pool = AsyncConnectionPool( "", kwargs=all_config, open=False, connection_class=psycopg_crdb.AsyncCrdbConnection, **pool_parameters ) await pool.open() return cast("AsyncConnectionPool", pool) async def _configure_async_connection(self, conn: "CockroachAsyncConnection") -> None: autocommit_setting = self.connection_config.get("autocommit") if autocommit_setting is not None: await conn.set_autocommit(autocommit_setting) # Call user-provided callback after internal setup if self._user_connection_hook is not None: await self._user_connection_hook(conn) async def _close_pool(self) -> None: if not self.connection_instance: return try: await self.connection_instance.close() finally: self.connection_instance = None
[docs] async def create_connection(self) -> "CockroachAsyncConnection": if self.connection_instance is None: self.connection_instance = await self.create_pool() return cast("CockroachAsyncConnection", await self.connection_instance.getconn())
[docs] def provide_connection(self, *args: Any, **kwargs: Any) -> "CockroachPsycopgAsyncConnectionContext": return CockroachPsycopgAsyncConnectionContext(self)
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, follower_reads: bool | None = None, staleness: str | None = None, **_kwargs: Any, ) -> "CockroachPsycopgAsyncSessionContext": handler = _CockroachPsycopgAsyncSessionConnectionHandler(self) driver_features = dict(self.driver_features) if follower_reads is not None: driver_features["enable_follower_reads"] = follower_reads if staleness is not None: driver_features["default_staleness"] = staleness return CockroachPsycopgAsyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or self.statement_config or build_statement_config(), driver_features=driver_features, prepare_driver=self._prepare_driver, )
[docs] async def provide_pool(self, *args: Any, **kwargs: Any) -> "AsyncConnectionPool": if not self.connection_instance: self.connection_instance = await self.create_pool() return self.connection_instance
[docs] def get_signature_namespace(self) -> "dict[str, Any]": namespace = super().get_signature_namespace() namespace.update({ "CockroachAsyncConnection": CockroachAsyncConnection, "CockroachPsycopgAsyncDriver": CockroachPsycopgAsyncDriver, "CockroachPsycopgAsyncExceptionHandler": CockroachPsycopgAsyncExceptionHandler, "CockroachPsycopgAsyncSessionContext": CockroachPsycopgAsyncSessionContext, "CockroachPsycopgConnectionConfig": CockroachPsycopgConnectionConfig, "CockroachPsycopgPoolConfig": CockroachPsycopgPoolConfig, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": return EventRuntimeHints(poll_interval=0.5, select_for_update=True, skip_locked=True)