"""arrow-odbc database configuration."""
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from typing_extensions import NotRequired
from sqlspec.adapters.arrow_odbc._typing import ArrowOdbcConnection, ArrowOdbcSessionContext, arrow_odbc_connect
from sqlspec.adapters.arrow_odbc.core import apply_driver_features, build_connection_config, default_statement_config
from sqlspec.adapters.arrow_odbc.driver import ArrowOdbcDriver
from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig
from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config
if TYPE_CHECKING:
from types import TracebackType
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = ("ArrowOdbcConfig", "ArrowOdbcConnectionParams", "ArrowOdbcDriverFeatures")
[docs]
class ArrowOdbcConnectionParams(TypedDict):
"""arrow-odbc connection parameters."""
connection_string: NotRequired[str]
dsn: NotRequired[str]
driver: NotRequired[str]
server: NotRequired[str]
host: NotRequired[str]
database: NotRequired[str]
uid: NotRequired[str]
pwd: NotRequired[str]
user: NotRequired[str]
password: NotRequired[str]
login_timeout: NotRequired[int]
login_timeout_sec: NotRequired[int]
packet_size: NotRequired[int]
autocommit: NotRequired[bool]
extra: NotRequired[dict[str, Any]]
[docs]
class ArrowOdbcDriverFeatures(TypedDict):
"""arrow-odbc driver feature flags."""
chunk_size: NotRequired[int]
max_bytes_per_batch: NotRequired[int]
max_text_size: NotRequired[int]
max_binary_size: NotRequired[int]
fetch_concurrently: NotRequired[bool]
query_timeout_sec: NotRequired[int]
connection_string: NotRequired[str]
dbms_name: NotRequired[str]
enable_events: NotRequired[bool]
class ArrowOdbcConnectionContext(SyncPoolConnectionContext):
"""Context manager for arrow-odbc connections."""
__slots__ = ("_connection",)
def __init__(self, config: "ArrowOdbcConfig") -> None:
super().__init__(config)
self._connection: ArrowOdbcConnection | None = None
def __enter__(self) -> "ArrowOdbcConnection":
self._connection = self._config.create_connection()
return cast("ArrowOdbcConnection", self._connection)
def __exit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
) -> "bool | None":
if self._connection is not None:
_close_arrow_odbc_connection(self._connection)
self._connection = None
return None
class _ArrowOdbcSessionConnectionHandler(SyncPoolSessionFactory):
"""Session connection handler for no-pool arrow-odbc sessions."""
__slots__ = ("_connection",)
def __init__(self, config: "ArrowOdbcConfig") -> None:
super().__init__(config)
self._connection: ArrowOdbcConnection | None = None
def acquire_connection(self) -> "ArrowOdbcConnection":
self._connection = self._config.create_connection()
return cast("ArrowOdbcConnection", self._connection)
def release_connection(self, _conn: "ArrowOdbcConnection", **kwargs: Any) -> None:
if self._connection is None:
return
_close_arrow_odbc_connection(self._connection)
self._connection = None
[docs]
class ArrowOdbcConfig(NoPoolSyncConfig[ArrowOdbcConnection, ArrowOdbcDriver]):
"""Configuration for synchronous arrow-odbc connections."""
driver_type: "ClassVar[type[ArrowOdbcDriver]]" = ArrowOdbcDriver
connection_type: "ClassVar[type[ArrowOdbcConnection]]" = ArrowOdbcConnection
supports_transactional_ddl: "ClassVar[bool]" = False
supports_native_arrow_export: "ClassVar[bool]" = True
supports_native_arrow_import: "ClassVar[bool]" = True
supports_native_parquet_export: "ClassVar[bool]" = False
supports_native_parquet_import: "ClassVar[bool]" = False
_connection_context_class: "ClassVar[type[ArrowOdbcConnectionContext]]" = ArrowOdbcConnectionContext
_session_factory_class: "ClassVar[type[_ArrowOdbcSessionConnectionHandler]]" = _ArrowOdbcSessionConnectionHandler
_session_context_class: "ClassVar[type[ArrowOdbcSessionContext]]" = ArrowOdbcSessionContext
_default_statement_config = default_statement_config
[docs]
def __init__(
self,
*,
connection_config: "ArrowOdbcConnectionParams | dict[str, Any] | None" = None,
connection_instance: "Any" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "ArrowOdbcDriverFeatures | dict[str, Any] | None" = None,
bind_key: str | None = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
"""Initialize arrow-odbc configuration."""
self.connection_config = normalize_connection_config(connection_config)
features = apply_driver_features(dict(driver_features or {}))
connection_string = self.connection_config.get("connection_string")
if connection_string is not None:
features.setdefault("connection_string", str(connection_string))
elif self.connection_config.get("driver") is not None:
features.setdefault("dbms_name", str(self.connection_config["driver"]))
super().__init__(
connection_config=self.connection_config,
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=statement_config or default_statement_config,
driver_features=features,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
[docs]
def create_connection(self) -> "ArrowOdbcConnection":
"""Create and return a new arrow-odbc connection."""
if self.connection_instance is not None:
return cast("ArrowOdbcConnection", self.connection_instance)
connection_string, connect_kwargs = build_connection_config(self.connection_config)
try:
connection = cast("ArrowOdbcConnection", arrow_odbc_connect(connection_string, **connect_kwargs))
return cast("ArrowOdbcConnection", connection)
except Exception as exc:
msg = f"Could not configure arrow-odbc connection. Error: {exc}"
raise ImproperConfigurationError(msg) from exc
[docs]
def provide_connection(self, *args: Any, **kwargs: Any) -> "ArrowOdbcConnectionContext":
"""Provide a connection context manager."""
return ArrowOdbcConnectionContext(self)
[docs]
def provide_session(
self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any
) -> "ArrowOdbcSessionContext":
"""Provide a driver session context manager."""
handler = _ArrowOdbcSessionConnectionHandler(self)
return ArrowOdbcSessionContext(
acquire_connection=handler.acquire_connection,
release_connection=handler.release_connection,
statement_config=statement_config or self.statement_config,
driver_features=self.driver_features,
prepare_driver=self._prepare_driver,
)
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for ArrowOdbcConfig types."""
namespace = super().get_signature_namespace()
namespace.update({
"ArrowOdbcConfig": ArrowOdbcConfig,
"ArrowOdbcConnection": ArrowOdbcConnection,
"ArrowOdbcConnectionContext": ArrowOdbcConnectionContext,
"ArrowOdbcConnectionParams": ArrowOdbcConnectionParams,
"ArrowOdbcDriver": ArrowOdbcDriver,
"ArrowOdbcDriverFeatures": ArrowOdbcDriverFeatures,
"ArrowOdbcSessionContext": ArrowOdbcSessionContext,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
"""Return polling defaults suitable for generic ODBC sources."""
return EventRuntimeHints(poll_interval=2.0, lease_seconds=60, retention_seconds=172_800)
def _close_arrow_odbc_connection(connection: "ArrowOdbcConnection") -> None:
"""Close connection objects from compatible wrappers when they expose close()."""
close = getattr(connection, "close", None)
if close is not None:
close()