"""arrow-odbc sync driver."""
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from sqlspec.adapters.arrow_odbc._typing import ArrowOdbcConnection, ArrowOdbcCursor, ArrowOdbcError, ArrowOdbcRawCursor
from sqlspec.adapters.arrow_odbc.core import (
build_statement_config,
create_mapped_exception,
driver_profile,
resolve_dialect_from_dbms_name,
)
from sqlspec.adapters.arrow_odbc.data_dictionary import ArrowOdbcDataDictionary
from sqlspec.core import SQL, build_arrow_result_from_table, get_cache_config, register_driver_profile
from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase
from sqlspec.exceptions import ImproperConfigurationError, SQLSpecError
from sqlspec.utils.module_loader import ensure_pyarrow
if TYPE_CHECKING:
from sqlspec.builder import QueryBuilder
from sqlspec.core import ArrowResult, Statement, StatementConfig, StatementFilter
from sqlspec.driver import ExecutionResult
from sqlspec.typing import ArrowReturnFormat, StatementParameters
__all__ = ("ArrowOdbcCursor", "ArrowOdbcDriver", "ArrowOdbcExceptionHandler", "resolve_dialect_from_dbms_name")
class ArrowOdbcExceptionHandler(BaseSyncExceptionHandler):
"""Sync context manager handling arrow-odbc exceptions."""
__slots__ = ()
def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
if exc_type is None:
return False
if isinstance(exc_val, ArrowOdbcError):
self.pending_exception = create_mapped_exception(exc_val)
return True
return False
[docs]
class ArrowOdbcDriver(SyncDriverAdapterBase):
"""Sync driver for generic ODBC connections with Arrow-native transfer."""
__slots__ = ("_data_dictionary", "_dbms_name", "_dialect", "_statement_dialect", "dialect")
[docs]
def __init__(
self,
connection: "ArrowOdbcConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
features = dict(driver_features or {})
self._dbms_name = self._resolve_dbms_name(connection, features)
self._dialect = resolve_dialect_from_dbms_name(self._dbms_name)
self._statement_dialect = _statement_dialect_for(self._dialect)
if statement_config is None:
statement_config = build_statement_config(dialect=self._statement_dialect).replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
else:
statement_config = statement_config.replace(dialect=self._statement_dialect)
super().__init__(connection=connection, statement_config=statement_config, driver_features=features)
self.dialect = self._statement_dialect
self._data_dictionary: ArrowOdbcDataDictionary | None = None
@property
def data_dictionary(self) -> "ArrowOdbcDataDictionary":
if self._data_dictionary is None:
self._data_dictionary = ArrowOdbcDataDictionary(self._dialect)
return self._data_dictionary
[docs]
def dispatch_execute(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
parameters = _odbc_parameters(prepared_parameters)
if statement.returns_rows():
reader = self._read_arrow_batches(sql, parameters, self._chunk_size())
table = _reader_to_table(reader)
rows = table.to_pylist()
column_names = table.column_names
return self.create_execution_result(
cursor,
selected_data=rows,
column_names=column_names,
data_row_count=table.num_rows,
is_select_result=True,
row_format="dict",
)
cursor.execute(query=sql, parameters=parameters)
return self.create_execution_result(cursor, rowcount_override=0)
[docs]
def dispatch_execute_many(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") -> "ExecutionResult":
msg = "arrow-odbc does not expose a row-oriented executemany API; use bulk_insert_arrow() for Arrow ingestion."
raise NotImplementedError(msg)
[docs]
def dispatch_execute_script(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
parameters = _odbc_parameters(prepared_parameters)
successful_count = 0
for stmt in statements:
cursor.execute(query=stmt, parameters=parameters)
successful_count += 1
return self.create_execution_result(
cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
[docs]
def collect_rows(self, cursor: "ArrowOdbcRawCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
return fetched, [], len(fetched)
[docs]
def resolve_rowcount(self, cursor: "ArrowOdbcRawCursor") -> int:
return 0
[docs]
def begin(self) -> None:
statement = "BEGIN TRANSACTION" if self._dialect == "mssql" else "BEGIN"
self.connection.execute(statement)
[docs]
def commit(self) -> None:
try:
self.connection.commit()
except Exception as exc:
msg = f"Failed to commit transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def rollback(self) -> None:
try:
self.connection.rollback()
except Exception as exc:
msg = f"Failed to rollback transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def with_cursor(self, connection: "ArrowOdbcConnection") -> "ArrowOdbcCursor":
return ArrowOdbcCursor(connection)
[docs]
def handle_database_exceptions(self) -> "ArrowOdbcExceptionHandler":
return ArrowOdbcExceptionHandler()
def create_savepoint(self, name: str) -> None:
if self._dialect == "mssql":
self.execute_script(f"SAVE TRANSACTION {name}")
return
self.execute_script(f"SAVEPOINT {name}")
def release_savepoint(self, name: str) -> None:
if self._dialect == "mssql":
return
self.execute_script(f"RELEASE SAVEPOINT {name}")
def rollback_to_savepoint(self, name: str) -> None:
if self._dialect == "mssql":
self.execute_script(f"ROLLBACK TRANSACTION {name}")
return
self.execute_script(f"ROLLBACK TO SAVEPOINT {name}")
[docs]
def select_to_arrow(
self,
statement: "Statement | QueryBuilder",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
return_format: "ArrowReturnFormat" = "table",
native_only: bool = False,
batch_size: int | None = None,
arrow_schema: Any = None,
**kwargs: Any,
) -> "ArrowResult":
"""Execute a query and return native Arrow results."""
ensure_pyarrow()
config = statement_config or self.statement_config
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
sql, prepared_parameters = self._get_compiled_sql(prepared_statement, config)
resolved_batch_size = batch_size or self._chunk_size()
table: Any | None = None
exc_handler = self.handle_database_exceptions()
with exc_handler, self.with_cursor(self.connection):
reader = self._read_arrow_batches(sql, _odbc_parameters(prepared_parameters), resolved_batch_size)
table = _reader_to_table(reader)
self._check_pending_exception(exc_handler)
if table is None:
msg = "arrow-odbc did not return an Arrow table."
raise SQLSpecError(msg)
return build_arrow_result_from_table(
prepared_statement,
table,
return_format=return_format,
batch_size=resolved_batch_size,
arrow_schema=arrow_schema,
)
[docs]
def bulk_insert_arrow(self, target_table: str, source: Any, *, chunk_size: int | None = None) -> None:
"""Insert an Arrow table or reader into a database table."""
ensure_pyarrow()
import pyarrow as pa
resolved_chunk_size = chunk_size or self._chunk_size()
exc_handler = self.handle_database_exceptions()
with exc_handler, self.with_cursor(self.connection):
if isinstance(source, pa.Table) and hasattr(self.connection, "from_table_to_db"):
self.connection.from_table_to_db(source=source, target=target_table, chunk_size=resolved_chunk_size)
self._check_pending_exception(exc_handler)
return
reader = _table_to_reader(source, resolved_chunk_size) if isinstance(source, pa.Table) else source
if hasattr(self.connection, "insert_into_table"):
self.connection.insert_into_table(reader=reader, table=target_table, chunk_size=resolved_chunk_size)
self._check_pending_exception(exc_handler)
return
self._check_pending_exception(exc_handler)
msg = "arrow-odbc connection does not expose table import APIs."
raise ImproperConfigurationError(msg)
def _read_arrow_batches(self, sql: str, parameters: "list[str | None] | None", batch_size: int) -> Any:
kwargs: dict[str, Any] = {
"query": sql,
"batch_size": batch_size,
"parameters": parameters,
"max_bytes_per_batch": self.driver_features.get("max_bytes_per_batch"),
"max_text_size": self.driver_features.get("max_text_size"),
"max_binary_size": self.driver_features.get("max_binary_size"),
"fetch_concurrently": self.driver_features.get("fetch_concurrently", True),
}
query_timeout_sec = self.driver_features.get("query_timeout_sec")
if query_timeout_sec is not None:
kwargs["query_timeout_sec"] = query_timeout_sec
return self.connection.read_arrow_batches(**kwargs)
def _chunk_size(self) -> int:
return int(self.driver_features.get("chunk_size") or 65_536)
@staticmethod
def _resolve_dbms_name(connection: "ArrowOdbcConnection", features: "dict[str, Any]") -> str | None:
dbms_name = getattr(connection, "dbms_name", None)
if dbms_name:
return str(dbms_name)
dbms_name = features.get("dbms_name")
if dbms_name:
return str(dbms_name)
connection_string = features.get("connection_string")
if connection_string:
return str(connection_string)
return None
def _statement_dialect_for(dialect: str) -> str:
if dialect == "mssql":
return "tsql"
return dialect
def _unwrap_parameter(value: Any) -> Any:
wrapped = getattr(value, "value", value)
return None if wrapped is None else str(wrapped)
def _odbc_parameters(parameters: Any) -> "list[str | None] | None":
if parameters is None:
return None
if isinstance(parameters, Mapping):
return [_unwrap_parameter(value) for value in parameters.values()]
if isinstance(parameters, (list, tuple)):
if not parameters:
return None
return [_unwrap_parameter(value) for value in parameters]
return [_unwrap_parameter(parameters)]
def _reader_to_table(reader: Any) -> Any:
ensure_pyarrow()
import pyarrow as pa
if isinstance(reader, pa.Table):
return reader
if hasattr(reader, "read_all"):
return reader.read_all()
batches = list(reader)
if not batches:
return pa.table({})
return pa.Table.from_batches(batches)
def _table_to_reader(table: Any, chunk_size: int) -> Any:
ensure_pyarrow()
import pyarrow as pa
return pa.RecordBatchReader.from_batches(table.schema, table.to_batches(max_chunksize=chunk_size))
register_driver_profile("arrow_odbc", driver_profile)