"""pymssql database configuration with thread-local connections."""
import contextlib
import logging
import threading
import time
import uuid
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, cast
from sqlspec.adapters.pymssql._typing import PYMSSQL_MODULE, PymssqlConnection
from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context
if TYPE_CHECKING:
from collections.abc import Callable, Generator
__all__ = ("PymssqlConnectionPool",)
logger = get_logger(POOL_LOGGER_NAME)
_ADAPTER_NAME = "pymssql"
pymssql = PYMSSQL_MODULE
[docs]
class PymssqlConnectionPool:
"""Thread-local connection manager for pymssql."""
__slots__ = (
"_connection_parameters",
"_health_check_interval",
"_on_connection_create",
"_pool_id",
"_recycle_seconds",
"_thread_local",
)
[docs]
def __init__(
self,
connection_parameters: "dict[str, Any]",
recycle_seconds: int = 86400,
health_check_interval: float = 30.0,
on_connection_create: "Callable[[PymssqlConnection], None] | None" = None,
) -> None:
"""Initialize the thread-local connection manager.
Args:
connection_parameters: pymssql connection parameters
recycle_seconds: Connection recycle time in seconds (default 24h)
health_check_interval: Seconds of idle time before running health check
on_connection_create: Callback executed when connection is created
"""
self._connection_parameters = connection_parameters
self._thread_local = threading.local()
self._recycle_seconds = recycle_seconds
self._health_check_interval = health_check_interval
self._on_connection_create = on_connection_create
self._pool_id = str(uuid.uuid4())[:8]
@property
def _database_name(self) -> str:
"""Get sanitized database name for logging."""
return str(self._connection_parameters.get("database", "unknown"))
def _create_connection(self) -> PymssqlConnection:
connection = pymssql.connect(**self._connection_parameters)
# Call user-provided callback after connection creation
if self._on_connection_create is not None:
self._on_connection_create(connection)
return cast("PymssqlConnection", connection)
def _is_connection_alive(self, connection: PymssqlConnection) -> bool:
try:
cursor = connection.cursor()
try:
cursor.execute("SELECT 1")
cursor.fetchone()
finally:
cursor.close()
except Exception:
return False
return True
def _get_thread_connection(self) -> PymssqlConnection:
thread_state = self._thread_local.__dict__
if "connection" not in thread_state:
self._thread_local.connection = self._create_connection()
self._thread_local.created_at = time.time()
self._thread_local.last_used = time.time()
return cast("PymssqlConnection", self._thread_local.connection)
if self._recycle_seconds > 0 and time.time() - self._thread_local.created_at > self._recycle_seconds:
log_with_context(
logger,
logging.DEBUG,
"pool.connection.recycle",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
recycle_seconds=self._recycle_seconds,
reason="exceeded_recycle_time",
)
with contextlib.suppress(Exception):
self._thread_local.connection.close()
self._thread_local.connection = self._create_connection()
self._thread_local.created_at = time.time()
self._thread_local.last_used = time.time()
return cast("PymssqlConnection", self._thread_local.connection)
idle_time = time.time() - thread_state.get("last_used", 0)
if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection):
log_with_context(
logger,
logging.DEBUG,
"pool.connection.recycle",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
idle_seconds=round(idle_time, 1),
reason="failed_health_check",
)
with contextlib.suppress(Exception):
self._thread_local.connection.close()
self._thread_local.connection = self._create_connection()
self._thread_local.created_at = time.time()
self._thread_local.last_used = time.time()
return cast("PymssqlConnection", self._thread_local.connection)
def _close_thread_connection(self) -> None:
thread_state = self._thread_local.__dict__
if "connection" in thread_state:
with contextlib.suppress(Exception):
self._thread_local.connection.close()
del self._thread_local.connection
if "created_at" in thread_state:
del self._thread_local.created_at
if "last_used" in thread_state:
del self._thread_local.last_used
[docs]
@contextmanager
def get_connection(self) -> "Generator[PymssqlConnection, None, None]":
"""Get a thread-local connection."""
connection = self._get_thread_connection()
try:
yield connection
except Exception:
with contextlib.suppress(Exception):
self._close_thread_connection()
raise
def close(self) -> None:
self._close_thread_connection()
def acquire(self) -> PymssqlConnection:
return self._get_thread_connection()
def release(self, connection: PymssqlConnection) -> None:
_ = connection
def size(self) -> int:
try:
_ = self._thread_local.connection
except AttributeError:
return 0
else:
return 1
def checked_out(self) -> int:
return 0