"""PyMySQL database configuration."""
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from typing_extensions import NotRequired
from sqlspec.adapters.pymysql._typing import PyMysqlConnection, PyMysqlSessionContext
from sqlspec.adapters.pymysql.core import apply_driver_features, default_statement_config
from sqlspec.adapters.pymysql.driver import PyMysqlCursor, PyMysqlDriver, PyMysqlExceptionHandler
from sqlspec.adapters.pymysql.pool import PyMysqlConnectionPool
from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config
if TYPE_CHECKING:
from collections.abc import Callable
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = ("PyMysqlConfig", "PyMysqlConnectionParams", "PyMysqlDriverFeatures", "PyMysqlPoolParams")
class PyMysqlConnectionParams(TypedDict):
"""PyMySQL connection parameters."""
host: NotRequired[str]
user: NotRequired[str]
password: NotRequired[str]
database: NotRequired[str]
port: NotRequired[int]
unix_socket: NotRequired[str]
charset: NotRequired[str]
connect_timeout: NotRequired[int]
read_timeout: NotRequired[int]
write_timeout: NotRequired[int]
autocommit: NotRequired[bool]
ssl: NotRequired["dict[str, Any]"]
client_flag: NotRequired[int]
cursorclass: NotRequired[type]
init_command: NotRequired[str]
sql_mode: NotRequired[str]
extra: NotRequired["dict[str, Any]"]
class PyMysqlPoolParams(PyMysqlConnectionParams):
"""PyMySQL pool parameters."""
pool_recycle_seconds: NotRequired[int]
health_check_interval: NotRequired[float]
class PyMysqlDriverFeatures(TypedDict):
"""PyMySQL driver feature flags.
json_serializer: Custom JSON serializer function.
Defaults to sqlspec.utils.serializers.to_json.
json_deserializer: Custom JSON deserializer function.
Defaults to sqlspec.utils.serializers.from_json.
on_connection_create: Callback executed when a connection is created.
Receives the raw pymysql connection for low-level driver configuration.
Runs after connection creation.
enable_events: Enable database event channel support.
events_backend: Event channel backend selection.
"""
json_serializer: NotRequired["Callable[[Any], str]"]
json_deserializer: NotRequired["Callable[[str], Any]"]
on_connection_create: "NotRequired[Callable[[PyMysqlConnection], None]]"
enable_events: NotRequired[bool]
events_backend: NotRequired[str]
class PyMysqlConnectionContext:
"""Context manager for PyMySQL connections."""
__slots__ = ("_config", "_ctx")
def __init__(self, config: "PyMysqlConfig") -> None:
self._config = config
self._ctx: Any = None
def __enter__(self) -> PyMysqlConnection:
pool = self._config.provide_pool()
self._ctx = pool.get_connection()
return cast("PyMysqlConnection", self._ctx.__enter__())
def __exit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any
) -> bool | None:
if self._ctx:
return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb))
return None
class _PyMysqlSessionConnectionHandler:
__slots__ = ("_config", "_ctx")
def __init__(self, config: "PyMysqlConfig") -> None:
self._config = config
self._ctx: Any = None
def acquire_connection(self) -> "PyMysqlConnection":
pool = self._config.provide_pool()
self._ctx = pool.get_connection()
return cast("PyMysqlConnection", self._ctx.__enter__())
def release_connection(self, _conn: "PyMysqlConnection") -> None:
if self._ctx is None:
return
self._ctx.__exit__(None, None, None)
self._ctx = None
[docs]
class PyMysqlConfig(SyncDatabaseConfig[PyMysqlConnection, PyMysqlConnectionPool, PyMysqlDriver]):
"""Configuration for PyMySQL synchronous connections."""
driver_type: "ClassVar[type[PyMysqlDriver]]" = PyMysqlDriver
connection_type: "ClassVar[type[PyMysqlConnection]]" = cast("type[PyMysqlConnection]", PyMysqlConnection)
supports_transactional_ddl: "ClassVar[bool]" = False
supports_native_arrow_export: "ClassVar[bool]" = True
supports_native_arrow_import: "ClassVar[bool]" = True
supports_native_parquet_export: "ClassVar[bool]" = True
supports_native_parquet_import: "ClassVar[bool]" = True
[docs]
def __init__(
self,
*,
connection_config: "PyMysqlPoolParams | dict[str, Any] | None" = None,
connection_instance: "PyMysqlConnectionPool | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "PyMysqlDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
connection_config = normalize_connection_config(connection_config)
connection_config.setdefault("host", "localhost")
connection_config.setdefault("port", 3306)
statement_config = statement_config or default_statement_config
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
# Extract user connection hook before storing driver_features
features_dict = dict(driver_features) if driver_features else {}
self._user_connection_hook: Callable[[PyMysqlConnection], None] | None = features_dict.pop(
"on_connection_create", None
)
super().__init__(
connection_config=connection_config,
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=statement_config,
driver_features=features_dict,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
def _create_pool(self) -> "PyMysqlConnectionPool":
config = dict(self.connection_config)
pool_recycle = config.pop("pool_recycle_seconds", 86400)
health_check = config.pop("health_check_interval", 30.0)
extra = config.pop("extra", {})
config.update(extra)
return PyMysqlConnectionPool(
config,
recycle_seconds=pool_recycle,
health_check_interval=health_check,
on_connection_create=self._user_connection_hook,
)
def _close_pool(self) -> None:
if self.connection_instance:
self.connection_instance.close()
[docs]
def create_connection(self) -> PyMysqlConnection:
pool = self.provide_pool()
return pool.acquire()
[docs]
def provide_connection(self, *args: Any, **kwargs: Any) -> "PyMysqlConnectionContext":
return PyMysqlConnectionContext(self)
[docs]
def provide_session(
self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any
) -> "PyMysqlSessionContext":
handler = _PyMysqlSessionConnectionHandler(self)
return PyMysqlSessionContext(
acquire_connection=handler.acquire_connection,
release_connection=handler.release_connection,
statement_config=statement_config or self.statement_config or default_statement_config,
driver_features=self.driver_features,
prepare_driver=self._prepare_driver,
)
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
namespace = super().get_signature_namespace()
namespace.update({
"PyMysqlConnectionContext": PyMysqlConnectionContext,
"PyMysqlConnection": PyMysqlConnection,
"PyMysqlConnectionParams": PyMysqlConnectionParams,
"PyMysqlConnectionPool": PyMysqlConnectionPool,
"PyMysqlCursor": PyMysqlCursor,
"PyMysqlDriver": PyMysqlDriver,
"PyMysqlDriverFeatures": PyMysqlDriverFeatures,
"PyMysqlExceptionHandler": PyMysqlExceptionHandler,
"PyMysqlPoolParams": PyMysqlPoolParams,
"PyMysqlSessionContext": PyMysqlSessionContext,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
return EventRuntimeHints(poll_interval=0.25, lease_seconds=5, select_for_update=True, skip_locked=True)