"""DuckDB connection pool with thread-local connections."""
import logging
import re
import threading
import time
import uuid
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, Any, Final, cast
import duckdb
from typing_extensions import final
from sqlspec.adapters.duckdb._typing import DuckDBConnection
from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context
if TYPE_CHECKING:
from collections.abc import Callable, Generator
__all__ = ("DuckDBConnectionPool",)
_SQL_IDENTIFIER_RE: Final[re.Pattern[str]] = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$")
logger = get_logger(POOL_LOGGER_NAME)
_ADAPTER_NAME = "duckdb"
DEFAULT_MIN_POOL: Final[int] = 1
DEFAULT_MAX_POOL: Final[int] = 4
POOL_TIMEOUT: Final[float] = 30.0
POOL_RECYCLE: Final[int] = 86400
HEALTH_CHECK_INTERVAL: Final[float] = 30.0
[docs]
@final
class DuckDBConnectionPool:
"""Thread-local connection manager for DuckDB.
Uses thread-local storage to ensure each thread gets its own DuckDB connection,
preventing the thread-safety issues that cause segmentation faults when
multiple cursors share the same connection concurrently.
This design trades traditional pooling for thread safety, which is essential
for DuckDB since connections and cursors are not thread-safe.
"""
__slots__ = (
"_connection_config",
"_extension_flags",
"_extensions",
"_health_check_interval",
"_is_memory_db",
"_lock",
"_on_connection_create",
"_pool_id",
"_recycle",
"_secrets",
"_thread_local",
)
[docs]
def __init__(
self,
connection_config: "dict[str, Any]",
pool_recycle_seconds: int = POOL_RECYCLE,
health_check_interval: float = HEALTH_CHECK_INTERVAL,
extensions: "list[dict[str, Any]] | None" = None,
extension_flags: "dict[str, Any] | None" = None,
secrets: "list[dict[str, Any]] | None" = None,
on_connection_create: "Callable[[DuckDBConnection], DuckDBConnection | None] | None" = None,
) -> None:
"""Initialize the thread-local connection manager.
Args:
connection_config: DuckDB connection configuration
pool_recycle_seconds: Connection recycle time in seconds
health_check_interval: Seconds of idle time before running health check
extensions: List of extensions to install/load
extension_flags: Connection-level SET statements applied after creation
secrets: List of secrets to create
on_connection_create: Callback executed when connection is created
"""
self._connection_config = connection_config
self._recycle = pool_recycle_seconds
self._health_check_interval = health_check_interval
self._extensions = extensions or []
self._extension_flags = extension_flags or {}
self._secrets = secrets or []
self._on_connection_create = on_connection_create
self._thread_local = threading.local()
self._lock = threading.RLock()
self._pool_id = str(uuid.uuid4())[:8]
# Track if this pool uses an in-memory database
# In-memory databases require connections to stay alive to preserve data
database = connection_config.get("database", "")
self._is_memory_db = database.startswith(":memory:") or database == ""
@property
def _database_name(self) -> str:
"""Get sanitized database name for logging."""
db = self._connection_config.get("database", "")
if db.startswith(":memory:") or db == "":
return ":memory:"
return str(db)
def _create_connection(self) -> DuckDBConnection:
"""Create a new DuckDB connection with extensions and secrets."""
connect_parameters = {}
config_dict = {}
for key, value in self._connection_config.items():
if key in {"database", "read_only"}:
connect_parameters[key] = value
elif key == "config" and isinstance(value, dict):
config_dict.update(value)
else:
config_dict[key] = value
if config_dict:
connect_parameters["config"] = config_dict
connection = duckdb.connect(**connect_parameters)
self._apply_extension_flags(connection)
for ext_config in self._extensions:
ext_name = ext_config.get("name")
if not ext_name:
continue
install_kwargs = {}
if "version" in ext_config:
install_kwargs["version"] = ext_config["version"]
if "repository" in ext_config:
install_kwargs["repository"] = ext_config["repository"]
if "repository_url" in ext_config:
install_kwargs["repository_url"] = ext_config["repository_url"]
if ext_config.get("force_install", False):
install_kwargs["force_install"] = True
if install_kwargs:
connection.install_extension(ext_name, **install_kwargs)
else:
connection.install_extension(ext_name)
connection.load_extension(ext_name)
for secret_config in self._secrets:
_create_secret(connection, secret_config)
if self._on_connection_create:
# Let a failing user hook surface its real error instead of silently returning a
# half-configured connection (mirrors the sqlite/aiosqlite pools).
self._on_connection_create(connection)
return connection
def _apply_extension_flags(self, connection: DuckDBConnection) -> None:
"""Apply connection-level extension flags via SET statements."""
if not self._extension_flags:
return
for key, value in self._extension_flags.items():
if not key or not key.replace("_", "").isalnum():
continue
normalized = self._normalize_flag_value(value)
try:
connection.execute(f"SET {key} = {normalized}")
except Exception as exc: # pragma: no cover
log_with_context(
logger,
logging.DEBUG,
"pool.flag.set.failed",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
flag=key,
error=str(exc),
)
@staticmethod
def _normalize_flag_value(value: Any) -> str:
"""Convert Python value to DuckDB SET literal."""
if isinstance(value, bool):
return "TRUE" if value else "FALSE"
if isinstance(value, (int, float)):
return str(value)
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _get_thread_connection(self) -> DuckDBConnection:
"""Get or create a connection for the current thread.
Each thread gets its own dedicated DuckDB connection to prevent
thread-safety issues with concurrent cursor operations.
"""
thread_state = self._thread_local.__dict__
if "connection" not in thread_state:
self._thread_local.connection = self._create_connection()
self._thread_local.created_at = time.time()
self._thread_local.last_used = time.time()
return cast("DuckDBConnection", self._thread_local.connection)
if self._recycle > 0 and time.time() - self._thread_local.created_at > self._recycle:
with suppress(Exception):
self._thread_local.connection.close()
self._thread_local.connection = self._create_connection()
self._thread_local.created_at = time.time()
self._thread_local.last_used = time.time()
return cast("DuckDBConnection", self._thread_local.connection)
idle_time = time.time() - thread_state.get("last_used", 0)
if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection):
log_with_context(
logger,
logging.DEBUG,
"pool.connection.recycle",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
idle_seconds=round(idle_time, 1),
reason="failed_health_check",
)
with suppress(Exception):
self._thread_local.connection.close()
self._thread_local.connection = self._create_connection()
self._thread_local.created_at = time.time()
self._thread_local.last_used = time.time()
return cast("DuckDBConnection", self._thread_local.connection)
def _close_thread_connection(self) -> None:
"""Close the connection for the current thread."""
thread_state = self._thread_local.__dict__
if "connection" in thread_state:
with suppress(Exception):
self._thread_local.connection.close()
del self._thread_local.connection
if "created_at" in thread_state:
del self._thread_local.created_at
if "last_used" in thread_state:
del self._thread_local.last_used
def _is_connection_alive(self, connection: DuckDBConnection) -> bool:
"""Check if a connection is still alive and usable.
Args:
connection: Connection to check
Returns:
True if connection is alive, False otherwise
"""
try:
cursor = connection.cursor()
cursor.close()
except Exception:
return False
return True
[docs]
@contextmanager
def get_connection(self) -> "Generator[DuckDBConnection, None, None]":
"""Get a thread-local connection.
Each thread gets its own dedicated DuckDB connection to prevent
thread-safety issues with concurrent cursor operations.
For file-based databases, the connection is closed when the context
manager exits to release DuckDB's file lock, allowing subsequent
connections with different configurations.
For in-memory databases, connections are kept alive to preserve data,
as in-memory data is lost when the last connection closes.
Yields:
DuckDBConnection: A thread-local connection.
"""
connection = self._get_thread_connection()
try:
yield connection
except Exception:
self._close_thread_connection()
raise
else:
# Only close connection for file-based databases to release file locks
# In-memory databases need connections to stay alive to preserve data
if not self._is_memory_db:
self._close_thread_connection()
[docs]
def close(self) -> None:
"""Close the thread-local connection if it exists."""
self._close_thread_connection()
[docs]
def size(self) -> int:
"""Get current pool size (always 1 for thread-local)."""
return 1 if "connection" in self._thread_local.__dict__ else 0
[docs]
def checked_out(self) -> int:
"""Get number of checked out connections (always 0 for thread-local)."""
return 0
[docs]
def acquire(self) -> DuckDBConnection:
"""Acquire a thread-local connection.
Each thread gets its own dedicated DuckDB connection to prevent
thread-safety issues with concurrent cursor operations.
Returns:
DuckDBConnection: A thread-local connection
"""
return self._get_thread_connection()
[docs]
def release(self, connection: DuckDBConnection) -> None:
"""Release a connection (no-op for thread-local connections).
Args:
connection: The connection to release (ignored)
"""
def _validate_sql_identifier(value: str, field_name: str) -> None:
"""Raise ValueError if value is not safe to interpolate as a SQL identifier."""
if not _SQL_IDENTIFIER_RE.fullmatch(value):
msg = (
f"Invalid SQL identifier for {field_name!r}: {value!r}. "
"Must start with a letter and contain only letters, digits, and underscores."
)
raise ValueError(msg)
def _create_secret(connection: DuckDBConnection, secret_config: dict[str, Any]) -> None:
secret_name = secret_config.get("name")
secret_type = secret_config.get("secret_type")
if not (secret_name and secret_type):
return
_validate_sql_identifier(secret_name, "secret_name")
_validate_sql_identifier(secret_type, "secret_type")
sql = _build_secret_sql(secret_config, secret_name, secret_type)
connection.execute(sql)
_verify_secret(connection, secret_config, secret_name, secret_type)
def _build_secret_sql(secret_config: dict[str, Any], secret_name: str, secret_type: str) -> str:
parts = [f"TYPE {secret_type}"]
provider = secret_config.get("provider")
if provider:
_validate_sql_identifier(str(provider), "secret_provider")
parts.append(f"PROVIDER {provider}")
secret_value = secret_config.get("value") or {}
if not isinstance(secret_value, dict):
msg = "DuckDB secret value must be a dictionary"
raise TypeError(msg)
for key, value in secret_value.items():
parts.append(f"{_format_secret_key(key)} {_format_secret_literal(value)}")
scope = secret_config.get("scope")
if scope is not None:
parts.append(f"SCOPE {_format_secret_literal(scope)}")
create = "CREATE PERSISTENT SECRET" if secret_config.get("persistent", False) else "CREATE SECRET"
body = ",\n ".join(parts)
return f"{create} {secret_name} (\n {body}\n)"
def _format_secret_key(key: Any) -> str:
key_text = str(key)
_validate_sql_identifier(key_text, "secret_value_key")
return key_text.upper()
def _format_secret_literal(value: Any) -> str:
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, (int, float)):
return str(value)
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _verify_secret(
connection: DuckDBConnection, secret_config: dict[str, Any], secret_name: str, secret_type: str
) -> None:
row = connection.execute(
"SELECT name, type, scope, persistent FROM duckdb_secrets() WHERE name = ?", (secret_name,)
).fetchone()
if not row:
msg = f"DuckDB secret {secret_name!r} was not visible after creation"
raise RuntimeError(msg)
actual_name, actual_type, actual_scope, actual_persistent = row
expected_persistent = bool(secret_config.get("persistent", False))
scope = secret_config.get("scope")
scopes = list(actual_scope or [])
if (
actual_name != secret_name
or str(actual_type).lower() != secret_type.lower()
or bool(actual_persistent) != expected_persistent
or (scope is not None and scope not in scopes)
):
msg = f"DuckDB secret {secret_name!r} verification failed"
raise RuntimeError(msg)