"""DuckDB driver implementation."""
import contextlib
import typing
from datetime import date, datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Final, cast
from uuid import uuid4
import duckdb
from sqlspec.adapters.duckdb.data_dictionary import DuckDBSyncDataDictionary
from sqlspec.adapters.duckdb.type_converter import DuckDBTypeConverter
from sqlspec.core import (
SQL,
DriverParameterProfile,
ParameterStyle,
StatementConfig,
build_statement_config_from_profile,
get_cache_config,
register_driver_profile,
)
from sqlspec.driver import SyncDriverAdapterBase
from sqlspec.exceptions import (
CheckViolationError,
DatabaseConnectionError,
DataError,
ForeignKeyViolationError,
IntegrityError,
NotFoundError,
NotNullViolationError,
OperationalError,
SQLParsingError,
SQLSpecError,
UniqueViolationError,
)
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import to_json
from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter
if TYPE_CHECKING:
from contextlib import AbstractContextManager
from sqlspec.adapters.duckdb._types import DuckDBConnection
from sqlspec.builder import QueryBuilder
from sqlspec.core import ArrowResult, SQLResult, Statement, StatementFilter
from sqlspec.driver import ExecutionResult
from sqlspec.driver._sync import SyncDataDictionaryBase
from sqlspec.storage import (
StorageBridgeJob,
StorageDestination,
StorageFormat,
StorageTelemetry,
SyncStoragePipeline,
)
from sqlspec.typing import ArrowReturnFormat, StatementParameters
__all__ = (
"DuckDBCursor",
"DuckDBDriver",
"DuckDBExceptionHandler",
"build_duckdb_statement_config",
"duckdb_statement_config",
)
logger = get_logger("adapters.duckdb")
_TIME_TO_ISO = build_time_iso_converter()
_DECIMAL_TO_STRING = build_decimal_converter(mode="string")
_type_converter = DuckDBTypeConverter()
class DuckDBCursor:
"""Context manager for DuckDB cursor management."""
__slots__ = ("connection", "cursor")
def __init__(self, connection: "DuckDBConnection") -> None:
self.connection = connection
self.cursor: Any | None = None
def __enter__(self) -> Any:
self.cursor = self.connection.cursor()
return self.cursor
def __exit__(self, *_: Any) -> None:
if self.cursor is not None:
self.cursor.close()
class DuckDBExceptionHandler:
"""Context manager for handling DuckDB database exceptions.
Uses exception type and message-based detection to map DuckDB errors
to specific SQLSpec exceptions for better error handling.
"""
__slots__ = ()
def __enter__(self) -> None:
return None
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
_ = exc_tb
if exc_type is None:
return
self._map_duckdb_exception(exc_type, exc_val)
def _map_duckdb_exception(self, exc_type: Any, e: Any) -> None:
"""Map DuckDB exception to SQLSpec exception.
Uses exception type and message-based detection.
Args:
exc_type: Exception type
e: Exception instance
"""
error_msg = str(e).lower()
exc_name = exc_type.__name__ if hasattr(exc_type, "__name__") else str(exc_type)
if "constraintexception" in exc_name.lower():
self._handle_constraint_exception(e, error_msg)
elif "catalogexception" in exc_name.lower():
self._raise_not_found_error(e)
elif "parserexception" in exc_name.lower() or "binderexception" in exc_name.lower():
self._raise_parsing_error(e)
elif "ioexception" in exc_name.lower():
self._raise_operational_error(e)
elif "conversionexception" in exc_name.lower() or "type mismatch" in error_msg:
self._raise_data_error(e)
else:
self._raise_generic_error(e)
def _handle_constraint_exception(self, e: Any, error_msg: str) -> None:
"""Handle constraint exceptions using message-based detection.
Args:
e: Exception instance
error_msg: Lowercase error message
"""
if "unique" in error_msg or "duplicate" in error_msg:
self._raise_unique_violation(e)
elif "foreign key" in error_msg or "violates foreign key" in error_msg:
self._raise_foreign_key_violation(e)
elif "not null" in error_msg or "null value" in error_msg:
self._raise_not_null_violation(e)
elif "check constraint" in error_msg or "check condition" in error_msg:
self._raise_check_violation(e)
else:
self._raise_integrity_error(e)
def _raise_unique_violation(self, e: Any) -> None:
msg = f"DuckDB unique constraint violation: {e}"
raise UniqueViolationError(msg) from e
def _raise_foreign_key_violation(self, e: Any) -> None:
msg = f"DuckDB foreign key constraint violation: {e}"
raise ForeignKeyViolationError(msg) from e
def _raise_not_null_violation(self, e: Any) -> None:
msg = f"DuckDB not-null constraint violation: {e}"
raise NotNullViolationError(msg) from e
def _raise_check_violation(self, e: Any) -> None:
msg = f"DuckDB check constraint violation: {e}"
raise CheckViolationError(msg) from e
def _raise_integrity_error(self, e: Any) -> None:
msg = f"DuckDB integrity constraint violation: {e}"
raise IntegrityError(msg) from e
def _raise_not_found_error(self, e: Any) -> None:
msg = f"DuckDB catalog error: {e}"
raise NotFoundError(msg) from e
def _raise_parsing_error(self, e: Any) -> None:
msg = f"DuckDB SQL parsing error: {e}"
raise SQLParsingError(msg) from e
def _raise_operational_error(self, e: Any) -> None:
msg = f"DuckDB operational error: {e}"
raise OperationalError(msg) from e
def _raise_data_error(self, e: Any) -> None:
msg = f"DuckDB data error: {e}"
raise DataError(msg) from e
def _raise_generic_error(self, e: Any) -> None:
msg = f"DuckDB database error: {e}"
raise SQLSpecError(msg) from e
[docs]
class DuckDBDriver(SyncDriverAdapterBase):
"""Synchronous DuckDB database driver.
Provides SQL statement execution, transaction management, and result handling
for DuckDB databases. Supports multiple parameter styles including QMARK,
NUMERIC, and NAMED_DOLLAR formats.
The driver handles script execution, batch operations, and integrates with
the sqlspec.core modules for statement processing and caching.
"""
__slots__ = ("_data_dictionary",)
dialect = "duckdb"
[docs]
def __init__(
self,
connection: "DuckDBConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
cache_config = get_cache_config()
updated_config = duckdb_statement_config.replace(
enable_caching=cache_config.compiled_cache_enabled,
enable_parsing=True,
enable_validation=True,
dialect="duckdb",
)
statement_config = updated_config
if driver_features:
param_config = statement_config.parameter_config
json_serializer = driver_features.get("json_serializer")
if json_serializer:
param_config = param_config.with_json_serializers(json_serializer, tuple_strategy="tuple")
enable_uuid_conversion = driver_features.get("enable_uuid_conversion", True)
if not enable_uuid_conversion:
type_converter = DuckDBTypeConverter(enable_uuid_conversion=enable_uuid_conversion)
type_coercion_map = dict(param_config.type_coercion_map)
type_coercion_map[str] = type_converter.convert_if_detected
param_config = param_config.replace(type_coercion_map=type_coercion_map)
if param_config is not statement_config.parameter_config:
statement_config = statement_config.replace(parameter_config=param_config)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: SyncDataDictionaryBase | None = None
[docs]
def with_cursor(self, connection: "DuckDBConnection") -> "DuckDBCursor":
"""Create context manager for DuckDB cursor.
Args:
connection: DuckDB connection instance
Returns:
DuckDBCursor context manager instance
"""
return DuckDBCursor(connection)
[docs]
def handle_database_exceptions(self) -> "AbstractContextManager[None]":
"""Handle database-specific exceptions and wrap them appropriately.
Returns:
Context manager that catches and converts DuckDB exceptions
"""
return DuckDBExceptionHandler()
def _try_special_handling(self, cursor: Any, statement: SQL) -> "SQLResult | None":
"""Handle DuckDB-specific special operations.
DuckDB does not require special operation handling, so this method
returns None to indicate standard execution should proceed.
Args:
cursor: DuckDB cursor object
statement: SQL statement to analyze
Returns:
None to indicate no special handling required
"""
_ = (cursor, statement)
return None
def _execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult":
"""Execute SQL script with statement splitting and parameter handling.
Parses multi-statement scripts and executes each statement sequentially
with the provided parameters.
Args:
cursor: DuckDB cursor object
statement: SQL statement with script content
Returns:
ExecutionResult with script execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_result = None
for stmt in statements:
last_result = cursor.execute(stmt, prepared_parameters or ())
successful_count += 1
return self.create_execution_result(
last_result, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
def _execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult":
"""Execute SQL with multiple parameter sets using batch processing.
Uses DuckDB's executemany method for batch operations and calculates
row counts for both data modification and query operations.
Args:
cursor: DuckDB cursor object
statement: SQL statement with multiple parameter sets
Returns:
ExecutionResult with batch execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
if prepared_parameters:
cursor.executemany(sql, prepared_parameters)
if statement.is_modifying_operation():
row_count = len(prepared_parameters)
else:
try:
result = cursor.fetchone()
row_count = int(result[0]) if result and isinstance(result, tuple) and len(result) == 1 else 0
except Exception:
row_count = max(cursor.rowcount, 0) if hasattr(cursor, "rowcount") else 0
else:
row_count = 0
return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True)
def _execute_statement(self, cursor: Any, statement: SQL) -> "ExecutionResult":
"""Execute single SQL statement with data handling.
Executes a SQL statement with parameter binding and processes the results.
Handles both data-returning queries and data modification operations.
Args:
cursor: DuckDB cursor object
statement: SQL statement to execute
Returns:
ExecutionResult with execution metadata
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
cursor.execute(sql, prepared_parameters or ())
if statement.returns_rows():
fetched_data = cursor.fetchall()
column_names = [col[0] for col in cursor.description or []]
if fetched_data and isinstance(fetched_data[0], tuple):
dict_data = [dict(zip(column_names, row, strict=False)) for row in fetched_data]
else:
dict_data = fetched_data
return self.create_execution_result(
cursor,
selected_data=dict_data,
column_names=column_names,
data_row_count=len(dict_data),
is_select_result=True,
)
try:
result = cursor.fetchone()
row_count = int(result[0]) if result and isinstance(result, tuple) and len(result) == 1 else 0
except Exception:
row_count = max(cursor.rowcount, 0) if hasattr(cursor, "rowcount") else 0
return self.create_execution_result(cursor, rowcount_override=row_count)
[docs]
def begin(self) -> None:
"""Begin a database transaction."""
try:
self.connection.execute("BEGIN TRANSACTION")
except duckdb.Error as e:
msg = f"Failed to begin DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def rollback(self) -> None:
"""Rollback the current transaction."""
try:
self.connection.rollback()
except duckdb.Error as e:
msg = f"Failed to rollback DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def commit(self) -> None:
"""Commit the current transaction."""
try:
self.connection.commit()
except duckdb.Error as e:
msg = f"Failed to commit DuckDB transaction: {e}"
raise SQLSpecError(msg) from e
@property
def data_dictionary(self) -> "SyncDataDictionaryBase":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = DuckDBSyncDataDictionary()
return self._data_dictionary
[docs]
def select_to_arrow(
self,
statement: "Statement | QueryBuilder",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
return_format: "ArrowReturnFormat" = "table",
native_only: bool = False,
batch_size: int | None = None,
arrow_schema: Any = None,
**kwargs: Any,
) -> "ArrowResult":
"""Execute query and return results as Apache Arrow (DuckDB native path).
DuckDB provides native Arrow support via cursor.arrow().
This is the fastest path due to DuckDB's columnar architecture.
Args:
statement: SQL statement, string, or QueryBuilder
*parameters: Query parameters or filters
statement_config: Optional statement configuration override
return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch
native_only: Ignored for DuckDB (always uses native path)
batch_size: Batch size hint (for future streaming implementation)
arrow_schema: Optional pyarrow.Schema for type casting
**kwargs: Additional keyword arguments
Returns:
ArrowResult with native Arrow data
Example:
>>> result = driver.select_to_arrow(
... "SELECT * FROM users WHERE age > ?", 18
... )
>>> df = result.to_pandas() # Fast zero-copy conversion
"""
from sqlspec.utils.module_loader import ensure_pyarrow
ensure_pyarrow()
import pyarrow as pa
from sqlspec.core import create_arrow_result
# Prepare statement
config = statement_config or self.statement_config
prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs)
# Execute query and get native Arrow
with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions():
if cursor is None:
msg = "Failed to create cursor"
raise DatabaseConnectionError(msg)
# Get compiled SQL and parameters
sql, driver_params = self._get_compiled_sql(prepared_statement, config)
# Execute query
cursor.execute(sql, driver_params or ())
# DuckDB native Arrow (zero-copy!)
arrow_reader = cursor.arrow()
arrow_table = arrow_reader.read_all()
# Apply schema casting if requested
if arrow_schema is not None:
arrow_table = arrow_table.cast(arrow_schema)
# Convert to batch if requested
if return_format == "batch":
batches = arrow_table.to_batches()
arrow_data: Any = batches[0] if batches else pa.RecordBatch.from_pydict({})
else:
arrow_data = arrow_table
# Create ArrowResult
return create_arrow_result(statement=prepared_statement, data=arrow_data, rows_affected=arrow_data.num_rows)
[docs]
def select_to_storage(
self,
statement: "Statement | QueryBuilder | SQL | str",
destination: "StorageDestination",
/,
*parameters: "StatementParameters | StatementFilter",
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, Any] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Persist DuckDB query output to a storage backend using Arrow fast paths."""
_ = kwargs
self._require_capability("arrow_export_enabled")
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline())
telemetry_payload = self._write_result_to_storage_sync(
arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, Any] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Load Arrow data into DuckDB using temporary table registration."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
temp_view = f"_sqlspec_arrow_{uuid4().hex}"
if overwrite:
self.connection.execute(f"TRUNCATE TABLE {table}")
self.connection.register(temp_view, arrow_table)
try:
self.connection.execute(f"INSERT INTO {table} SELECT * FROM {temp_view}")
finally:
with contextlib.suppress(Exception):
self.connection.unregister(temp_view)
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, Any] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Read an artifact from storage and load it into DuckDB."""
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
def _bool_to_int(value: bool) -> int:
return int(value)
def _build_duckdb_profile() -> DriverParameterProfile:
"""Create the DuckDB driver parameter profile."""
return DriverParameterProfile(
name="DuckDB",
default_style=ParameterStyle.QMARK,
supported_styles={ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR},
default_execution_style=ParameterStyle.QMARK,
supported_execution_styles={ParameterStyle.QMARK},
has_native_list_expansion=True,
preserve_parameter_format=True,
needs_static_script_compilation=False,
allow_mixed_parameter_styles=False,
preserve_original_params_for_many=False,
json_serializer_strategy="helper",
custom_type_coercions={
bool: _bool_to_int,
datetime: _TIME_TO_ISO,
date: _TIME_TO_ISO,
Decimal: _DECIMAL_TO_STRING,
},
default_dialect="duckdb",
)
_DUCKDB_PROFILE = _build_duckdb_profile()
register_driver_profile("duckdb", _DUCKDB_PROFILE)
def build_duckdb_statement_config(*, json_serializer: "typing.Callable[[Any], str] | None" = None) -> StatementConfig:
"""Construct the DuckDB statement configuration with optional JSON serializer."""
serializer = json_serializer or to_json
return build_statement_config_from_profile(
_DUCKDB_PROFILE, statement_overrides={"dialect": "duckdb"}, json_serializer=serializer
)
duckdb_statement_config = build_duckdb_statement_config()
MODIFYING_OPERATIONS: Final[tuple[str, ...]] = ("INSERT", "UPDATE", "DELETE")