Source code for sqlspec.adapters.mysqlconnector.driver

"""MysqlConnector MySQL driver implementation.

Provides MySQL/MariaDB connectivity with parameter style conversion,
type coercion, error handling, and transaction management.
"""

from collections.abc import Sized
from typing import TYPE_CHECKING, Any, Final, cast

import mysql.connector
from mysql.connector.constants import FieldType

from sqlspec.adapters.mysqlconnector.core import (
    build_insert_statement,
    collect_rows,
    create_mapped_exception,
    default_statement_config,
    detect_json_columns_from_description,
    driver_profile,
    format_identifier,
    normalize_execute_many_parameters,
    normalize_execute_parameters,
    normalize_lastrowid,
    resolve_column_names,
    resolve_many_rowcount,
    resolve_rowcount,
)
from sqlspec.adapters.mysqlconnector.data_dictionary import (
    MysqlConnectorAsyncDataDictionary,
    MysqlConnectorSyncDataDictionary,
)
from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json
from sqlspec.utils.type_guards import supports_json_type

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.adapters.mysqlconnector._typing import MysqlConnectorAsyncConnection, MysqlConnectorSyncConnection
    from sqlspec.core import SQL, StatementConfig
    from sqlspec.driver import ExecutionResult
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry

from typing_extensions import Self

from sqlspec.adapters.mysqlconnector._typing import MysqlConnectorAsyncSessionContext, MysqlConnectorSyncSessionContext

__all__ = (
    "MysqlConnectorAsyncCursor",
    "MysqlConnectorAsyncDriver",
    "MysqlConnectorAsyncExceptionHandler",
    "MysqlConnectorAsyncSessionContext",
    "MysqlConnectorSyncCursor",
    "MysqlConnectorSyncDriver",
    "MysqlConnectorSyncExceptionHandler",
    "MysqlConnectorSyncSessionContext",
)

logger = get_logger("sqlspec.adapters.mysqlconnector")

json_type_value = FieldType.JSON if supports_json_type(FieldType) else None
MYSQLCONNECTOR_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set()


class MysqlConnectorSyncCursor:
    """Context manager for mysql-connector sync cursor operations."""

    __slots__ = ("connection", "cursor")

    def __init__(self, connection: "MysqlConnectorSyncConnection") -> None:
        self.connection = connection
        self.cursor: Any | None = None

    def __enter__(self) -> Any:
        self.cursor = self.connection.cursor()
        return self.cursor

    def __exit__(self, *_: Any) -> None:
        if self.cursor is not None:
            self.cursor.close()


class MysqlConnectorSyncExceptionHandler:
    """Context manager for handling mysql-connector sync exceptions."""

    __slots__ = ("pending_exception",)

    def __init__(self) -> None:
        self.pending_exception: Exception | None = None

    def __enter__(self) -> Self:
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, mysql.connector.Error):
            result = create_mapped_exception(exc_val, logger=logger)
            if result is True:
                return True
            self.pending_exception = cast("Exception", result)
            return True
        return False


[docs] class MysqlConnectorSyncDriver(SyncDriverAdapterBase): """MySQL/MariaDB database driver using mysql-connector sync library.""" __slots__ = ("_data_dictionary",) dialect = "mysql"
[docs] def __init__( self, connection: "MysqlConnectorSyncConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MysqlConnectorSyncDataDictionary | None = None
[docs] def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) if statement.returns_rows(): fetched_data = cursor.fetchall() description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, MYSQLCONNECTOR_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, row_format = collect_rows( fetched_data, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return self.create_execution_result( cursor, selected_data=rows, column_names=column_names, data_row_count=len(rows), is_select_result=True, row_format=row_format, ) affected_rows = resolve_rowcount(cursor) last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id)
[docs] def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None cursor.executemany(sql, prepared_parameters) affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
[docs] def begin(self) -> None: try: with MysqlConnectorSyncCursor(self.connection) as cursor: cursor.execute("BEGIN") except mysql.connector.Error as e: msg = f"Failed to begin MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] def commit(self) -> None: try: self.connection.commit() except mysql.connector.Error as e: msg = f"Failed to commit MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] def rollback(self) -> None: try: self.connection.rollback() except mysql.connector.Error as e: msg = f"Failed to rollback MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] def with_cursor(self, connection: "MysqlConnectorSyncConnection") -> "MysqlConnectorSyncCursor": return MysqlConnectorSyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "MysqlConnectorSyncExceptionHandler": return MysqlConnectorSyncExceptionHandler()
[docs] def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) pipeline = self._storage_pipeline() telemetry_payload = self._write_result_to_storage_sync( arrow_result, destination, format_hint=format_hint, pipeline=pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: statement = f"TRUNCATE TABLE {format_identifier(table)}" exc_handler = self.handle_database_exceptions() with exc_handler, self.with_cursor(self.connection) as cursor: cursor.execute(statement) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: insert_sql = build_insert_statement(table, columns) exc_handler = self.handle_database_exceptions() with exc_handler, self.with_cursor(self.connection) as cursor: cursor.executemany(insert_sql, records) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
@property def data_dictionary(self) -> "MysqlConnectorSyncDataDictionary": if self._data_dictionary is None: self._data_dictionary = MysqlConnectorSyncDataDictionary() return self._data_dictionary
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mysql-connector sync rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, MYSQLCONNECTOR_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, _row_format = collect_rows( fetched, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return rows, column_names, len(rows)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from mysql-connector cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: autocommit = getattr(self.connection, "autocommit", None) if autocommit is not None: try: return not bool(autocommit) except Exception: return False return False
class MysqlConnectorAsyncCursor: """Async context manager for mysql-connector async cursor operations.""" __slots__ = ("connection", "cursor") def __init__(self, connection: "MysqlConnectorAsyncConnection") -> None: self.connection = connection self.cursor: Any | None = None async def __aenter__(self) -> Any: self.cursor = await self.connection.cursor() return self.cursor async def __aexit__(self, *_: Any) -> None: if self.cursor is not None: await self.cursor.close() class MysqlConnectorAsyncExceptionHandler: """Async context manager for handling mysql-connector exceptions.""" __slots__ = ("pending_exception",) def __init__(self) -> None: self.pending_exception: Exception | None = None async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: if exc_type is None: return False if issubclass(exc_type, mysql.connector.Error): result = create_mapped_exception(exc_val, logger=logger) if result is True: return True self.pending_exception = cast("Exception", result) return True return False
[docs] class MysqlConnectorAsyncDriver(AsyncDriverAdapterBase): """MySQL/MariaDB database driver using mysql-connector async library.""" __slots__ = ("_data_dictionary",) dialect = "mysql"
[docs] def __init__( self, connection: "MysqlConnectorAsyncConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MysqlConnectorAsyncDataDictionary | None = None
[docs] async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) if statement.returns_rows(): fetched_data = await cursor.fetchall() description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, MYSQLCONNECTOR_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, row_format = collect_rows( fetched_data, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return self.create_execution_result( cursor, selected_data=rows, column_names=column_names, data_row_count=len(rows), is_select_result=True, row_format=row_format, ) affected_rows = resolve_rowcount(cursor) last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id)
[docs] async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None await cursor.executemany(sql, prepared_parameters) affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: await cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
[docs] async def begin(self) -> None: try: async with MysqlConnectorAsyncCursor(self.connection) as cursor: await cursor.execute("BEGIN") except mysql.connector.Error as e: msg = f"Failed to begin MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def commit(self) -> None: try: await self.connection.commit() except mysql.connector.Error as e: msg = f"Failed to commit MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def rollback(self) -> None: try: await self.connection.rollback() except mysql.connector.Error as e: msg = f"Failed to rollback MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] def with_cursor(self, connection: "MysqlConnectorAsyncConnection") -> "MysqlConnectorAsyncCursor": return MysqlConnectorAsyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "MysqlConnectorAsyncExceptionHandler": return MysqlConnectorAsyncExceptionHandler()
[docs] async def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline = self._storage_pipeline() telemetry_payload = await self._write_result_to_storage_async( arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: statement = f"TRUNCATE TABLE {format_identifier(table)}" exc_handler = self.handle_database_exceptions() async with exc_handler, self.with_cursor(self.connection) as cursor: await cursor.execute(statement) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: insert_sql = build_insert_statement(table, columns) exc_handler = self.handle_database_exceptions() async with exc_handler, self.with_cursor(self.connection) as cursor: await cursor.executemany(insert_sql, records) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) return await self.load_from_arrow( table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound )
@property def data_dictionary(self) -> "MysqlConnectorAsyncDataDictionary": if self._data_dictionary is None: self._data_dictionary = MysqlConnectorAsyncDataDictionary() return self._data_dictionary
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mysql-connector async rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, MYSQLCONNECTOR_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, _row_format = collect_rows( fetched, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return rows, column_names, len(rows)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from mysql-connector cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: in_tx = getattr(self.connection, "in_transaction", None) if in_tx is not None: try: return bool(in_tx) except Exception: return False return False
register_driver_profile("mysql-connector", driver_profile)