"""ADBC driver implementation for Arrow Database Connectivity.
Provides database connectivity through ADBC with support for multiple
database dialects, parameter style conversion, and transaction management.
"""
from typing import TYPE_CHECKING, Any, Literal, cast
from sqlspec.adapters.adbc._typing import AdbcCursor, AdbcSessionContext
from sqlspec.adapters.adbc.core import (
collect_rows,
create_mapped_exception,
detect_dialect,
driver_profile,
get_statement_config,
handle_postgres_rollback,
is_postgres_dialect,
normalize_postgres_empty_parameters,
normalize_script_rowcount,
prepare_postgres_parameters,
resolve_column_names,
resolve_dialect_name,
resolve_many_rowcount,
resolve_parameter_casts,
resolve_rowcount,
)
from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
from sqlspec.core import SQL, StatementConfig, build_arrow_result_from_table, get_cache_config, register_driver_profile
from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase
from sqlspec.exceptions import DatabaseConnectionError, SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import ensure_pyarrow
from sqlspec.utils.serializers import to_json
if TYPE_CHECKING:
from collections.abc import Callable
from sqlspec.adapters.adbc._typing import AdbcConnection, AdbcRawCursor
from sqlspec.builder import QueryBuilder
from sqlspec.core import ArrowResult, Statement, StatementFilter
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
from sqlspec.typing import ArrowReturnFormat, StatementParameters
__all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "AdbcSessionContext")
logger = get_logger("sqlspec.adapters.adbc")
class AdbcExceptionHandler(BaseSyncExceptionHandler):
"""Context manager for handling ADBC database exceptions.
ADBC propagates underlying database errors. Exception mapping
depends on the specific ADBC driver being used.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __exit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ()
def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
if exc_type is None:
return False
self.pending_exception = create_mapped_exception(exc_val)
return True
[docs]
class AdbcDriver(SyncDriverAdapterBase):
"""ADBC driver for Arrow Database Connectivity.
Provides database connectivity through ADBC with support for multiple
database dialects, parameter style conversion, and transaction management.
"""
__slots__ = (
"_column_name_cache",
"_data_dictionary",
"_detected_dialect",
"_dialect_name",
"_is_postgres",
"_json_serializer",
"dialect",
)
[docs]
def __init__(
self,
connection: "AdbcConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
self._detected_dialect = detect_dialect(connection, logger)
if statement_config is None:
base_config = get_statement_config(self._detected_dialect)
statement_config = base_config.replace(enable_caching=get_cache_config().compiled_cache_enabled)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self.dialect = statement_config.dialect
self._dialect_name = resolve_dialect_name(self.dialect)
self._is_postgres = is_postgres_dialect(self._dialect_name)
self._json_serializer = cast("Callable[[Any], str]", self.driver_features.get("json_serializer", to_json))
self._data_dictionary: AdbcDataDictionary | None = None
self._column_name_cache: dict[int, tuple[Any, list[str]]] = {}
# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def dispatch_execute(self, cursor: "AdbcRawCursor", statement: SQL) -> "ExecutionResult":
"""Execute single SQL statement.
Args:
cursor: Database cursor
statement: SQL statement to execute
Returns:
Execution result with data or row count
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
parameter_casts = resolve_parameter_casts(statement) if self._is_postgres else {}
try:
if self._is_postgres and parameter_casts:
execute_parameters = prepare_postgres_parameters(
prepared_parameters,
parameter_casts,
self.statement_config,
dialect=self._dialect_name,
json_serializer=self._json_serializer,
)
else:
execute_parameters = normalize_postgres_empty_parameters(self._dialect_name, prepared_parameters)
cursor.execute(sql, parameters=execute_parameters)
except Exception:
handle_postgres_rollback(self._dialect_name, cursor, logger)
raise
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)
if is_select_like:
fetched_data = cursor.fetchall()
column_names = self._resolve_column_names(cursor.description)
data, column_names = collect_rows(
cast("list[Any] | None", fetched_data), cursor.description, column_names=column_names
)
row_format = "dict" if data and isinstance(data[0], dict) else "tuple"
return self.create_execution_result(
cursor,
selected_data=data,
column_names=column_names,
data_row_count=len(data),
is_select_result=True,
row_format=row_format,
)
row_count = resolve_rowcount(cursor)
return self.create_execution_result(cursor, rowcount_override=row_count)
[docs]
def dispatch_execute_many(self, cursor: "AdbcRawCursor", statement: SQL) -> "ExecutionResult":
"""Execute SQL with multiple parameter sets.
Args:
cursor: Database cursor
statement: SQL statement to execute
Returns:
Execution result with row counts
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
try:
if not prepared_parameters:
cursor._rowcount = 0 # pyright: ignore[reportPrivateUsage]
row_count = 0
elif isinstance(prepared_parameters, (list, tuple)) and prepared_parameters:
parameter_count = len(prepared_parameters)
if self._is_postgres:
parameter_casts = resolve_parameter_casts(statement)
processed_params: list[Any] | tuple[Any, ...]
if parameter_casts:
processed_params = [
prepare_postgres_parameters(
param_set,
parameter_casts,
self.statement_config,
dialect=self._dialect_name,
json_serializer=self._json_serializer,
)
for param_set in prepared_parameters
]
else:
processed_params = prepared_parameters
cursor.executemany(sql, processed_params)
row_count = resolve_many_rowcount(cursor, processed_params, fallback_count=parameter_count)
else:
cursor.executemany(sql, prepared_parameters)
row_count = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count)
else:
cursor.executemany(sql, prepared_parameters)
row_count = resolve_rowcount(cursor)
except Exception:
handle_postgres_rollback(self._dialect_name, cursor, logger)
raise
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
[docs]
def dispatch_execute_script(self, cursor: "AdbcRawCursor", statement: "SQL") -> "ExecutionResult":
"""Execute SQL script containing multiple statements.
Args:
cursor: Database cursor
statement: SQL script to execute
Returns:
Execution result with statement counts
"""
prepared_parameters: Any | None = None
if statement.is_script:
sql = statement.raw_sql
else:
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_rowcount = 0
try:
for stmt in statements:
if prepared_parameters:
postgres_compatible_params = normalize_postgres_empty_parameters(
self._dialect_name, prepared_parameters
)
cursor.execute(stmt, parameters=postgres_compatible_params)
else:
cursor.execute(stmt)
successful_count += 1
last_rowcount = normalize_script_rowcount(last_rowcount, cursor)
except Exception:
handle_postgres_rollback(self._dialect_name, cursor, logger)
raise
return self.create_execution_result(
cursor,
statement_count=len(statements),
successful_statements=successful_count,
rowcount_override=last_rowcount,
is_script_result=True,
)
# ─────────────────────────────────────────────────────────────────────────────
# TRANSACTION MANAGEMENT
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def begin(self) -> None:
"""Begin database transaction."""
try:
with self.with_cursor(self.connection) as cursor:
cursor.execute("BEGIN")
except Exception as e:
msg = f"Failed to begin transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def commit(self) -> None:
"""Commit database transaction."""
try:
with self.with_cursor(self.connection) as cursor:
cursor.execute("COMMIT")
except Exception as e:
msg = f"Failed to commit transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def rollback(self) -> None:
"""Rollback database transaction."""
try:
with self.with_cursor(self.connection) as cursor:
cursor.execute("ROLLBACK")
except Exception as e:
msg = f"Failed to rollback transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
"""Create context manager for cursor.
Args:
connection: Database connection
Returns:
Cursor context manager
"""
return AdbcCursor(connection)
[docs]
def handle_database_exceptions(self) -> "AdbcExceptionHandler":
"""Handle database-specific exceptions and wrap them appropriately.
Returns:
Exception handler context manager
"""
return AdbcExceptionHandler()
# ─────────────────────────────────────────────────────────────────────────────
# ARROW API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_arrow(
self,
statement: "Statement | QueryBuilder",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
return_format: "ArrowReturnFormat" = "table",
native_only: bool = False,
batch_size: int | None = None,
arrow_schema: Any = None,
**kwargs: Any,
) -> "ArrowResult":
"""Execute query and return results as Apache Arrow (ADBC native path).
ADBC provides zero-copy Arrow support via cursor.fetch_arrow_table().
This is 5-10x faster than the conversion path for large datasets.
Args:
statement: SQL statement, string, or QueryBuilder
*parameters: Query parameters or filters
statement_config: Optional statement configuration override
return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch,
"batches" for list of RecordBatch, "reader" for RecordBatchReader
native_only: Ignored for ADBC (always uses native path)
batch_size: Batch size hint (for future streaming implementation)
arrow_schema: Optional pyarrow.Schema for type casting
**kwargs: Additional keyword arguments
Returns:
ArrowResult with native Arrow data
Example:
>>> result = driver.select_to_arrow(
... "SELECT * FROM users WHERE age > $1", 18
... )
>>> df = result.to_pandas() # Fast zero-copy conversion
"""
ensure_pyarrow()
# Prepare statement
config = statement_config or self.statement_config
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
exc_handler = self.handle_database_exceptions()
arrow_result: ArrowResult | None = None
# Use ADBC cursor for native Arrow
with self.with_cursor(self.connection) as cursor, exc_handler:
if cursor is None:
msg = "Failed to create cursor"
raise DatabaseConnectionError(msg)
# Get compiled SQL and parameters
sql, driver_params = self._get_compiled_sql(prepared_statement, config)
# Execute query
cursor.execute(sql, driver_params or ())
# Fetch as Arrow table (zero-copy!)
arrow_table = cursor.fetch_arrow_table()
arrow_result = build_arrow_result_from_table(
prepared_statement,
arrow_table,
return_format=return_format,
batch_size=batch_size,
arrow_schema=arrow_schema,
)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
if arrow_result is None:
msg = "Unreachable"
raise RuntimeError(msg) # pragma: no cover
return arrow_result
# ─────────────────────────────────────────────────────────────────────────────
# STORAGE API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_storage(
self,
statement: "Statement | QueryBuilder | SQL | str",
destination: "StorageDestination",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Stream query results to storage via the Arrow fast path."""
_ = kwargs
self._require_capability("arrow_export_enabled")
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
sync_pipeline = self._storage_pipeline()
telemetry_payload = self._write_result_to_storage_sync(
arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Ingest an Arrow payload directly through the ADBC cursor."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
ingest_mode: Literal["append", "create", "replace", "create_append"]
ingest_mode = "replace" if overwrite else "create_append"
exc_handler = self.handle_database_exceptions()
with self.with_cursor(self.connection) as cursor, exc_handler:
cursor.adbc_ingest(table, arrow_table, mode=ingest_mode)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Read an artifact from storage and ingest it via ADBC."""
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ─────────────────────────────────────────────────────────────────────────────
# UTILITY METHODS
# ─────────────────────────────────────────────────────────────────────────────
@property
def data_dictionary(self) -> "AdbcDataDictionary":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = AdbcDataDictionary()
return self._data_dictionary
# ─────────────────────────────────────────────────────────────────────────────
# PRIVATE/INTERNAL METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def collect_rows(self, cursor: "AdbcRawCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect ADBC rows for the direct execution path."""
column_names = self._resolve_column_names(cursor.description)
data, column_names = collect_rows(
cast("list[Any] | None", fetched), cursor.description, column_names=column_names
)
return data, column_names, len(data)
[docs]
def resolve_rowcount(self, cursor: "AdbcRawCursor") -> int:
"""Resolve rowcount from ADBC cursor for the direct execution path."""
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
"""Check if connection is in transaction.
ADBC uses explicit BEGIN and does not expose reliable transaction state.
Returns:
False - ADBC requires explicit transaction management.
"""
return False
def _resolve_column_names(self, description: Any) -> list[str]:
return resolve_column_names(description, self._column_name_cache)
[docs]
def prepare_driver_parameters(
self,
parameters: Any,
statement_config: "StatementConfig",
is_many: bool = False,
prepared_statement: Any | None = None,
) -> Any:
"""Prepare parameters with cast-aware type coercion for ADBC.
For PostgreSQL, applies cast-aware parameter processing using metadata from the compiled statement.
This allows proper handling of JSONB casts and other type conversions.
Respects driver_features['enable_cast_detection'] configuration.
Args:
parameters: Parameters in any format
statement_config: Statement configuration
is_many: Whether this is for execute_many operation
prepared_statement: Prepared statement containing the original SQL statement
Returns:
Parameters with cast-aware type coercion applied
"""
enable_cast_detection = self.driver_features.get("enable_cast_detection", True)
if enable_cast_detection and prepared_statement and self._is_postgres and not is_many:
parameter_casts = resolve_parameter_casts(prepared_statement)
return prepare_postgres_parameters(
parameters,
parameter_casts,
statement_config,
dialect=self._dialect_name,
json_serializer=self._json_serializer,
)
return super().prepare_driver_parameters(parameters, statement_config, is_many, prepared_statement)
register_driver_profile("adbc", driver_profile)