"""MysqlConnector database configuration."""
import contextlib
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from weakref import WeakSet
import mysql.connector
from mysql.connector import pooling
from typing_extensions import NotRequired
from sqlspec.adapters.mysqlconnector._typing import (
MysqlConnectorAsyncConnection,
MysqlConnectorAsyncSessionContext,
MysqlConnectorSyncConnection,
MysqlConnectorSyncSessionContext,
)
from sqlspec.adapters.mysqlconnector.core import apply_driver_features, default_statement_config
from sqlspec.adapters.mysqlconnector.driver import (
MysqlConnectorAsyncCursor,
MysqlConnectorAsyncDriver,
MysqlConnectorAsyncExceptionHandler,
MysqlConnectorSyncCursor,
MysqlConnectorSyncDriver,
MysqlConnectorSyncExceptionHandler,
)
from sqlspec.config import ExtensionConfigs, NoPoolAsyncConfig, SyncDatabaseConfig
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from mysql.connector.pooling import MySQLConnectionPool
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = (
"MysqlConnectorAsyncConfig",
"MysqlConnectorAsyncConnectionParams",
"MysqlConnectorDriverFeatures",
"MysqlConnectorPoolParams",
"MysqlConnectorSyncConfig",
"MysqlConnectorSyncConnectionParams",
)
class MysqlConnectorSyncConnectionParams(TypedDict):
"""MysqlConnector sync connection parameters."""
host: NotRequired[str]
user: NotRequired[str]
password: NotRequired[str]
database: NotRequired[str]
port: NotRequired[int]
unix_socket: NotRequired[str]
charset: NotRequired[str]
connection_timeout: NotRequired[int]
autocommit: NotRequired[bool]
use_pure: NotRequired[bool]
ssl_ca: NotRequired[str]
ssl_cert: NotRequired[str]
ssl_key: NotRequired[str]
ssl_verify_cert: NotRequired[bool]
ssl_verify_identity: NotRequired[bool]
client_flags: NotRequired[int]
pool_name: NotRequired[str]
pool_size: NotRequired[int]
pool_reset_session: NotRequired[bool]
extra: NotRequired["dict[str, Any]"]
class MysqlConnectorPoolParams(MysqlConnectorSyncConnectionParams):
"""MysqlConnector pooling parameters.
Note: pool_name, pool_size, and pool_reset_session are inherited
from MysqlConnectorSyncConnectionParams.
"""
class MysqlConnectorAsyncConnectionParams(TypedDict):
"""MysqlConnector async connection parameters."""
host: NotRequired[str]
user: NotRequired[str]
password: NotRequired[str]
database: NotRequired[str]
port: NotRequired[int]
unix_socket: NotRequired[str]
charset: NotRequired[str]
connection_timeout: NotRequired[int]
autocommit: NotRequired[bool]
use_pure: NotRequired[bool]
ssl_ca: NotRequired[str]
ssl_cert: NotRequired[str]
ssl_key: NotRequired[str]
ssl_verify_cert: NotRequired[bool]
ssl_verify_identity: NotRequired[bool]
client_flags: NotRequired[int]
extra: NotRequired["dict[str, Any]"]
class MysqlConnectorDriverFeatures(TypedDict):
"""MysqlConnector driver feature flags.
json_serializer: Custom JSON serializer function.
Defaults to sqlspec.utils.serializers.to_json.
json_deserializer: Custom JSON deserializer function.
Defaults to sqlspec.utils.serializers.from_json.
on_connection_create: Callback executed when a connection is acquired.
For sync: Callable[[MysqlConnectorSyncConnection], None]
For async: Callable[[MysqlConnectorAsyncConnection], Awaitable[None]]
Called exactly once per physical connection using WeakSet tracking.
enable_events: Enable database event channel support.
Defaults to True when extension_config["events"] is configured.
events_backend: Event channel backend selection.
Only option: "table_queue".
"""
json_serializer: NotRequired["Callable[[Any], str]"]
json_deserializer: NotRequired["Callable[[str], Any]"]
on_connection_create: "NotRequired[Callable[..., Any]]"
enable_events: NotRequired[bool]
events_backend: NotRequired[str]
class MysqlConnectorSyncConnectionContext:
"""Context manager for mysql-connector sync connections."""
__slots__ = ("_config", "_connection")
def __init__(self, config: "MysqlConnectorSyncConfig") -> None:
self._config = config
self._connection: MysqlConnectorSyncConnection | None = None
def __enter__(self) -> MysqlConnectorSyncConnection:
if self._config.connection_instance is not None:
self._connection = cast("MysqlConnectorSyncConnection", self._config.connection_instance.get_connection())
else:
self._connection = self._config.create_connection()
self._config._ensure_connection_initialized(self._connection) # pyright: ignore[reportPrivateUsage]
return self._connection
def __exit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any
) -> bool | None:
if self._connection is not None:
self._connection.close()
self._connection = None
return None
class _MysqlConnectorSyncSessionConnectionHandler:
__slots__ = ("_config", "_connection")
def __init__(self, config: "MysqlConnectorSyncConfig") -> None:
self._config = config
self._connection: MysqlConnectorSyncConnection | None = None
def acquire_connection(self) -> MysqlConnectorSyncConnection:
if self._config.connection_instance is not None:
self._connection = cast("MysqlConnectorSyncConnection", self._config.connection_instance.get_connection())
else:
self._connection = self._config.create_connection()
self._config._ensure_connection_initialized(self._connection) # pyright: ignore[reportPrivateUsage]
return self._connection
def release_connection(self, _conn: MysqlConnectorSyncConnection) -> None:
if self._connection is None:
return
self._connection.close()
self._connection = None
class MysqlConnectorAsyncConnectionContext:
"""Async context manager for mysql-connector async connections."""
__slots__ = ("_config", "_connection")
def __init__(self, config: "MysqlConnectorAsyncConfig") -> None:
self._config = config
self._connection: MysqlConnectorAsyncConnection | None = None
async def __aenter__(self) -> MysqlConnectorAsyncConnection:
self._connection = await self._config.create_connection()
return self._connection
async def __aexit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any
) -> bool | None:
if self._connection is not None:
await self._connection.close()
self._connection = None
return None
class _MysqlConnectorAsyncSessionConnectionHandler:
__slots__ = ("_config", "_connection")
def __init__(self, config: "MysqlConnectorAsyncConfig") -> None:
self._config = config
self._connection: MysqlConnectorAsyncConnection | None = None
async def acquire_connection(self) -> MysqlConnectorAsyncConnection:
self._connection = await self._config.create_connection()
return self._connection
async def release_connection(self, _conn: MysqlConnectorAsyncConnection) -> None:
if self._connection is None:
return
await self._connection.close()
self._connection = None
[docs]
class MysqlConnectorSyncConfig(
SyncDatabaseConfig[MysqlConnectorSyncConnection, "MySQLConnectionPool", MysqlConnectorSyncDriver]
):
"""Configuration for mysql-connector synchronous MySQL connections."""
driver_type: ClassVar[type[MysqlConnectorSyncDriver]] = MysqlConnectorSyncDriver
connection_type: ClassVar[type[MysqlConnectorSyncConnection]] = MysqlConnectorSyncConnection
supports_transactional_ddl: ClassVar[bool] = False
supports_native_arrow_export: ClassVar[bool] = True
supports_native_parquet_export: ClassVar[bool] = True
supports_native_arrow_import: ClassVar[bool] = True
supports_native_parquet_import: ClassVar[bool] = True
[docs]
def __init__(
self,
*,
connection_config: "MysqlConnectorPoolParams | dict[str, Any] | None" = None,
connection_instance: "MySQLConnectionPool | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "MysqlConnectorDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
connection_config = normalize_connection_config(connection_config)
connection_config.setdefault("host", "localhost")
connection_config.setdefault("port", 3306)
statement_config = statement_config or default_statement_config
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
# Extract user connection hook before storing driver_features
features_dict = dict(driver_features) if driver_features else {}
self._user_connection_hook: Callable[[MysqlConnectorSyncConnection], None] | None = features_dict.pop(
"on_connection_create", None
)
# Track initialized connections to ensure callback runs exactly once per physical connection
self._initialized_connections: WeakSet[Any] = WeakSet()
super().__init__(
connection_config=connection_config,
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=statement_config,
driver_features=features_dict,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
def _ensure_connection_initialized(self, connection: "MysqlConnectorSyncConnection") -> None:
"""Ensure connection callback has been called exactly once for this connection."""
if self._user_connection_hook is None:
return
if connection not in self._initialized_connections:
self._user_connection_hook(connection)
self._initialized_connections.add(connection)
def _create_pool(self) -> "MySQLConnectionPool":
config = dict(self.connection_config)
pool_name = config.pop("pool_name", None)
pool_size = config.pop("pool_size", None)
pool_reset = config.pop("pool_reset_session", True)
return pooling.MySQLConnectionPool(
pool_name=pool_name, pool_size=pool_size or 5, pool_reset_session=pool_reset, **config
)
def _close_pool(self) -> None:
if self.connection_instance is not None:
self.connection_instance = None
[docs]
def create_connection(self) -> MysqlConnectorSyncConnection:
connection = cast("MysqlConnectorSyncConnection", mysql.connector.connect(**self.connection_config))
autocommit = self.connection_config.get("autocommit")
if autocommit is not None and hasattr(connection, "autocommit"):
with contextlib.suppress(Exception):
setattr(connection, "autocommit", bool(autocommit))
return connection
[docs]
def provide_connection(self, *args: Any, **kwargs: Any) -> "MysqlConnectorSyncConnectionContext":
return MysqlConnectorSyncConnectionContext(self)
[docs]
def provide_session(
self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any
) -> "MysqlConnectorSyncSessionContext":
statement_config = statement_config or self.statement_config or default_statement_config
handler = _MysqlConnectorSyncSessionConnectionHandler(self)
return MysqlConnectorSyncSessionContext(
acquire_connection=handler.acquire_connection,
release_connection=handler.release_connection,
statement_config=statement_config,
driver_features=self.driver_features,
prepare_driver=self._prepare_driver,
)
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
namespace = super().get_signature_namespace()
namespace.update({
"MysqlConnectorSyncConfig": MysqlConnectorSyncConfig,
"MysqlConnectorSyncConnection": MysqlConnectorSyncConnection,
"MysqlConnectorSyncConnectionParams": MysqlConnectorSyncConnectionParams,
"MysqlConnectorSyncCursor": MysqlConnectorSyncCursor,
"MysqlConnectorSyncDriver": MysqlConnectorSyncDriver,
"MysqlConnectorSyncExceptionHandler": MysqlConnectorSyncExceptionHandler,
"MysqlConnectorSyncSessionContext": MysqlConnectorSyncSessionContext,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
return EventRuntimeHints(poll_interval=0.25, lease_seconds=5, select_for_update=True, skip_locked=True)
[docs]
class MysqlConnectorAsyncConfig(NoPoolAsyncConfig[MysqlConnectorAsyncConnection, MysqlConnectorAsyncDriver]):
"""Configuration for mysql-connector async MySQL connections."""
driver_type: ClassVar[type[MysqlConnectorAsyncDriver]] = MysqlConnectorAsyncDriver
connection_type: "ClassVar[type[Any]]" = cast("type[Any]", MysqlConnectorAsyncConnection)
supports_transactional_ddl: ClassVar[bool] = False
supports_native_arrow_export: ClassVar[bool] = True
supports_native_parquet_export: ClassVar[bool] = True
supports_native_arrow_import: ClassVar[bool] = True
supports_native_parquet_import: ClassVar[bool] = True
[docs]
def __init__(
self,
*,
connection_config: "MysqlConnectorAsyncConnectionParams | dict[str, Any] | None" = None,
connection_instance: Any = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "MysqlConnectorDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
self.connection_config = normalize_connection_config(connection_config)
self.connection_config.setdefault("host", "localhost")
self.connection_config.setdefault("port", 3306)
statement_config = statement_config or default_statement_config
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
# Extract user connection hook before storing driver_features
features_dict = dict(driver_features) if driver_features else {}
self._user_connection_hook: Callable[[MysqlConnectorAsyncConnection], Awaitable[None]] | None = (
features_dict.pop("on_connection_create", None)
)
super().__init__(
connection_config=self.connection_config,
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=statement_config,
driver_features=features_dict,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
[docs]
async def create_connection(self) -> MysqlConnectorAsyncConnection:
from mysql.connector import aio
connection = await aio.connect(**self.connection_config)
autocommit = self.connection_config.get("autocommit")
if autocommit is not None and hasattr(connection, "set_autocommit"):
with contextlib.suppress(Exception):
await connection.set_autocommit(bool(autocommit))
# Call user-provided callback after connection setup
if self._user_connection_hook is not None:
await self._user_connection_hook(cast("MysqlConnectorAsyncConnection", connection))
return cast("MysqlConnectorAsyncConnection", connection)
[docs]
def provide_connection(self, *args: Any, **kwargs: Any) -> "MysqlConnectorAsyncConnectionContext":
return MysqlConnectorAsyncConnectionContext(self)
[docs]
def provide_session(
self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any
) -> "MysqlConnectorAsyncSessionContext":
statement_config = statement_config or self.statement_config or default_statement_config
handler = _MysqlConnectorAsyncSessionConnectionHandler(self)
return MysqlConnectorAsyncSessionContext(
acquire_connection=handler.acquire_connection,
release_connection=handler.release_connection,
statement_config=statement_config,
driver_features=self.driver_features,
prepare_driver=self._prepare_driver,
)
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
namespace = super().get_signature_namespace()
namespace.update({
"MysqlConnectorAsyncConfig": MysqlConnectorAsyncConfig,
"MysqlConnectorAsyncConnection": MysqlConnectorAsyncConnection,
"MysqlConnectorAsyncConnectionParams": MysqlConnectorAsyncConnectionParams,
"MysqlConnectorAsyncCursor": MysqlConnectorAsyncCursor,
"MysqlConnectorAsyncDriver": MysqlConnectorAsyncDriver,
"MysqlConnectorAsyncExceptionHandler": MysqlConnectorAsyncExceptionHandler,
"MysqlConnectorAsyncSessionContext": MysqlConnectorAsyncSessionContext,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
return EventRuntimeHints(poll_interval=0.25, lease_seconds=5, select_for_update=True, skip_locked=True)