Source code for sqlspec.adapters.cockroach_asyncpg.config

"""CockroachDB AsyncPG configuration."""

from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from asyncpg import create_pool as asyncpg_create_pool
from typing_extensions import NotRequired

from sqlspec.adapters.asyncpg.core import apply_driver_features, build_connection_config, default_statement_config
from sqlspec.adapters.cockroach_asyncpg._typing import (
    CockroachAsyncpgConnection,
    CockroachAsyncpgPool,
    CockroachAsyncpgSessionContext,
)
from sqlspec.adapters.cockroach_asyncpg.core import CockroachAsyncpgRetryConfig
from sqlspec.adapters.cockroach_asyncpg.driver import CockroachAsyncpgDriver, CockroachAsyncpgExceptionHandler
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable

    from sqlspec.core import StatementConfig
    from sqlspec.observability import ObservabilityConfig

__all__ = (
    "CockroachAsyncpgConfig",
    "CockroachAsyncpgConnectionConfig",
    "CockroachAsyncpgDriverFeatures",
    "CockroachAsyncpgPoolConfig",
)


class CockroachAsyncpgConnectionConfig(TypedDict):
    """AsyncPG connection parameters for CockroachDB."""

    dsn: NotRequired[str]
    host: NotRequired[str]
    port: NotRequired[int]
    user: NotRequired[str]
    password: NotRequired[str]
    database: NotRequired[str]
    ssl: NotRequired[Any]
    passfile: NotRequired[str]
    direct_tls: NotRequired[bool]
    connect_timeout: NotRequired[float]
    command_timeout: NotRequired[float]
    statement_cache_size: NotRequired[int]
    max_cached_statement_lifetime: NotRequired[int]
    max_cacheable_statement_size: NotRequired[int]
    server_settings: NotRequired["dict[str, str]"]


[docs] class CockroachAsyncpgPoolConfig(CockroachAsyncpgConnectionConfig): """AsyncPG pool parameters for CockroachDB.""" min_size: NotRequired[int] max_size: NotRequired[int] max_queries: NotRequired[int] max_inactive_connection_lifetime: NotRequired[float] setup: NotRequired["Callable[[CockroachAsyncpgConnection], Awaitable[None]]"] init: NotRequired["Callable[[CockroachAsyncpgConnection], Awaitable[None]]"] extra: NotRequired["dict[str, Any]"]
class CockroachAsyncpgDriverFeatures(TypedDict): """Driver feature flags for CockroachDB AsyncPG adapter. on_connection_create: Async callback executed when a connection is acquired from pool. Receives the raw asyncpg connection for low-level driver configuration. Called after internal setup (JSON codecs, pgvector registration). """ enable_auto_retry: NotRequired[bool] max_retries: NotRequired[int] retry_delay_base_ms: NotRequired[float] retry_delay_max_ms: NotRequired[float] enable_retry_logging: NotRequired[bool] enable_follower_reads: NotRequired[bool] default_staleness: NotRequired[str] json_serializer: NotRequired["Callable[[Any], str]"] json_deserializer: NotRequired["Callable[[str], Any]"] enable_json_codecs: NotRequired[bool] enable_pgvector: NotRequired[bool] on_connection_create: "NotRequired[Callable[[CockroachAsyncpgConnection], Awaitable[None]]]" enable_events: NotRequired[bool] events_backend: NotRequired[str] class _CockroachAsyncpgSessionFactory: __slots__ = ("_config", "_ctx") def __init__(self, config: "CockroachAsyncpgConfig") -> None: self._config = config self._ctx: Any | None = None async def acquire_connection(self) -> "CockroachAsyncpgConnection": pool = self._config.connection_instance if pool is None: pool = await self._config.create_pool() self._config.connection_instance = pool ctx = pool.acquire() self._ctx = ctx return cast("CockroachAsyncpgConnection", await ctx.__aenter__()) async def release_connection(self, _conn: "CockroachAsyncpgConnection") -> None: if self._ctx is not None: await self._ctx.__aexit__(None, None, None) self._ctx = None class CockroachAsyncpgConnectionContext: """Async context manager for CockroachDB AsyncPG connections.""" __slots__ = ("_config", "_connection") def __init__(self, config: "CockroachAsyncpgConfig") -> None: self._config = config self._connection: CockroachAsyncpgConnection | None = None async def __aenter__(self) -> "CockroachAsyncpgConnection": pool = self._config.connection_instance if pool is None: pool = await self._config.create_pool() self._config.connection_instance = pool self._connection = await pool.acquire() return self._connection async def __aexit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any ) -> bool | None: if self._connection is not None: if self._config.connection_instance: await self._config.connection_instance.release(self._connection) # type: ignore[arg-type] self._connection = None return None
[docs] class CockroachAsyncpgConfig( AsyncDatabaseConfig[CockroachAsyncpgConnection, CockroachAsyncpgPool, CockroachAsyncpgDriver] ): """Configuration for CockroachDB using AsyncPG.""" driver_type: "ClassVar[type[CockroachAsyncpgDriver]]" = CockroachAsyncpgDriver connection_type: "ClassVar[type[CockroachAsyncpgConnection]]" = CockroachAsyncpgConnection # type: ignore[assignment] supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True
[docs] def __init__( self, *, connection_config: "CockroachAsyncpgPoolConfig | dict[str, Any] | None" = None, connection_instance: "CockroachAsyncpgPool | None" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "CockroachAsyncpgDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: connection_config = normalize_connection_config(connection_config) statement_config = statement_config or default_statement_config statement_config, driver_features = apply_driver_features(statement_config, driver_features) driver_features.setdefault("enable_auto_retry", True) _ = CockroachAsyncpgRetryConfig.from_features(driver_features) # Extract user connection hook before storing driver_features features_dict = dict(driver_features) if driver_features else {} self._user_connection_hook: Callable[[CockroachAsyncpgConnection], Awaitable[None]] | None = features_dict.pop( "on_connection_create", None ) super().__init__( connection_config=connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=features_dict, bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, **kwargs, )
async def _create_pool(self) -> "CockroachAsyncpgPool": config = build_connection_config(self.connection_config) config.setdefault("init", self._init_connection) return await asyncpg_create_pool(**config) async def _init_connection(self, connection: "CockroachAsyncpgConnection") -> None: """Initialize connection with user callback if provided.""" if self._user_connection_hook is not None: await self._user_connection_hook(connection) async def _close_pool(self) -> None: if not self.connection_instance: return await self.connection_instance.close() self.connection_instance = None
[docs] async def create_connection(self) -> "CockroachAsyncpgConnection": if self.connection_instance is None: self.connection_instance = await self.create_pool() return cast("CockroachAsyncpgConnection", await self.connection_instance.acquire())
[docs] def provide_connection(self, *args: Any, **kwargs: Any) -> "CockroachAsyncpgConnectionContext": return CockroachAsyncpgConnectionContext(self)
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, follower_reads: bool | None = None, staleness: str | None = None, **_kwargs: Any, ) -> "CockroachAsyncpgSessionContext": factory = _CockroachAsyncpgSessionFactory(self) driver_features = dict(self.driver_features) if follower_reads is not None: driver_features["enable_follower_reads"] = follower_reads if staleness is not None: driver_features["default_staleness"] = staleness return CockroachAsyncpgSessionContext( acquire_connection=factory.acquire_connection, release_connection=factory.release_connection, statement_config=statement_config or self.statement_config or default_statement_config, driver_features=driver_features, prepare_driver=self._prepare_driver, )
[docs] async def provide_pool(self, *args: Any, **kwargs: Any) -> "CockroachAsyncpgPool": if not self.connection_instance: self.connection_instance = await self.create_pool() return self.connection_instance
[docs] def get_signature_namespace(self) -> "dict[str, Any]": namespace = super().get_signature_namespace() namespace.update({ "CockroachAsyncpgConnectionConfig": CockroachAsyncpgConnectionConfig, "CockroachAsyncpgPoolConfig": CockroachAsyncpgPoolConfig, "CockroachAsyncpgDriver": CockroachAsyncpgDriver, "CockroachAsyncpgExceptionHandler": CockroachAsyncpgExceptionHandler, "CockroachAsyncpgSessionContext": CockroachAsyncpgSessionContext, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": return EventRuntimeHints(poll_interval=0.5, select_for_update=True, skip_locked=True)