Source code for sqlspec.adapters.sqlite.pool

"""SQLite database configuration with thread-local connections."""

import contextlib
import logging
import sqlite3
import threading
import time
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, TypedDict, cast

from typing_extensions import NotRequired

from sqlspec.adapters.sqlite._typing import SqliteConnection
from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context

if TYPE_CHECKING:
    from collections.abc import Callable, Generator


class SqliteConnectionParams(TypedDict):
    """SQLite connection parameters."""

    database: NotRequired[str]
    timeout: NotRequired[float]
    detect_types: NotRequired[int]
    isolation_level: "NotRequired[str | None]"
    check_same_thread: NotRequired[bool]
    factory: "NotRequired[type[SqliteConnection] | None]"
    cached_statements: NotRequired[int]
    uri: NotRequired[bool]


__all__ = ("SqliteConnectionPool",)

logger = get_logger(POOL_LOGGER_NAME)
_ADAPTER_NAME = "sqlite"


[docs] class SqliteConnectionPool: """Thread-local connection manager for SQLite. SQLite connections aren't thread-safe, so we use thread-local storage to ensure each thread has its own connection. This is simpler and more efficient than a traditional pool for SQLite's constraints. """ __slots__ = ( "_connection_parameters", "_enable_optimizations", "_health_check_interval", "_on_connection_create", "_pool_id", "_recycle_seconds", "_thread_local", )
[docs] def __init__( self, connection_parameters: "dict[str, Any]", enable_optimizations: bool = True, recycle_seconds: int = 86400, health_check_interval: float = 30.0, on_connection_create: "Callable[[SqliteConnection], None] | None" = None, ) -> None: """Initialize the thread-local connection manager. Args: connection_parameters: SQLite connection parameters enable_optimizations: Whether to apply performance PRAGMAs recycle_seconds: Connection recycle time in seconds (default 24h) health_check_interval: Seconds of idle time before running health check on_connection_create: Callback executed when connection is created """ if "check_same_thread" not in connection_parameters: connection_parameters = {**connection_parameters, "check_same_thread": False} self._connection_parameters = connection_parameters self._thread_local = threading.local() self._enable_optimizations = enable_optimizations self._recycle_seconds = recycle_seconds self._health_check_interval = health_check_interval self._on_connection_create = on_connection_create self._pool_id = str(uuid.uuid4())[:8]
@property def _database_name(self) -> str: """Get sanitized database name for logging.""" db = self._connection_parameters.get("database", ":memory:") if db == ":memory:" or "mode=memory" in str(db): return ":memory:" return str(db) def _create_connection(self) -> SqliteConnection: """Create a new SQLite connection with optimizations.""" connection = sqlite3.connect(**self._connection_parameters) if self._enable_optimizations: database = self._connection_parameters.get("database", ":memory:") is_memory = database == ":memory:" or "mode=memory" in str(database) if is_memory: connection.execute("PRAGMA journal_mode = MEMORY") connection.execute("PRAGMA synchronous = OFF") connection.execute("PRAGMA temp_store = MEMORY") else: connection.execute("PRAGMA journal_mode = WAL") connection.execute("PRAGMA synchronous = NORMAL") connection.execute("PRAGMA busy_timeout = 5000") connection.execute("PRAGMA foreign_keys = ON") # Call user-provided callback after internal setup if self._on_connection_create is not None: self._on_connection_create(connection) return connection # type: ignore[no-any-return] def _is_connection_alive(self, connection: SqliteConnection) -> bool: """Check if a connection is still alive and usable. Args: connection: Connection to check Returns: True if connection is alive, False otherwise """ try: connection.execute("SELECT 1") except Exception: return False return True def _get_thread_connection(self) -> SqliteConnection: """Get or create a connection for the current thread.""" thread_state = self._thread_local.__dict__ if "connection" not in thread_state: self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("SqliteConnection", self._thread_local.connection) if self._recycle_seconds > 0 and time.time() - self._thread_local.created_at > self._recycle_seconds: log_with_context( logger, logging.DEBUG, "pool.connection.recycle", adapter=_ADAPTER_NAME, pool_id=self._pool_id, database=self._database_name, recycle_seconds=self._recycle_seconds, reason="exceeded_recycle_time", ) with contextlib.suppress(Exception): self._thread_local.connection.close() self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("SqliteConnection", self._thread_local.connection) idle_time = time.time() - thread_state.get("last_used", 0) if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection): log_with_context( logger, logging.DEBUG, "pool.connection.recycle", adapter=_ADAPTER_NAME, pool_id=self._pool_id, database=self._database_name, idle_seconds=round(idle_time, 1), reason="failed_health_check", ) with contextlib.suppress(Exception): self._thread_local.connection.close() self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("SqliteConnection", self._thread_local.connection) def _close_thread_connection(self) -> None: """Close the connection for the current thread.""" thread_state = self._thread_local.__dict__ if "connection" in thread_state: with contextlib.suppress(Exception): self._thread_local.connection.close() del self._thread_local.connection if "created_at" in thread_state: del self._thread_local.created_at if "last_used" in thread_state: del self._thread_local.last_used
[docs] @contextmanager def get_connection(self) -> "Generator[SqliteConnection, None, None]": """Get a thread-local connection. Yields: SqliteConnection: A thread-local connection. """ connection = self._get_thread_connection() try: yield connection finally: with contextlib.suppress(Exception): if connection.in_transaction: connection.commit()
[docs] def close(self) -> None: """Close the thread-local connection if it exists.""" self._close_thread_connection()
[docs] def acquire(self) -> SqliteConnection: """Acquire a thread-local connection. Returns: SqliteConnection: A thread-local connection """ return self._get_thread_connection()
[docs] def release(self, connection: SqliteConnection) -> None: """Release a connection (no-op for thread-local connections). Args: connection: The connection to release (ignored) """
[docs] def size(self) -> int: """Get pool size (always 1 for thread-local).""" try: _ = self._thread_local.connection except AttributeError: return 0 else: return 1
[docs] def checked_out(self) -> int: """Get number of checked out connections (always 0).""" return 0