"""ADBC driver implementation for Arrow Database Connectivity.
Provides database connectivity through ADBC with support for multiple
database dialects, parameter style conversion, and transaction management.
"""
import contextlib
from typing import TYPE_CHECKING, Any, Literal, cast
from typing_extensions import Self
from sqlspec.adapters.adbc._typing import AdbcSessionContext
from sqlspec.adapters.adbc.core import (
collect_rows,
create_mapped_exception,
detect_dialect,
driver_profile,
get_statement_config,
handle_postgres_rollback,
is_postgres_dialect,
normalize_postgres_empty_parameters,
normalize_script_rowcount,
prepare_postgres_parameters,
resolve_column_names,
resolve_dialect_name,
resolve_many_rowcount,
resolve_parameter_casts,
resolve_rowcount,
)
from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
from sqlspec.core import SQL, StatementConfig, build_arrow_result_from_table, get_cache_config, register_driver_profile
from sqlspec.driver import SyncDriverAdapterBase
from sqlspec.exceptions import DatabaseConnectionError, SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import ensure_pyarrow
from sqlspec.utils.serializers import to_json
if TYPE_CHECKING:
from collections.abc import Callable
from adbc_driver_manager.dbapi import Cursor
from sqlspec.adapters.adbc._typing import AdbcConnection
from sqlspec.builder import QueryBuilder
from sqlspec.core import ArrowResult, Statement, StatementFilter
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
from sqlspec.typing import ArrowReturnFormat, StatementParameters
__all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "AdbcSessionContext")
logger = get_logger("sqlspec.adapters.adbc")
class AdbcCursor:
"""Context manager for cursor management."""
__slots__ = ("connection", "cursor")
def __init__(self, connection: "AdbcConnection") -> None:
self.connection = connection
self.cursor: Cursor | None = None
def __enter__(self) -> "Cursor":
self.cursor = self.connection.cursor()
return self.cursor
def __exit__(self, *_: Any) -> None:
if self.cursor is not None:
with contextlib.suppress(Exception):
self.cursor.close() # type: ignore[no-untyped-call]
class AdbcExceptionHandler:
"""Context manager for handling ADBC database exceptions.
ADBC propagates underlying database errors. Exception mapping
depends on the specific ADBC driver being used.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __exit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ("pending_exception",)
def __init__(self) -> None:
self.pending_exception: Exception | None = None
def __enter__(self) -> Self:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
_ = exc_tb
if exc_type is None:
return False
self.pending_exception = create_mapped_exception(exc_val)
return True
[docs]
class AdbcDriver(SyncDriverAdapterBase):
"""ADBC driver for Arrow Database Connectivity.
Provides database connectivity through ADBC with support for multiple
database dialects, parameter style conversion, and transaction management.
"""
__slots__ = (
"_column_name_cache",
"_data_dictionary",
"_detected_dialect",
"_dialect_name",
"_is_postgres",
"_json_serializer",
"dialect",
)
[docs]
def __init__(
self,
connection: "AdbcConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
self._detected_dialect = detect_dialect(connection, logger)
if statement_config is None:
base_config = get_statement_config(self._detected_dialect)
statement_config = base_config.replace(enable_caching=get_cache_config().compiled_cache_enabled)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self.dialect = statement_config.dialect
self._dialect_name = resolve_dialect_name(self.dialect)
self._is_postgres = is_postgres_dialect(self._dialect_name)
self._json_serializer = cast("Callable[[Any], str]", self.driver_features.get("json_serializer", to_json))
self._data_dictionary: AdbcDataDictionary | None = None
self._column_name_cache: dict[int, tuple[Any, list[str]]] = {}
# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def dispatch_execute(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
"""Execute single SQL statement.
Args:
cursor: Database cursor
statement: SQL statement to execute
Returns:
Execution result with data or row count
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
parameter_casts = resolve_parameter_casts(statement) if self._is_postgres else {}
try:
if self._is_postgres and parameter_casts:
execute_parameters = prepare_postgres_parameters(
prepared_parameters,
parameter_casts,
self.statement_config,
dialect=self._dialect_name,
json_serializer=self._json_serializer,
)
else:
execute_parameters = normalize_postgres_empty_parameters(self._dialect_name, prepared_parameters)
cursor.execute(sql, parameters=execute_parameters)
except Exception:
handle_postgres_rollback(self._dialect_name, cursor, logger)
raise
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)
if is_select_like:
fetched_data = cursor.fetchall()
column_names = self._resolve_column_names(cursor.description)
data, column_names = collect_rows(
cast("list[Any] | None", fetched_data), cursor.description, column_names=column_names
)
row_format = "dict" if data and isinstance(data[0], dict) else "tuple"
return self.create_execution_result(
cursor,
selected_data=data,
column_names=column_names,
data_row_count=len(data),
is_select_result=True,
row_format=row_format,
)
row_count = resolve_rowcount(cursor)
return self.create_execution_result(cursor, rowcount_override=row_count)
[docs]
def dispatch_execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
"""Execute SQL with multiple parameter sets.
Args:
cursor: Database cursor
statement: SQL statement to execute
Returns:
Execution result with row counts
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
try:
if not prepared_parameters:
cursor._rowcount = 0 # pyright: ignore[reportPrivateUsage]
row_count = 0
elif isinstance(prepared_parameters, (list, tuple)) and prepared_parameters:
parameter_count = len(prepared_parameters)
if self._is_postgres:
parameter_casts = resolve_parameter_casts(statement)
processed_params: list[Any] | tuple[Any, ...]
if parameter_casts:
processed_params = [
prepare_postgres_parameters(
param_set,
parameter_casts,
self.statement_config,
dialect=self._dialect_name,
json_serializer=self._json_serializer,
)
for param_set in prepared_parameters
]
else:
processed_params = prepared_parameters
cursor.executemany(sql, processed_params)
row_count = resolve_many_rowcount(cursor, processed_params, fallback_count=parameter_count)
else:
cursor.executemany(sql, prepared_parameters)
row_count = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count)
else:
cursor.executemany(sql, prepared_parameters)
row_count = resolve_rowcount(cursor)
except Exception:
handle_postgres_rollback(self._dialect_name, cursor, logger)
raise
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
[docs]
def dispatch_execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult":
"""Execute SQL script containing multiple statements.
Args:
cursor: Database cursor
statement: SQL script to execute
Returns:
Execution result with statement counts
"""
prepared_parameters: Any | None = None
if statement.is_script:
sql = statement.raw_sql
else:
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, self.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_rowcount = 0
try:
for stmt in statements:
if prepared_parameters:
postgres_compatible_params = normalize_postgres_empty_parameters(
self._dialect_name, prepared_parameters
)
cursor.execute(stmt, parameters=postgres_compatible_params)
else:
cursor.execute(stmt)
successful_count += 1
last_rowcount = normalize_script_rowcount(last_rowcount, cursor)
except Exception:
handle_postgres_rollback(self._dialect_name, cursor, logger)
raise
return self.create_execution_result(
cursor,
statement_count=len(statements),
successful_statements=successful_count,
rowcount_override=last_rowcount,
is_script_result=True,
)
# ─────────────────────────────────────────────────────────────────────────────
# TRANSACTION MANAGEMENT
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def begin(self) -> None:
"""Begin database transaction."""
try:
with self.with_cursor(self.connection) as cursor:
cursor.execute("BEGIN")
except Exception as e:
msg = f"Failed to begin transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def commit(self) -> None:
"""Commit database transaction."""
try:
with self.with_cursor(self.connection) as cursor:
cursor.execute("COMMIT")
except Exception as e:
msg = f"Failed to commit transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def rollback(self) -> None:
"""Rollback database transaction."""
try:
with self.with_cursor(self.connection) as cursor:
cursor.execute("ROLLBACK")
except Exception as e:
msg = f"Failed to rollback transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor":
"""Create context manager for cursor.
Args:
connection: Database connection
Returns:
Cursor context manager
"""
return AdbcCursor(connection)
[docs]
def handle_database_exceptions(self) -> "AdbcExceptionHandler":
"""Handle database-specific exceptions and wrap them appropriately.
Returns:
Exception handler context manager
"""
return AdbcExceptionHandler()
# ─────────────────────────────────────────────────────────────────────────────
# ARROW API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_arrow(
self,
statement: "Statement | QueryBuilder",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
return_format: "ArrowReturnFormat" = "table",
native_only: bool = False,
batch_size: int | None = None,
arrow_schema: Any = None,
**kwargs: Any,
) -> "ArrowResult":
"""Execute query and return results as Apache Arrow (ADBC native path).
ADBC provides zero-copy Arrow support via cursor.fetch_arrow_table().
This is 5-10x faster than the conversion path for large datasets.
Args:
statement: SQL statement, string, or QueryBuilder
*parameters: Query parameters or filters
statement_config: Optional statement configuration override
return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch,
"batches" for list of RecordBatch, "reader" for RecordBatchReader
native_only: Ignored for ADBC (always uses native path)
batch_size: Batch size hint (for future streaming implementation)
arrow_schema: Optional pyarrow.Schema for type casting
**kwargs: Additional keyword arguments
Returns:
ArrowResult with native Arrow data
Example:
>>> result = driver.select_to_arrow(
... "SELECT * FROM users WHERE age > $1", 18
... )
>>> df = result.to_pandas() # Fast zero-copy conversion
"""
ensure_pyarrow()
# Prepare statement
config = statement_config or self.statement_config
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
exc_handler = self.handle_database_exceptions()
arrow_result: ArrowResult | None = None
# Use ADBC cursor for native Arrow
with self.with_cursor(self.connection) as cursor, exc_handler:
if cursor is None:
msg = "Failed to create cursor"
raise DatabaseConnectionError(msg)
# Get compiled SQL and parameters
sql, driver_params = self._get_compiled_sql(prepared_statement, config)
# Execute query
cursor.execute(sql, driver_params or ())
# Fetch as Arrow table (zero-copy!)
arrow_table = cursor.fetch_arrow_table()
arrow_result = build_arrow_result_from_table(
prepared_statement,
arrow_table,
return_format=return_format,
batch_size=batch_size,
arrow_schema=arrow_schema,
)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
if arrow_result is None:
msg = "Unreachable"
raise RuntimeError(msg) # pragma: no cover
return arrow_result
# ─────────────────────────────────────────────────────────────────────────────
# STORAGE API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_storage(
self,
statement: "Statement | QueryBuilder | SQL | str",
destination: "StorageDestination",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Stream query results to storage via the Arrow fast path."""
_ = kwargs
self._require_capability("arrow_export_enabled")
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
sync_pipeline = self._storage_pipeline()
telemetry_payload = self._write_result_to_storage_sync(
arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Ingest an Arrow payload directly through the ADBC cursor."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
ingest_mode: Literal["append", "create", "replace", "create_append"]
ingest_mode = "replace" if overwrite else "create_append"
exc_handler = self.handle_database_exceptions()
with self.with_cursor(self.connection) as cursor, exc_handler:
cursor.adbc_ingest(table, arrow_table, mode=ingest_mode)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Read an artifact from storage and ingest it via ADBC."""
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ─────────────────────────────────────────────────────────────────────────────
# UTILITY METHODS
# ─────────────────────────────────────────────────────────────────────────────
@property
def data_dictionary(self) -> "AdbcDataDictionary":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = AdbcDataDictionary()
return self._data_dictionary
# ─────────────────────────────────────────────────────────────────────────────
# PRIVATE/INTERNAL METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect ADBC rows for the direct execution path."""
column_names = self._resolve_column_names(cursor.description)
data, column_names = collect_rows(
cast("list[Any] | None", fetched), cursor.description, column_names=column_names
)
return data, column_names, len(data)
[docs]
def resolve_rowcount(self, cursor: Any) -> int:
"""Resolve rowcount from ADBC cursor for the direct execution path."""
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
"""Check if connection is in transaction.
ADBC uses explicit BEGIN and does not expose reliable transaction state.
Returns:
False - ADBC requires explicit transaction management.
"""
return False
def _resolve_column_names(self, description: Any) -> list[str]:
return resolve_column_names(description, self._column_name_cache)
[docs]
def prepare_driver_parameters(
self,
parameters: Any,
statement_config: "StatementConfig",
is_many: bool = False,
prepared_statement: Any | None = None,
) -> Any:
"""Prepare parameters with cast-aware type coercion for ADBC.
For PostgreSQL, applies cast-aware parameter processing using metadata from the compiled statement.
This allows proper handling of JSONB casts and other type conversions.
Respects driver_features['enable_cast_detection'] configuration.
Args:
parameters: Parameters in any format
statement_config: Statement configuration
is_many: Whether this is for execute_many operation
prepared_statement: Prepared statement containing the original SQL statement
Returns:
Parameters with cast-aware type coercion applied
"""
enable_cast_detection = self.driver_features.get("enable_cast_detection", True)
if enable_cast_detection and prepared_statement and self._is_postgres and not is_many:
parameter_casts = resolve_parameter_casts(prepared_statement)
return prepare_postgres_parameters(
parameters,
parameter_casts,
statement_config,
dialect=self._dialect_name,
json_serializer=self._json_serializer,
)
return super().prepare_driver_parameters(parameters, statement_config, is_many, prepared_statement)
register_driver_profile("adbc", driver_profile)