Source code for sqlspec.adapters.adbc.config

"""ADBC database configuration."""

from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from typing_extensions import NotRequired

from sqlspec.adapters.adbc._typing import AdbcConnection
from sqlspec.adapters.adbc.core import (
    apply_driver_features,
    build_connection_config,
    detect_postgres_extensions,
    get_statement_config,
    is_postgres_dialect,
    resolve_dialect_from_config,
    resolve_driver_connect_func,
)
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, AdbcExceptionHandler, AdbcSessionContext
from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig
from sqlspec.core import StatementConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.observability import ObservabilityConfig


__all__ = ("AdbcConfig", "AdbcConnectionParams", "AdbcDriverFeatures")


class AdbcConnectionParams(TypedDict):
    """ADBC connection parameters."""

    uri: NotRequired[str]
    driver_name: NotRequired[str]
    db_kwargs: NotRequired[dict[str, Any]]
    conn_kwargs: NotRequired[dict[str, Any]]
    adbc_driver_manager_entrypoint: NotRequired[str]
    autocommit: NotRequired[bool]
    isolation_level: NotRequired[str]
    batch_size: NotRequired[int]
    query_timeout: NotRequired[float]
    connection_timeout: NotRequired[float]
    ssl_mode: NotRequired[str]
    ssl_cert: NotRequired[str]
    ssl_key: NotRequired[str]
    ssl_ca: NotRequired[str]
    username: NotRequired[str]
    password: NotRequired[str]
    token: NotRequired[str]
    project_id: NotRequired[str]
    dataset_id: NotRequired[str]
    account: NotRequired[str]
    warehouse: NotRequired[str]
    database: NotRequired[str]
    schema: NotRequired[str]
    role: NotRequired[str]
    authorization_header: NotRequired[str]
    grpc_options: NotRequired[dict[str, Any]]
    gizmosql_backend: NotRequired[str]
    tls_skip_verify: NotRequired[bool]
    extra: NotRequired[dict[str, Any]]


[docs] class AdbcDriverFeatures(TypedDict): """ADBC driver feature configuration. Controls optional type handling and serialization behavior for the ADBC adapter. These features configure how data is converted between Python and Arrow types. Attributes: json_serializer: JSON serialization function to use. Callable that takes Any and returns str (JSON string). Default: sqlspec.utils.serializers.to_json enable_cast_detection: Enable cast-aware parameter processing. When True, detects SQL casts (e.g., ::JSONB) and applies appropriate serialization. Currently used for PostgreSQL JSONB handling. Default: True enable_strict_type_coercion: Enforce strict type coercion rules. When True, raises errors for unsupported type conversions. When False, attempts best-effort conversion. Default: False strict_type_coercion: Alias for enable_strict_type_coercion. enable_arrow_extension_types: Enable PyArrow extension type support. When True, preserves Arrow extension type metadata when reading data. When False, falls back to storage types. Default: True arrow_extension_types: Alias for enable_arrow_extension_types. enable_pgvector: Enable automatic pgvector extension detection. When True and the resolved dialect is PostgreSQL, queries ``pg_extension`` on the first connection to check for the ``vector`` extension. Defaults to True when the ``pgvector`` Python package is installed. enable_paradedb: Enable ParadeDB (pg_search) extension detection. When True and the resolved dialect is PostgreSQL, queries ``pg_extension`` on the first connection to check for the ``pg_search`` extension. Defaults to True. Independent of enable_pgvector. enable_events: Enable database event channel support. Defaults to True when extension_config["events"] is configured. Provides pub/sub capabilities via table-backed queue (ADBC has no native pub/sub). Requires extension_config["events"] for migration setup. events_backend: Event channel backend selection. Only option: "table_queue" (durable table-backed queue with retries and exactly-once delivery). ADBC does not have native pub/sub, so table_queue is the only backend. Defaults to "table_queue". """ json_serializer: "NotRequired[Callable[[Any], str]]" enable_cast_detection: NotRequired[bool] enable_strict_type_coercion: NotRequired[bool] strict_type_coercion: NotRequired[bool] enable_arrow_extension_types: NotRequired[bool] arrow_extension_types: NotRequired[bool] enable_pgvector: NotRequired[bool] enable_paradedb: NotRequired[bool] enable_events: NotRequired[bool] events_backend: NotRequired[str]
class AdbcConnectionContext: """Context manager for ADBC connections.""" __slots__ = ("_config", "_connection") def __init__(self, config: "AdbcConfig") -> None: self._config = config self._connection: AdbcConnection | None = None def __enter__(self) -> "AdbcConnection": self._connection = self._config.create_connection() return self._connection def __exit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any ) -> bool | None: if self._connection: self._connection.close() self._connection = None return None class _AdbcSessionConnectionHandler: __slots__ = ("_config", "_connection") def __init__(self, config: "AdbcConfig") -> None: self._config = config self._connection: AdbcConnection | None = None def acquire_connection(self) -> "AdbcConnection": self._connection = self._config.create_connection() return self._connection def release_connection(self, _conn: "AdbcConnection") -> None: if self._connection is None: return self._connection.close() self._connection = None
[docs] class AdbcConfig(NoPoolSyncConfig[AdbcConnection, AdbcDriver]): """ADBC configuration for Arrow Database Connectivity. ADBC provides an interface for connecting to multiple database systems with Arrow-native data transfer. Supports multiple database backends including PostgreSQL, SQLite, DuckDB, BigQuery, and Snowflake with automatic driver detection and loading. """ driver_type: ClassVar[type[AdbcDriver]] = AdbcDriver connection_type: "ClassVar[type[AdbcConnection]]" = AdbcConnection supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True storage_partition_strategies: "ClassVar[tuple[str, ...]]" = ("fixed", "rows_per_chunk")
[docs] def __init__( self, *, connection_config: "AdbcConnectionParams | dict[str, Any] | None" = None, connection_instance: "Any" = None, migration_config: "dict[str, Any] | None" = None, statement_config: StatementConfig | None = None, driver_features: "AdbcDriverFeatures | dict[str, Any] | None" = None, bind_key: str | None = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: """Initialize configuration. Args: connection_config: Connection configuration parameters connection_instance: Pre-created connection instance to use instead of creating new one migration_config: Migration configuration statement_config: Default SQL statement configuration driver_features: Driver feature configuration (AdbcDriverFeatures) bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers **kwargs: Additional keyword arguments passed to the base configuration. """ self.connection_config = normalize_connection_config(connection_config) self._pgvector_available: bool | None = None self._paradedb_available: bool | None = None if statement_config is None: statement_config = get_statement_config(resolve_dialect_from_config(self.connection_config)) statement_config, driver_features = apply_driver_features(statement_config, driver_features) super().__init__( connection_config=self.connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=driver_features, bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, **kwargs, )
[docs] def create_connection(self) -> AdbcConnection: """Create and return a new connection using the specified driver. Returns: A new connection instance. Raises: ImproperConfigurationError: If the connection could not be established. """ try: connection = resolve_driver_connect_func( self.connection_config.get("driver_name"), self.connection_config.get("uri") )(**build_connection_config(self.connection_config)) return cast("AdbcConnection", connection) except Exception as e: driver_name = self.connection_config.get("driver_name", "Unknown") msg = f"Could not configure connection using driver '{driver_name}'. Error: {e}" raise ImproperConfigurationError(msg) from e
def _update_dialect_for_extensions(self) -> None: """Update statement_config dialect based on detected extensions. Priority: paradedb > pgvector > postgres (default). Only switches when current dialect is ``postgres``. """ current_dialect = getattr(self.statement_config, "dialect", "postgres") if current_dialect != "postgres": return if self._paradedb_available: self.statement_config = self.statement_config.replace(dialect="paradedb") elif self._pgvector_available: self.statement_config = self.statement_config.replace(dialect="pgvector") def _detect_extensions_if_needed(self) -> None: """Detect postgres extensions on first call, caching results. Only queries ``pg_extension`` when the resolved dialect is PostgreSQL and detection has not yet run (``_pgvector_available is None``). """ if self._pgvector_available is not None: return dialect = getattr(self.statement_config, "dialect", "") if not is_postgres_dialect(dialect): self._pgvector_available = False self._paradedb_available = False return connection = self.create_connection() try: self._pgvector_available, self._paradedb_available = detect_postgres_extensions( connection, enable_pgvector=self.driver_features.get("enable_pgvector", False), enable_paradedb=self.driver_features.get("enable_paradedb", False), ) finally: connection.close() self._update_dialect_for_extensions()
[docs] def provide_connection(self, *args: Any, **kwargs: Any) -> "AdbcConnectionContext": """Provide a connection context manager. Args: *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: A connection context manager. """ return AdbcConnectionContext(self)
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "AdbcSessionContext": """Provide a driver session context manager. On first call with a PostgreSQL backend, detects pgvector/paradedb extensions and updates the dialect accordingly. Args: *_args: Additional arguments. statement_config: Optional statement configuration override. **_kwargs: Additional keyword arguments. Returns: A context manager that yields an AdbcDriver instance. """ self._detect_extensions_if_needed() statement_config = ( statement_config or self.statement_config or get_statement_config(resolve_dialect_from_config(self.connection_config)) ) handler = _AdbcSessionConnectionHandler(self) return AdbcSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config, driver_features=self.driver_features, prepare_driver=self._prepare_driver, )
[docs] def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for AdbcConfig types. Returns: Dictionary mapping type names to types. """ namespace = super().get_signature_namespace() namespace.update({ "AdbcConnectionContext": AdbcConnectionContext, "AdbcConnection": AdbcConnection, "AdbcConnectionParams": AdbcConnectionParams, "AdbcCursor": AdbcCursor, "AdbcDriver": AdbcDriver, "AdbcDriverFeatures": AdbcDriverFeatures, "AdbcExceptionHandler": AdbcExceptionHandler, "AdbcSessionContext": AdbcSessionContext, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": """Return polling defaults suitable for ADBC warehouses.""" return EventRuntimeHints(poll_interval=2.0, lease_seconds=60, retention_seconds=172_800)