Source code for sqlspec.adapters.pymssql.pool

"""pymssql database configuration with thread-local connections."""

import contextlib
import logging
import threading
import time
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, cast

from sqlspec.adapters.pymssql._typing import PYMSSQL_MODULE, PymssqlConnection
from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context

if TYPE_CHECKING:
    from collections.abc import Callable, Generator

__all__ = ("PymssqlConnectionPool",)


logger = get_logger(POOL_LOGGER_NAME)
_ADAPTER_NAME = "pymssql"
pymssql = PYMSSQL_MODULE


[docs] class PymssqlConnectionPool: """Thread-local connection manager for pymssql.""" __slots__ = ( "_connection_parameters", "_health_check_interval", "_on_connection_create", "_pool_id", "_recycle_seconds", "_thread_local", )
[docs] def __init__( self, connection_parameters: "dict[str, Any]", recycle_seconds: int = 86400, health_check_interval: float = 30.0, on_connection_create: "Callable[[PymssqlConnection], None] | None" = None, ) -> None: """Initialize the thread-local connection manager. Args: connection_parameters: pymssql connection parameters recycle_seconds: Connection recycle time in seconds (default 24h) health_check_interval: Seconds of idle time before running health check on_connection_create: Callback executed when connection is created """ self._connection_parameters = connection_parameters self._thread_local = threading.local() self._recycle_seconds = recycle_seconds self._health_check_interval = health_check_interval self._on_connection_create = on_connection_create self._pool_id = str(uuid.uuid4())[:8]
@property def _database_name(self) -> str: """Get sanitized database name for logging.""" return str(self._connection_parameters.get("database", "unknown")) def _create_connection(self) -> PymssqlConnection: connection = pymssql.connect(**self._connection_parameters) # Call user-provided callback after connection creation if self._on_connection_create is not None: self._on_connection_create(connection) return cast("PymssqlConnection", connection) def _is_connection_alive(self, connection: PymssqlConnection) -> bool: try: cursor = connection.cursor() try: cursor.execute("SELECT 1") cursor.fetchone() finally: cursor.close() except Exception: return False return True def _get_thread_connection(self) -> PymssqlConnection: thread_state = self._thread_local.__dict__ if "connection" not in thread_state: self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("PymssqlConnection", self._thread_local.connection) if self._recycle_seconds > 0 and time.time() - self._thread_local.created_at > self._recycle_seconds: log_with_context( logger, logging.DEBUG, "pool.connection.recycle", adapter=_ADAPTER_NAME, pool_id=self._pool_id, database=self._database_name, recycle_seconds=self._recycle_seconds, reason="exceeded_recycle_time", ) with contextlib.suppress(Exception): self._thread_local.connection.close() self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("PymssqlConnection", self._thread_local.connection) idle_time = time.time() - thread_state.get("last_used", 0) if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection): log_with_context( logger, logging.DEBUG, "pool.connection.recycle", adapter=_ADAPTER_NAME, pool_id=self._pool_id, database=self._database_name, idle_seconds=round(idle_time, 1), reason="failed_health_check", ) with contextlib.suppress(Exception): self._thread_local.connection.close() self._thread_local.connection = self._create_connection() self._thread_local.created_at = time.time() self._thread_local.last_used = time.time() return cast("PymssqlConnection", self._thread_local.connection) def _close_thread_connection(self) -> None: thread_state = self._thread_local.__dict__ if "connection" in thread_state: with contextlib.suppress(Exception): self._thread_local.connection.close() del self._thread_local.connection if "created_at" in thread_state: del self._thread_local.created_at if "last_used" in thread_state: del self._thread_local.last_used
[docs] @contextmanager def get_connection(self) -> "Generator[PymssqlConnection, None, None]": """Get a thread-local connection.""" connection = self._get_thread_connection() try: yield connection except Exception: with contextlib.suppress(Exception): self._close_thread_connection() raise
def close(self) -> None: self._close_thread_connection() def acquire(self) -> PymssqlConnection: return self._get_thread_connection() def release(self, connection: PymssqlConnection) -> None: _ = connection def size(self) -> int: try: _ = self._thread_local.connection except AttributeError: return 0 else: return 1 def checked_out(self) -> int: return 0