Source code for sqlspec.adapters.duckdb.pool

"""DuckDB connection pool with thread-local connections."""

import logging
import threading
import time
import uuid
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Any, Final, cast

import duckdb

from sqlspec.adapters.duckdb._typing import DuckDBConnection
from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context

if TYPE_CHECKING:
    from collections.abc import Callable, Generator


logger = get_logger(POOL_LOGGER_NAME)
_ADAPTER_NAME = "duckdb"

DEFAULT_MIN_POOL: Final[int] = 1
DEFAULT_MAX_POOL: Final[int] = 4
POOL_TIMEOUT: Final[float] = 30.0
POOL_RECYCLE: Final[int] = 86400
HEALTH_CHECK_INTERVAL: Final[float] = 30.0

__all__ = ("DuckDBConnectionPool",)


[docs] class DuckDBConnectionPool: """Thread-local connection manager for DuckDB. Uses thread-local storage to ensure each thread gets its own DuckDB connection, preventing the thread-safety issues that cause segmentation faults when multiple cursors share the same connection concurrently. This design trades traditional pooling for thread safety, which is essential for DuckDB since connections and cursors are not thread-safe. """ __slots__ = ( "_connection_config", "_connection_times", "_created_connections", "_extension_flags", "_extensions", "_health_check_interval", "_is_memory_db", "_lock", "_on_connection_create", "_pool_id", "_recycle", "_secrets", "_thread_local", )
[docs] def __init__( self, connection_config: "dict[str, Any]", pool_recycle_seconds: int = POOL_RECYCLE, health_check_interval: float = HEALTH_CHECK_INTERVAL, extensions: "list[dict[str, Any]] | None" = None, extension_flags: "dict[str, Any] | None" = None, secrets: "list[dict[str, Any]] | None" = None, on_connection_create: "Callable[[DuckDBConnection], DuckDBConnection | None] | None" = None, ) -> None: """Initialize the thread-local connection manager. Args: connection_config: DuckDB connection configuration pool_recycle_seconds: Connection recycle time in seconds health_check_interval: Seconds of idle time before running health check extensions: List of extensions to install/load extension_flags: Connection-level SET statements applied after creation secrets: List of secrets to create on_connection_create: Callback executed when connection is created """ self._connection_config = connection_config self._recycle = pool_recycle_seconds self._health_check_interval = health_check_interval self._extensions = extensions or [] self._extension_flags = extension_flags or {} self._secrets = secrets or [] self._on_connection_create = on_connection_create self._thread_local = threading.local() self._lock = threading.RLock() self._created_connections = 0 self._connection_times: dict[int, float] = {} self._pool_id = str(uuid.uuid4())[:8] # Track if this pool uses an in-memory database # In-memory databases require connections to stay alive to preserve data database = connection_config.get("database", "") self._is_memory_db = database.startswith(":memory:") or database == ""
@property def _database_name(self) -> str: """Get sanitized database name for logging.""" db = self._connection_config.get("database", "") if db.startswith(":memory:") or db == "": return ":memory:" return str(db) def _create_connection(self) -> DuckDBConnection: """Create a new DuckDB connection with extensions and secrets.""" connect_parameters = {} config_dict = {} for key, value in self._connection_config.items(): if key in {"database", "read_only"}: connect_parameters[key] = value else: config_dict[key] = value if config_dict: connect_parameters["config"] = config_dict connection = duckdb.connect(**connect_parameters) self._apply_extension_flags(connection) for ext_config in self._extensions: ext_name = ext_config.get("name") if not ext_name: continue install_kwargs = {} if "version" in ext_config: install_kwargs["version"] = ext_config["version"] if "repository" in ext_config: install_kwargs["repository"] = ext_config["repository"] if ext_config.get("force_install", False): install_kwargs["force_install"] = True try: if install_kwargs: connection.install_extension(ext_name, **install_kwargs) connection.load_extension(ext_name) except Exception as e: log_with_context( logger, logging.DEBUG, "pool.extension.load.failed", adapter=_ADAPTER_NAME, pool_id=self._pool_id, database=self._database_name, extension=ext_name, error=str(e), ) for secret_config in self._secrets: secret_type = secret_config.get("secret_type") secret_name = secret_config.get("name") secret_value = secret_config.get("value") if not (secret_type and secret_name and secret_value): continue value_pairs = [] for key, value in secret_value.items(): escaped_value = str(value).replace("'", "''") value_pairs.append(f"'{key}' = '{escaped_value}'") value_string = ", ".join(value_pairs) scope_clause = "" if "scope" in secret_config: scope_clause = f" SCOPE '{secret_config['scope']}'" sql = f""" CREATE SECRET {secret_name} ( TYPE {secret_type}, {value_string} ){scope_clause} """ with suppress(Exception): connection.execute(sql) if self._on_connection_create: with suppress(Exception): self._on_connection_create(connection) conn_id = id(connection) with self._lock: self._created_connections += 1 self._connection_times[conn_id] = time.time() return connection def _apply_extension_flags(self, connection: DuckDBConnection) -> None: """Apply connection-level extension flags via SET statements.""" if not self._extension_flags: return for key, value in self._extension_flags.items(): if not key or not key.replace("_", "").isalnum(): continue normalized = self._normalize_flag_value(value) try: connection.execute(f"SET {key} = {normalized}") except Exception as exc: # pragma: no cover - best-effort guard log_with_context( logger, logging.DEBUG, "pool.flag.set.failed", adapter=_ADAPTER_NAME, pool_id=self._pool_id, database=self._database_name, flag=key, error=str(exc), ) @staticmethod def _normalize_flag_value(value: Any) -> str: """Convert Python value to DuckDB SET literal.""" if isinstance(value, bool): return "TRUE" if value else "FALSE" if isinstance(value, (int, float)): return str(value) escaped = str(value).replace("'", "''") return f"'{escaped}'" def _get_thread_connection(self) -> DuckDBConnection: """Get or create a connection for the current thread. Each thread gets its own dedicated DuckDB connection to prevent thread-safety issues with concurrent cursor operations. """ thread_state = self._thread_local.__dict__ if "connection" not in thread_state: self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("DuckDBConnection", self._thread_local.connection) if self._recycle > 0 and time.time() - self._thread_local.created_at > self._recycle: with suppress(Exception): self._thread_local.connection.close() self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("DuckDBConnection", self._thread_local.connection) idle_time = time.time() - thread_state.get("last_used", 0) if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection): log_with_context( logger, logging.DEBUG, "pool.connection.recycle", adapter=_ADAPTER_NAME, pool_id=self._pool_id, database=self._database_name, idle_seconds=round(idle_time, 1), reason="failed_health_check", ) with suppress(Exception): self._thread_local.connection.close() self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("DuckDBConnection", self._thread_local.connection) def _close_thread_connection(self) -> None: """Close the connection for the current thread.""" thread_state = self._thread_local.__dict__ if "connection" in thread_state: with suppress(Exception): self._thread_local.connection.close() del self._thread_local.connection if "created_at" in thread_state: del self._thread_local.created_at if "last_used" in thread_state: del self._thread_local.last_used def _is_connection_alive(self, connection: DuckDBConnection) -> bool: """Check if a connection is still alive and usable. Args: connection: Connection to check Returns: True if connection is alive, False otherwise """ try: cursor = connection.cursor() cursor.close() except Exception: return False return True
[docs] @contextmanager def get_connection(self) -> "Generator[DuckDBConnection, None, None]": """Get a thread-local connection. Each thread gets its own dedicated DuckDB connection to prevent thread-safety issues with concurrent cursor operations. For file-based databases, the connection is closed when the context manager exits to release DuckDB's file lock, allowing subsequent connections with different configurations. For in-memory databases, connections are kept alive to preserve data, as in-memory data is lost when the last connection closes. Yields: DuckDBConnection: A thread-local connection. """ connection = self._get_thread_connection() try: yield connection except Exception: self._close_thread_connection() raise else: # Only close connection for file-based databases to release file locks # In-memory databases need connections to stay alive to preserve data if not self._is_memory_db: self._close_thread_connection()
[docs] def close(self) -> None: """Close the thread-local connection if it exists.""" self._close_thread_connection()
[docs] def size(self) -> int: """Get current pool size (always 1 for thread-local).""" return 1 if "connection" in self._thread_local.__dict__ else 0
[docs] def checked_out(self) -> int: """Get number of checked out connections (always 0 for thread-local).""" return 0
[docs] def acquire(self) -> DuckDBConnection: """Acquire a thread-local connection. Each thread gets its own dedicated DuckDB connection to prevent thread-safety issues with concurrent cursor operations. Returns: DuckDBConnection: A thread-local connection """ return self._get_thread_connection()
[docs] def release(self, connection: DuckDBConnection) -> None: """Release a connection (no-op for thread-local connections). Args: connection: The connection to release (ignored) """