Source code for sqlspec.adapters.arrow_odbc.config

"""arrow-odbc database configuration."""

from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from typing_extensions import NotRequired

from sqlspec.adapters.arrow_odbc._typing import ArrowOdbcConnection, ArrowOdbcSessionContext, arrow_odbc_connect
from sqlspec.adapters.arrow_odbc.core import apply_driver_features, build_connection_config, default_statement_config
from sqlspec.adapters.arrow_odbc.driver import ArrowOdbcDriver
from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig
from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
    from types import TracebackType

    from sqlspec.core import StatementConfig
    from sqlspec.observability import ObservabilityConfig

__all__ = ("ArrowOdbcConfig", "ArrowOdbcConnectionParams", "ArrowOdbcDriverFeatures")


[docs] class ArrowOdbcConnectionParams(TypedDict): """arrow-odbc connection parameters.""" connection_string: NotRequired[str] dsn: NotRequired[str] driver: NotRequired[str] server: NotRequired[str] host: NotRequired[str] database: NotRequired[str] uid: NotRequired[str] pwd: NotRequired[str] user: NotRequired[str] password: NotRequired[str] login_timeout: NotRequired[int] login_timeout_sec: NotRequired[int] packet_size: NotRequired[int] autocommit: NotRequired[bool] extra: NotRequired[dict[str, Any]]
[docs] class ArrowOdbcDriverFeatures(TypedDict): """arrow-odbc driver feature flags.""" chunk_size: NotRequired[int] max_bytes_per_batch: NotRequired[int] max_text_size: NotRequired[int] max_binary_size: NotRequired[int] fetch_concurrently: NotRequired[bool] query_timeout_sec: NotRequired[int] connection_string: NotRequired[str] dbms_name: NotRequired[str] enable_events: NotRequired[bool]
class ArrowOdbcConnectionContext(SyncPoolConnectionContext): """Context manager for arrow-odbc connections.""" __slots__ = ("_connection",) def __init__(self, config: "ArrowOdbcConfig") -> None: super().__init__(config) self._connection: ArrowOdbcConnection | None = None def __enter__(self) -> "ArrowOdbcConnection": self._connection = self._config.create_connection() return cast("ArrowOdbcConnection", self._connection) def __exit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: _close_arrow_odbc_connection(self._connection) self._connection = None return None class _ArrowOdbcSessionConnectionHandler(SyncPoolSessionFactory): """Session connection handler for no-pool arrow-odbc sessions.""" __slots__ = ("_connection",) def __init__(self, config: "ArrowOdbcConfig") -> None: super().__init__(config) self._connection: ArrowOdbcConnection | None = None def acquire_connection(self) -> "ArrowOdbcConnection": self._connection = self._config.create_connection() return cast("ArrowOdbcConnection", self._connection) def release_connection(self, _conn: "ArrowOdbcConnection", **kwargs: Any) -> None: if self._connection is None: return _close_arrow_odbc_connection(self._connection) self._connection = None
[docs] class ArrowOdbcConfig(NoPoolSyncConfig[ArrowOdbcConnection, ArrowOdbcDriver]): """Configuration for synchronous arrow-odbc connections.""" driver_type: "ClassVar[type[ArrowOdbcDriver]]" = ArrowOdbcDriver connection_type: "ClassVar[type[ArrowOdbcConnection]]" = ArrowOdbcConnection supports_transactional_ddl: "ClassVar[bool]" = False supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = False supports_native_parquet_import: "ClassVar[bool]" = False _connection_context_class: "ClassVar[type[ArrowOdbcConnectionContext]]" = ArrowOdbcConnectionContext _session_factory_class: "ClassVar[type[_ArrowOdbcSessionConnectionHandler]]" = _ArrowOdbcSessionConnectionHandler _session_context_class: "ClassVar[type[ArrowOdbcSessionContext]]" = ArrowOdbcSessionContext _default_statement_config = default_statement_config
[docs] def __init__( self, *, connection_config: "ArrowOdbcConnectionParams | dict[str, Any] | None" = None, connection_instance: "Any" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "ArrowOdbcDriverFeatures | dict[str, Any] | None" = None, bind_key: str | None = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: """Initialize arrow-odbc configuration.""" self.connection_config = normalize_connection_config(connection_config) features = apply_driver_features(dict(driver_features or {})) connection_string = self.connection_config.get("connection_string") if connection_string is not None: features.setdefault("connection_string", str(connection_string)) elif self.connection_config.get("driver") is not None: features.setdefault("dbms_name", str(self.connection_config["driver"])) super().__init__( connection_config=self.connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=statement_config or default_statement_config, driver_features=features, bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, **kwargs, )
[docs] def create_connection(self) -> "ArrowOdbcConnection": """Create and return a new arrow-odbc connection.""" if self.connection_instance is not None: return cast("ArrowOdbcConnection", self.connection_instance) connection_string, connect_kwargs = build_connection_config(self.connection_config) try: connection = cast("ArrowOdbcConnection", arrow_odbc_connect(connection_string, **connect_kwargs)) return cast("ArrowOdbcConnection", connection) except Exception as exc: msg = f"Could not configure arrow-odbc connection. Error: {exc}" raise ImproperConfigurationError(msg) from exc
[docs] def provide_connection(self, *args: Any, **kwargs: Any) -> "ArrowOdbcConnectionContext": """Provide a connection context manager.""" return ArrowOdbcConnectionContext(self)
[docs] def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "ArrowOdbcSessionContext": """Provide a driver session context manager.""" handler = _ArrowOdbcSessionConnectionHandler(self) return ArrowOdbcSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, statement_config=statement_config or self.statement_config, driver_features=self.driver_features, prepare_driver=self._prepare_driver, )
[docs] def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for ArrowOdbcConfig types.""" namespace = super().get_signature_namespace() namespace.update({ "ArrowOdbcConfig": ArrowOdbcConfig, "ArrowOdbcConnection": ArrowOdbcConnection, "ArrowOdbcConnectionContext": ArrowOdbcConnectionContext, "ArrowOdbcConnectionParams": ArrowOdbcConnectionParams, "ArrowOdbcDriver": ArrowOdbcDriver, "ArrowOdbcDriverFeatures": ArrowOdbcDriverFeatures, "ArrowOdbcSessionContext": ArrowOdbcSessionContext, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": """Return polling defaults suitable for generic ODBC sources.""" return EventRuntimeHints(poll_interval=2.0, lease_seconds=60, retention_seconds=172_800)
def _close_arrow_odbc_connection(connection: "ArrowOdbcConnection") -> None: """Close connection objects from compatible wrappers when they expose close().""" close = getattr(connection, "close", None) if close is not None: close()