"""DuckDB driver implementation."""
import contextlib
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4
import duckdb
from sqlspec.adapters.duckdb.core import (
apply_driver_features,
collect_rows,
create_mapped_exception,
default_statement_config,
driver_profile,
normalize_execute_parameters,
resolve_rowcount,
)
from sqlspec.adapters.duckdb.data_dictionary import DuckDBDataDictionary
from sqlspec.adapters.duckdb.type_converter import DuckDBOutputConverter
from sqlspec.core import SQL, StatementConfig, build_arrow_result_from_table, get_cache_config, register_driver_profile
from sqlspec.driver import SyncDriverAdapterBase
from sqlspec.exceptions import DatabaseConnectionError, SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import ensure_pyarrow
if TYPE_CHECKING:
from sqlspec.adapters.duckdb._typing import DuckDBConnection
from sqlspec.builder import QueryBuilder
from sqlspec.core import ArrowResult, Statement, StatementFilter
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
from sqlspec.typing import ArrowReturnFormat, StatementParameters
from typing_extensions import Self
from sqlspec.adapters.duckdb._typing import DuckDBSessionContext
__all__ = ("DuckDBCursor", "DuckDBDriver", "DuckDBExceptionHandler", "DuckDBSessionContext")
logger = get_logger("sqlspec.adapters.duckdb")
_type_converter = DuckDBOutputConverter()
class DuckDBCursor:
"""Context manager for DuckDB connection-as-cursor.
DuckDB connections implement the cursor interface and preserve
variable state. Using connection directly avoids cursor overhead
and fixes SET VARIABLE persistence.
See: https://github.com/litestar-org/sqlspec/issues/341
"""
__slots__ = ("connection",)
def __init__(self, connection: "DuckDBConnection") -> None:
self.connection = connection
def __enter__(self) -> "DuckDBConnection":
return self.connection
def __exit__(self, *_: Any) -> None:
pass # Connection lifecycle managed by pool/session
class DuckDBExceptionHandler:
"""Context manager for handling DuckDB database exceptions.
Uses exception type and message-based detection to map DuckDB errors
to specific SQLSpec exceptions for better error handling.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __exit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ("pending_exception",)
def __init__(self) -> None:
self.pending_exception: Exception | None = None
def __enter__(self) -> Self:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
_ = exc_tb
if exc_type is None:
return False
self.pending_exception = create_mapped_exception(exc_type, exc_val)
return True
[docs]
class DuckDBDriver(SyncDriverAdapterBase):
"""Synchronous DuckDB database driver.
Provides SQL statement execution, transaction management, and result handling
for DuckDB databases. Supports multiple parameter styles including QMARK,
NUMERIC, and NAMED_DOLLAR formats.
The driver handles script execution, batch operations, and integrates with
the sqlspec.core modules for statement processing and caching.
"""
__slots__ = ("_data_dictionary",)
dialect = "duckdb"
[docs]
def __init__(
self,
connection: "DuckDBConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
statement_config = apply_driver_features(statement_config, driver_features)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: DuckDBDataDictionary | None = None
# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult":
"""Execute single SQL statement with data handling.
Executes a SQL statement with parameter binding and processes the results.
Handles both data-returning queries and data modification operations.
Args:
cursor: DuckDB cursor object
statement: SQL statement to execute
Returns:
ExecutionResult with execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
cursor.execute(sql, normalize_execute_parameters(prepared_parameters))
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)
if is_select_like:
fetched_data = cursor.fetchall()
data, column_names = collect_rows(cast("list[Any] | None", fetched_data), cursor.description)
row_format = "dict" if data and isinstance(data[0], dict) else "tuple"
return self.create_execution_result(
cursor,
selected_data=data,
column_names=column_names,
data_row_count=len(data),
is_select_result=True,
row_format=row_format,
)
row_count = resolve_rowcount(cursor)
return self.create_execution_result(cursor, rowcount_override=row_count)
[docs]
def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult":
"""Execute SQL with multiple parameter sets using batch processing.
Uses DuckDB's executemany method for batch operations and calculates
row counts for both data modification and query operations.
Args:
cursor: DuckDB cursor object
statement: SQL statement with multiple parameter sets
Returns:
ExecutionResult with batch execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
if prepared_parameters:
parameter_sets = cast("list[Any]", prepared_parameters)
cursor.executemany(sql, parameter_sets)
row_count = len(parameter_sets) if statement.is_modifying_operation() else resolve_rowcount(cursor)
else:
row_count = 0
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
[docs]
def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult":
"""Execute SQL script with statement splitting and parameter handling.
Parses multi-statement scripts and executes each statement sequentially
with the provided parameters.
Args:
cursor: DuckDB cursor object
statement: SQL statement with script content
Returns:
ExecutionResult with script execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_result = None
for stmt in statements:
last_result = cursor.execute(stmt, normalize_execute_parameters(prepared_parameters))
successful_count += 1
return self.create_execution_result(
last_result, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
# ─────────────────────────────────────────────────────────────────────────────
# TRANSACTION MANAGEMENT
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def begin(self) -> None:
"""Begin a database transaction."""
try:
self.connection.execute("BEGIN TRANSACTION")
except duckdb.Error as e:
msg = f"Failed to begin DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def commit(self) -> None:
"""Commit the current transaction."""
try:
self.connection.commit()
except duckdb.Error as e:
msg = f"Failed to commit DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def rollback(self) -> None:
"""Rollback the current transaction."""
try:
self.connection.rollback()
except duckdb.Error as e:
msg = f"Failed to rollback DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def with_cursor(self, connection: "DuckDBConnection") -> "DuckDBCursor":
"""Create context manager for DuckDB cursor.
Args:
connection: DuckDB connection instance
Returns:
DuckDBCursor context manager instance
"""
return DuckDBCursor(connection)
[docs]
def handle_database_exceptions(self) -> "DuckDBExceptionHandler":
"""Handle database-specific exceptions and wrap them appropriately.
Returns:
Exception handler with deferred exception pattern for mypyc compatibility.
"""
return DuckDBExceptionHandler()
# ─────────────────────────────────────────────────────────────────────────────
# ARROW API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_arrow(
self,
statement: "Statement | QueryBuilder",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
return_format: "ArrowReturnFormat" = "table",
native_only: bool = False,
batch_size: int | None = None,
arrow_schema: Any = None,
**kwargs: Any,
) -> "ArrowResult":
"""Execute query and return results as Apache Arrow (DuckDB native path).
DuckDB provides native Arrow support via cursor.arrow().
This is the fastest path due to DuckDB's columnar architecture.
Args:
statement: SQL statement, string, or QueryBuilder
*parameters: Query parameters or filters
statement_config: Optional statement configuration override
return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch,
"batches" for list of RecordBatch, "reader" for RecordBatchReader
native_only: Ignored for DuckDB (always uses native path)
batch_size: Batch size hint (for future streaming implementation)
arrow_schema: Optional pyarrow.Schema for type casting
**kwargs: Additional keyword arguments
Returns:
ArrowResult with native Arrow data
Example:
>>> result = driver.select_to_arrow(
... "SELECT * FROM users WHERE age > ?", 18
... )
>>> df = result.to_pandas() # Fast zero-copy conversion
"""
ensure_pyarrow()
# Prepare statement
config = statement_config or self.statement_config
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
exc_handler = self.handle_database_exceptions()
arrow_result: ArrowResult | None = None
# Execute query and get native Arrow
with self.with_cursor(self.connection) as cursor, exc_handler:
if cursor is None:
msg = "Failed to create cursor"
raise DatabaseConnectionError(msg)
# Get compiled SQL and parameters
sql, driver_params = self._get_compiled_sql(prepared_statement, config)
# Execute query
cursor.execute(sql, driver_params or ())
# DuckDB native Arrow (zero-copy!)
arrow_reader = cursor.arrow()
arrow_table = arrow_reader.read_all()
arrow_result = build_arrow_result_from_table(
prepared_statement,
arrow_table,
return_format=return_format,
batch_size=batch_size,
arrow_schema=arrow_schema,
)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
if arrow_result is None:
msg = "Unreachable"
raise RuntimeError(msg) # pragma: no cover
return arrow_result
# ─────────────────────────────────────────────────────────────────────────────
# STORAGE API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_storage(
self,
statement: "Statement | QueryBuilder | SQL | str",
destination: "StorageDestination",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Persist DuckDB query output to a storage backend using Arrow fast paths."""
_ = kwargs
self._require_capability("arrow_export_enabled")
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
sync_pipeline = self._storage_pipeline()
telemetry_payload = self._write_result_to_storage_sync(
arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Load Arrow data into DuckDB using temporary table registration."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
temp_view = f"_sqlspec_arrow_{uuid4().hex}"
if overwrite:
self.connection.execute(f"TRUNCATE TABLE {table}")
self.connection.register(temp_view, arrow_table)
try:
self.connection.execute(f"INSERT INTO {table} SELECT * FROM {temp_view}")
finally:
with contextlib.suppress(Exception):
self.connection.unregister(temp_view)
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Read an artifact from storage and load it into DuckDB."""
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ─────────────────────────────────────────────────────────────────────────────
# UTILITY METHODS
# ─────────────────────────────────────────────────────────────────────────────
@property
def data_dictionary(self) -> "DuckDBDataDictionary":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = DuckDBDataDictionary()
return self._data_dictionary
# ─────────────────────────────────────────────────────────────────────────────
# PRIVATE / INTERNAL METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect DuckDB rows for the direct execution path."""
data, column_names = collect_rows(cast("list[Any] | None", fetched), cursor.description)
return data, column_names, len(data)
[docs]
def resolve_rowcount(self, cursor: Any) -> int:
"""Resolve rowcount from DuckDB cursor for the direct execution path."""
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
"""Check if connection is in transaction.
DuckDB uses explicit BEGIN TRANSACTION and does not expose transaction state.
Returns:
False - DuckDB requires explicit transaction management.
"""
return False
register_driver_profile("duckdb", driver_profile)
MODIFYING_OPERATIONS: "tuple[str, ...]" = ("INSERT", "UPDATE", "DELETE")