"""DuckDB driver implementation."""
import contextlib
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
import duckdb
from sqlspec.adapters.duckdb._typing import DuckDBCursor, DuckDBSessionContext
from sqlspec.adapters.duckdb.core import (
apply_driver_features,
collect_rows,
create_mapped_exception,
default_statement_config,
driver_profile,
normalize_execute_parameters,
resolve_rowcount,
)
from sqlspec.adapters.duckdb.data_dictionary import DuckDBDataDictionary
from sqlspec.adapters.duckdb.type_converter import DuckDBOutputConverter
from sqlspec.core import SQL, StatementConfig, build_arrow_result_from_table, get_cache_config, register_driver_profile
from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase
from sqlspec.exceptions import DatabaseConnectionError, SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import ensure_pyarrow
if TYPE_CHECKING:
from sqlspec.adapters.duckdb._typing import DuckDBConnection
from sqlspec.builder import QueryBuilder
from sqlspec.core import ArrowResult, Statement, StatementFilter
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
from sqlspec.typing import ArrowReturnFormat, StatementParameters
__all__ = ("DuckDBCursor", "DuckDBDriver", "DuckDBExceptionHandler", "DuckDBSessionContext")
logger = get_logger("sqlspec.adapters.duckdb")
_type_converter = DuckDBOutputConverter()
class DuckDBExceptionHandler(BaseSyncExceptionHandler):
"""Context manager for handling DuckDB database exceptions.
Uses exception type and message-based detection to map DuckDB errors
to specific SQLSpec exceptions for better error handling.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __exit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ()
def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
if exc_type is None:
return False
self.pending_exception = create_mapped_exception(exc_type, exc_val)
return True
[docs]
class DuckDBDriver(SyncDriverAdapterBase):
"""Synchronous DuckDB database driver.
Provides SQL statement execution, transaction management, and result handling
for DuckDB databases. Supports multiple parameter styles including QMARK,
NUMERIC, and NAMED_DOLLAR formats.
The driver handles script execution, batch operations, and integrates with
the sqlspec.core modules for statement processing and caching.
"""
__slots__ = ("_data_dictionary",)
dialect = "duckdb"
[docs]
def __init__(
self,
connection: "DuckDBConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
statement_config = apply_driver_features(statement_config, driver_features)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: DuckDBDataDictionary | None = None
# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def dispatch_execute(self, cursor: "DuckDBConnection", statement: SQL) -> "ExecutionResult":
"""Execute single SQL statement with data handling.
Executes a SQL statement with parameter binding and processes the results.
Handles both data-returning queries and data modification operations.
Args:
cursor: DuckDB cursor object
statement: SQL statement to execute
Returns:
ExecutionResult with execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
cursor.execute(sql, normalize_execute_parameters(prepared_parameters))
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)
if is_select_like:
fetched_data = cursor.fetchall()
data, column_names = collect_rows(cast("list[Any] | None", fetched_data), cursor.description)
row_format = "dict" if data and isinstance(data[0], dict) else "tuple"
return self.create_execution_result(
cursor,
selected_data=data,
column_names=column_names,
data_row_count=len(data),
is_select_result=True,
row_format=row_format,
)
row_count = resolve_rowcount(cursor)
return self.create_execution_result(cursor, rowcount_override=row_count)
[docs]
def dispatch_execute_many(self, cursor: "DuckDBConnection", statement: SQL) -> "ExecutionResult":
"""Execute SQL with multiple parameter sets using batch processing.
Uses DuckDB's executemany method for batch operations and calculates
row counts for both data modification and query operations.
Args:
cursor: DuckDB cursor object
statement: SQL statement with multiple parameter sets
Returns:
ExecutionResult with batch execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
if prepared_parameters:
parameter_sets = cast("list[Any]", prepared_parameters)
cursor.executemany(sql, parameter_sets)
row_count = len(parameter_sets) if statement.is_modifying_operation() else resolve_rowcount(cursor)
else:
row_count = 0
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
[docs]
def dispatch_execute_script(self, cursor: "DuckDBConnection", statement: SQL) -> "ExecutionResult":
"""Execute SQL script with statement splitting and parameter handling.
Parses multi-statement scripts and executes each statement sequentially
with the provided parameters.
Args:
cursor: DuckDB cursor object
statement: SQL statement with script content
Returns:
ExecutionResult with script execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_result = None
for stmt in statements:
last_result = cursor.execute(stmt, normalize_execute_parameters(prepared_parameters))
successful_count += 1
return self.create_execution_result(
last_result, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
# ─────────────────────────────────────────────────────────────────────────────
# TRANSACTION MANAGEMENT
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def begin(self) -> None:
"""Begin a database transaction."""
try:
self.connection.execute("BEGIN TRANSACTION")
except duckdb.Error as e:
msg = f"Failed to begin DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def commit(self) -> None:
"""Commit the current transaction."""
try:
self.connection.commit()
except duckdb.Error as e:
msg = f"Failed to commit DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def rollback(self) -> None:
"""Rollback the current transaction."""
try:
self.connection.rollback()
except duckdb.Error as e:
msg = f"Failed to rollback DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def with_cursor(self, connection: "DuckDBConnection") -> "DuckDBCursor":
"""Create context manager for DuckDB cursor.
Args:
connection: DuckDB connection instance
Returns:
DuckDBCursor context manager instance
"""
return DuckDBCursor(connection)
[docs]
def handle_database_exceptions(self) -> "DuckDBExceptionHandler":
"""Handle database-specific exceptions and wrap them appropriately.
Returns:
Exception handler with deferred exception pattern for mypyc compatibility.
"""
return DuckDBExceptionHandler()
# ─────────────────────────────────────────────────────────────────────────────
# ARROW API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_arrow(
self,
statement: "Statement | QueryBuilder",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
return_format: "ArrowReturnFormat" = "table",
native_only: bool = False,
batch_size: int | None = None,
arrow_schema: Any = None,
**kwargs: Any,
) -> "ArrowResult":
"""Execute query and return results as Apache Arrow (DuckDB native path).
DuckDB provides native Arrow support via cursor.arrow().
This is the fastest path due to DuckDB's columnar architecture.
Args:
statement: SQL statement, string, or QueryBuilder
*parameters: Query parameters or filters
statement_config: Optional statement configuration override
return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch,
"batches" for list of RecordBatch, "reader" for RecordBatchReader
native_only: Ignored for DuckDB (always uses native path)
batch_size: Batch size hint (for future streaming implementation)
arrow_schema: Optional pyarrow.Schema for type casting
**kwargs: Additional keyword arguments
Returns:
ArrowResult with native Arrow data
Example:
>>> result = driver.select_to_arrow(
... "SELECT * FROM users WHERE age > ?", 18
... )
>>> df = result.to_pandas() # Fast zero-copy conversion
"""
ensure_pyarrow()
# Prepare statement
config = statement_config or self.statement_config
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
exc_handler = self.handle_database_exceptions()
arrow_result: ArrowResult | None = None
# Execute query and get native Arrow
with self.with_cursor(self.connection) as cursor, exc_handler:
if cursor is None:
msg = "Failed to create cursor"
raise DatabaseConnectionError(msg)
# Get compiled SQL and parameters
sql, driver_params = self._get_compiled_sql(prepared_statement, config)
# Execute query
cursor.execute(sql, driver_params or ())
# DuckDB native Arrow (zero-copy!)
arrow_reader = cursor.arrow()
arrow_table = arrow_reader.read_all()
arrow_result = build_arrow_result_from_table(
prepared_statement,
arrow_table,
return_format=return_format,
batch_size=batch_size,
arrow_schema=arrow_schema,
)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
if arrow_result is None:
msg = "Unreachable"
raise RuntimeError(msg) # pragma: no cover
return arrow_result
# ─────────────────────────────────────────────────────────────────────────────
# STORAGE API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_storage(
self,
statement: "Statement | QueryBuilder | SQL | str",
destination: "StorageDestination",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Persist DuckDB query output to a storage backend using Arrow fast paths."""
_ = kwargs
self._require_capability("arrow_export_enabled")
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
sync_pipeline = self._storage_pipeline()
telemetry_payload = self._write_result_to_storage_sync(
arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Load Arrow data into DuckDB using temporary table registration."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
temp_view = f"_sqlspec_arrow_{uuid4().hex}"
if overwrite:
self.connection.execute(f"TRUNCATE TABLE {table}")
self.connection.register(temp_view, arrow_table)
try:
self.connection.execute(f"INSERT INTO {table} SELECT * FROM {temp_view}")
finally:
with contextlib.suppress(Exception):
self.connection.unregister(temp_view)
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Read an artifact from storage and load it into DuckDB."""
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ─────────────────────────────────────────────────────────────────────────────
# UTILITY METHODS
# ─────────────────────────────────────────────────────────────────────────────
@property
def data_dictionary(self) -> "DuckDBDataDictionary":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = DuckDBDataDictionary()
return self._data_dictionary
# ─────────────────────────────────────────────────────────────────────────────
# PRIVATE / INTERNAL METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def collect_rows(self, cursor: "DuckDBConnection", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect DuckDB rows for the direct execution path."""
data, column_names = collect_rows(cast("list[Any] | None", fetched), cursor.description)
return data, column_names, len(data)
[docs]
def resolve_rowcount(self, cursor: "DuckDBConnection") -> int:
"""Resolve rowcount from DuckDB cursor for the direct execution path."""
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
"""Check if connection is in transaction.
DuckDB uses explicit BEGIN TRANSACTION and does not expose transaction state.
Returns:
False - DuckDB requires explicit transaction management.
"""
return False
register_driver_profile("duckdb", driver_profile)
MODIFYING_OPERATIONS: "tuple[str, ...]" = ("INSERT", "UPDATE", "DELETE")