Source code for sqlspec.adapters.adbc.config

"""ADBC database configuration."""

from collections.abc import Callable
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from typing_extensions import NotRequired

from sqlspec.adapters.adbc._types import AdbcConnection
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, AdbcExceptionHandler, get_adbc_statement_config
from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig
from sqlspec.core import StatementConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events._hints import EventRuntimeHints
from sqlspec.utils.config_normalization import normalize_connection_config
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import import_string
from sqlspec.utils.serializers import to_json

if TYPE_CHECKING:
    from collections.abc import Generator
    from contextlib import AbstractContextManager

    from sqlglot.dialects.dialect import DialectType

    from sqlspec.observability import ObservabilityConfig

logger = get_logger("adapters.adbc")

_DRIVER_ALIASES: dict[str, str] = {
    "sqlite": "adbc_driver_sqlite.dbapi.connect",
    "sqlite3": "adbc_driver_sqlite.dbapi.connect",
    "duckdb": "adbc_driver_duckdb.dbapi.connect",
    "postgres": "adbc_driver_postgresql.dbapi.connect",
    "postgresql": "adbc_driver_postgresql.dbapi.connect",
    "pg": "adbc_driver_postgresql.dbapi.connect",
    "snowflake": "adbc_driver_snowflake.dbapi.connect",
    "sf": "adbc_driver_snowflake.dbapi.connect",
    "bigquery": "adbc_driver_bigquery.dbapi.connect",
    "bq": "adbc_driver_bigquery.dbapi.connect",
    "flightsql": "adbc_driver_flightsql.dbapi.connect",
    "grpc": "adbc_driver_flightsql.dbapi.connect",
}

_URI_PREFIX_DRIVER: tuple[tuple[str, str], ...] = (
    ("postgresql://", "adbc_driver_postgresql.dbapi.connect"),
    ("postgres://", "adbc_driver_postgresql.dbapi.connect"),
    ("sqlite://", "adbc_driver_sqlite.dbapi.connect"),
    ("duckdb://", "adbc_driver_duckdb.dbapi.connect"),
    ("grpc://", "adbc_driver_flightsql.dbapi.connect"),
    ("snowflake://", "adbc_driver_snowflake.dbapi.connect"),
    ("bigquery://", "adbc_driver_bigquery.dbapi.connect"),
)

_DRIVER_PATH_KEYWORDS_TO_DIALECT: tuple[tuple[str, str], ...] = (
    ("postgresql", "postgres"),
    ("sqlite", "sqlite"),
    ("duckdb", "duckdb"),
    ("bigquery", "bigquery"),
    ("snowflake", "snowflake"),
    ("flightsql", "sqlite"),
    ("grpc", "sqlite"),
)

_PARAMETER_STYLES_BY_KEYWORD: tuple[tuple[str, tuple[tuple[str, ...], str]], ...] = (
    ("postgresql", (("numeric",), "numeric")),
    ("sqlite", (("qmark", "named_colon"), "qmark")),
    ("duckdb", (("qmark", "numeric"), "qmark")),
    ("bigquery", (("named_at",), "named_at")),
    ("snowflake", (("qmark", "numeric"), "qmark")),
)

_BIGQUERY_DB_KWARGS_FIELDS: tuple[str, ...] = ("project_id", "dataset_id", "token")


class AdbcConnectionParams(TypedDict):
    """ADBC connection parameters."""

    uri: NotRequired[str]
    driver_name: NotRequired[str]
    db_kwargs: NotRequired[dict[str, Any]]
    conn_kwargs: NotRequired[dict[str, Any]]
    adbc_driver_manager_entrypoint: NotRequired[str]
    autocommit: NotRequired[bool]
    isolation_level: NotRequired[str]
    batch_size: NotRequired[int]
    query_timeout: NotRequired[float]
    connection_timeout: NotRequired[float]
    ssl_mode: NotRequired[str]
    ssl_cert: NotRequired[str]
    ssl_key: NotRequired[str]
    ssl_ca: NotRequired[str]
    username: NotRequired[str]
    password: NotRequired[str]
    token: NotRequired[str]
    project_id: NotRequired[str]
    dataset_id: NotRequired[str]
    account: NotRequired[str]
    warehouse: NotRequired[str]
    database: NotRequired[str]
    schema: NotRequired[str]
    role: NotRequired[str]
    authorization_header: NotRequired[str]
    grpc_options: NotRequired[dict[str, Any]]
    extra: NotRequired[dict[str, Any]]


class AdbcDriverFeatures(TypedDict):
    """ADBC driver feature configuration.

    Controls optional type handling and serialization behavior for the ADBC adapter.
    These features configure how data is converted between Python and Arrow types.

    Attributes:
        json_serializer: JSON serialization function to use.
            Callable that takes Any and returns str (JSON string).
            Default: sqlspec.utils.serializers.to_json
        enable_cast_detection: Enable cast-aware parameter processing.
            When True, detects SQL casts (e.g., ::JSONB) and applies appropriate
            serialization. Currently used for PostgreSQL JSONB handling.
            Default: True
        strict_type_coercion: Enforce strict type coercion rules.
            When True, raises errors for unsupported type conversions.
            When False, attempts best-effort conversion.
            Default: False
        arrow_extension_types: Enable PyArrow extension type support.
            When True, preserves Arrow extension type metadata when reading data.
            When False, falls back to storage types.
            Default: True
        enable_events: Enable database event channel support.
            Defaults to True when extension_config["events"] is configured.
            Provides pub/sub capabilities via table-backed queue (ADBC has no native pub/sub).
            Requires extension_config["events"] for migration setup.
        events_backend: Event channel backend selection.
            Only option: "table_queue" (durable table-backed queue with retries and exactly-once delivery).
            ADBC does not have native pub/sub, so table_queue is the only backend.
            Defaults to "table_queue".
    """

    json_serializer: "NotRequired[Callable[[Any], str]]"
    enable_cast_detection: NotRequired[bool]
    strict_type_coercion: NotRequired[bool]
    arrow_extension_types: NotRequired[bool]


__all__ = ("AdbcConfig", "AdbcConnectionParams", "AdbcDriverFeatures")


[docs] class AdbcConfig(NoPoolSyncConfig[AdbcConnection, AdbcDriver]): """ADBC configuration for Arrow Database Connectivity. ADBC provides an interface for connecting to multiple database systems with Arrow-native data transfer. Supports multiple database backends including PostgreSQL, SQLite, DuckDB, BigQuery, and Snowflake with automatic driver detection and loading. """ driver_type: ClassVar[type[AdbcDriver]] = AdbcDriver connection_type: "ClassVar[type[AdbcConnection]]" = AdbcConnection supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True storage_partition_strategies: "ClassVar[tuple[str, ...]]" = ("fixed", "rows_per_chunk")
[docs] def __init__( self, *, connection_config: AdbcConnectionParams | dict[str, Any] | None = None, connection_instance: "Any" = None, migration_config: dict[str, Any] | None = None, statement_config: StatementConfig | None = None, driver_features: "AdbcDriverFeatures | dict[str, Any] | None" = None, bind_key: str | None = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: """Initialize configuration. Args: connection_config: Connection configuration parameters connection_instance: Pre-created connection instance to use instead of creating new one migration_config: Migration configuration statement_config: Default SQL statement configuration driver_features: Driver feature configuration (AdbcDriverFeatures) bind_key: Optional unique identifier for this configuration extension_config: Extension-specific configuration (e.g., Litestar plugin settings) observability_config: Adapter-level observability overrides for lifecycle hooks and observers **kwargs: Additional keyword arguments passed to the base configuration. """ self.connection_config = normalize_connection_config(connection_config) if statement_config is None: detected_dialect = str(self._get_dialect() or "sqlite") statement_config = get_adbc_statement_config(detected_dialect) processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {} json_serializer = processed_driver_features.setdefault("json_serializer", to_json) processed_driver_features.setdefault("enable_cast_detection", True) processed_driver_features.setdefault("strict_type_coercion", False) processed_driver_features.setdefault("arrow_extension_types", True) if json_serializer is not None: statement_config = _apply_json_serializer_to_statement_config(statement_config, json_serializer) super().__init__( connection_config=self.connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config, driver_features=processed_driver_features, bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, **kwargs, )
def _resolve_driver_name(self) -> str: """Resolve and normalize the driver name. Returns: The normalized driver connect function path. """ driver_name = self.connection_config.get("driver_name") uri = self.connection_config.get("uri") if isinstance(driver_name, str): lowered_driver = driver_name.lower() alias = _DRIVER_ALIASES.get(lowered_driver) if alias is not None: return alias return _normalize_driver_path(driver_name) if isinstance(uri, str): resolved = _driver_from_uri(uri) if resolved is not None: return resolved return "adbc_driver_sqlite.dbapi.connect" def _get_connect_func(self) -> Callable[..., AdbcConnection]: """Get the driver connect function. Returns: The driver connect function. Raises: ImproperConfigurationError: If driver cannot be loaded. """ driver_path = self._resolve_driver_name() try: connect_func = import_string(driver_path) except ImportError as e: msg = f"Failed to import connect function from '{driver_path}'. Is the driver installed? Error: {e}" raise ImproperConfigurationError(msg) from e if not callable(connect_func): msg = f"The path '{driver_path}' did not resolve to a callable function." raise ImproperConfigurationError(msg) return cast("Callable[..., AdbcConnection]", connect_func) def _get_dialect(self) -> "DialectType": """Get the SQL dialect type based on the driver. Returns: The SQL dialect type for the driver. """ driver_path = self._resolve_driver_name() for keyword, dialect in _DRIVER_PATH_KEYWORDS_TO_DIALECT: if keyword in driver_path: return dialect return None def _get_parameter_styles(self) -> tuple[tuple[str, ...], str]: """Get parameter styles based on the underlying driver. Returns: Tuple of (supported_parameter_styles, default_parameter_style) """ try: driver_path = self._resolve_driver_name() for keyword, styles in _PARAMETER_STYLES_BY_KEYWORD: if keyword in driver_path: return styles except Exception: # pylint: disable=broad-exception-caught logger.debug("Error resolving parameter styles, using defaults", exc_info=True) return (("qmark",), "qmark")
[docs] def create_connection(self) -> AdbcConnection: """Create and return a new connection using the specified driver. Returns: A new connection instance. Raises: ImproperConfigurationError: If the connection could not be established. """ try: connect_func = self._get_connect_func() connection_config_dict = self._get_connection_config_dict() connection = connect_func(**connection_config_dict) except Exception as e: driver_name = self.connection_config.get("driver_name", "Unknown") msg = f"Could not configure connection using driver '{driver_name}'. Error: {e}" raise ImproperConfigurationError(msg) from e return connection
[docs] @contextmanager def provide_connection(self, *args: Any, **kwargs: Any) -> "Generator[AdbcConnection, None, None]": """Provide a connection context manager. Args: *args: Additional arguments. **kwargs: Additional keyword arguments. Yields: A connection instance. """ connection = self.create_connection() try: yield connection finally: connection.close()
[docs] def provide_session( self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any ) -> "AbstractContextManager[AdbcDriver]": """Provide a driver session context manager. Args: *args: Additional arguments. statement_config: Optional statement configuration override. **kwargs: Additional keyword arguments. Returns: A context manager that yields an AdbcDriver instance. """ @contextmanager def session_manager() -> "Generator[AdbcDriver, None, None]": with self.provide_connection(*args, **kwargs) as connection: final_statement_config = ( statement_config or self.statement_config or get_adbc_statement_config(str(self._get_dialect() or "sqlite")) ) driver = self.driver_type( connection=connection, statement_config=final_statement_config, driver_features=self.driver_features ) yield self._prepare_driver(driver) return session_manager()
def _get_connection_config_dict(self) -> dict[str, Any]: """Get the connection configuration dictionary. Returns: The connection configuration dictionary. """ config = dict(self.connection_config) driver_name = config.get("driver_name") uri = config.get("uri") driver_kind: str | None = None if isinstance(driver_name, str): driver_kind = _driver_kind_from_driver_name(driver_name) if driver_kind is None and isinstance(uri, str): driver_kind = _driver_kind_from_uri(uri) if isinstance(uri, str) and driver_kind == "sqlite" and uri.startswith("sqlite://"): config["uri"] = uri[9:] if isinstance(uri, str) and driver_kind == "duckdb" and uri.startswith("duckdb://"): config["path"] = uri[9:] config.pop("uri", None) if isinstance(driver_name, str) and driver_kind == "bigquery": db_kwargs = config.get("db_kwargs") db_kwargs_dict: dict[str, Any] = dict(db_kwargs) if isinstance(db_kwargs, dict) else {} for param in _BIGQUERY_DB_KWARGS_FIELDS: if param in config: db_kwargs_dict[param] = config.pop(param) if db_kwargs_dict: config["db_kwargs"] = db_kwargs_dict elif isinstance(driver_name, str) and "db_kwargs" in config and driver_kind != "bigquery": db_kwargs = config.pop("db_kwargs") if isinstance(db_kwargs, dict): config.update(db_kwargs) config.pop("driver_name", None) return config
[docs] def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for types. Returns: Dictionary mapping type names to types. """ namespace = super().get_signature_namespace() namespace.update({ "AdbcConnection": AdbcConnection, "AdbcConnectionParams": AdbcConnectionParams, "AdbcCursor": AdbcCursor, "AdbcDriver": AdbcDriver, "AdbcExceptionHandler": AdbcExceptionHandler, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": """Return polling defaults suitable for ADBC warehouses.""" return EventRuntimeHints(poll_interval=2.0, lease_seconds=60, retention_seconds=172_800)
def _apply_json_serializer_to_statement_config( statement_config: "StatementConfig", json_serializer: "Callable[[Any], str]" ) -> "StatementConfig": """Apply a JSON serializer to statement config while preserving list/tuple converters. Args: statement_config: Base statement configuration to update. json_serializer: JSON serializer function. Returns: Updated statement configuration. """ parameter_config = statement_config.parameter_config previous_list_converter = parameter_config.type_coercion_map.get(list) previous_tuple_converter = parameter_config.type_coercion_map.get(tuple) updated_parameter_config = parameter_config.with_json_serializers(json_serializer) updated_map = dict(updated_parameter_config.type_coercion_map) if previous_list_converter is not None: updated_map[list] = previous_list_converter if previous_tuple_converter is not None: updated_map[tuple] = previous_tuple_converter return statement_config.replace(parameter_config=updated_parameter_config.replace(type_coercion_map=updated_map)) def _normalize_driver_path(driver_name: str) -> str: """Normalize a driver name to an importable connect function path. Args: driver_name: Driver name or dotted import path. Returns: A dotted path to a driver connect function. """ stripped = driver_name.strip() if stripped.endswith(".dbapi.connect"): return stripped if stripped.endswith(".dbapi"): return f"{stripped}.connect" if "." in stripped: return stripped return f"{stripped}.dbapi.connect" def _driver_from_uri(uri: str) -> str | None: """Resolve a default driver connect path from a URI. Args: uri: Connection URI. Returns: Dotted connect function path if a scheme matches, otherwise None. """ for prefix, driver_path in _URI_PREFIX_DRIVER: if uri.startswith(prefix): return driver_path return None def _driver_kind_from_driver_name(driver_name: str) -> str | None: """Return a canonical driver kind based on driver name content. Args: driver_name: Driver name or dotted path. Returns: Canonical driver kind string or None. """ resolved = _DRIVER_ALIASES.get(driver_name.lower(), driver_name) lowered = resolved.lower() for keyword, _dialect in _DRIVER_PATH_KEYWORDS_TO_DIALECT: if keyword in lowered: return keyword return None def _driver_kind_from_uri(uri: str) -> str | None: """Return a canonical driver kind based on URI scheme. Args: uri: Connection URI. Returns: Canonical driver kind string or None. """ for prefix, driver_path in _URI_PREFIX_DRIVER: if uri.startswith(prefix): return _driver_kind_from_driver_name(driver_path) return None