Source code for sqlspec.adapters.aiomysql.driver

"""aiomysql MySQL driver implementation.

Provides MySQL/MariaDB connectivity with parameter style conversion,
type coercion, error handling, and transaction management.

aiomysql is built on top of PyMySQL, so error classes come from pymysql.err
rather than a driver-specific error module.
"""

from collections.abc import Sized
from typing import TYPE_CHECKING, Any, Final, cast

import pymysql.err  # pyright: ignore
from pymysql.constants import FIELD_TYPE as PYMYSQL_FIELD_TYPE  # pyright: ignore

from sqlspec.adapters.aiomysql._typing import AiomysqlCursor, AiomysqlSessionContext
from sqlspec.adapters.aiomysql.core import (
    build_insert_statement,
    collect_rows,
    create_mapped_exception,
    default_statement_config,
    detect_json_columns_from_description,
    driver_profile,
    format_identifier,
    normalize_execute_many_parameters,
    normalize_execute_parameters,
    normalize_lastrowid,
    resolve_column_names,
    resolve_many_rowcount,
    resolve_rowcount,
)
from sqlspec.adapters.aiomysql.data_dictionary import AiomysqlDataDictionary
from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile
from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json
from sqlspec.utils.type_guards import supports_json_type

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.adapters.aiomysql._typing import AiomysqlConnection
    from sqlspec.core import SQL, StatementConfig
    from sqlspec.driver import ExecutionResult
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry

__all__ = ("AiomysqlCursor", "AiomysqlDriver", "AiomysqlExceptionHandler", "AiomysqlSessionContext")

logger = get_logger(__name__)

json_type_value = (
    PYMYSQL_FIELD_TYPE.JSON if PYMYSQL_FIELD_TYPE is not None and supports_json_type(PYMYSQL_FIELD_TYPE) else None
)
AIOMYSQL_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set()


class AiomysqlExceptionHandler(BaseAsyncExceptionHandler):
    """Async context manager for handling aiomysql (MySQL) database exceptions.

    Maps MySQL error codes and SQLSTATE to specific SQLSpec exceptions
    for better error handling in application code.

    aiomysql uses pymysql.err.Error as its base exception class since
    aiomysql is built on top of PyMySQL.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __aexit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, pymysql.err.Error):
            result = create_mapped_exception(exc_val, logger=logger)
            if result is True:
                return True
            self.pending_exception = cast("Exception", result)
            return True
        return False


[docs] class AiomysqlDriver(AsyncDriverAdapterBase): """MySQL/MariaDB database driver using aiomysql client library. Implements asynchronous database operations for MySQL and MariaDB servers with support for parameter style conversion, type coercion, error handling, and transaction management. """ __slots__ = ("_data_dictionary",) dialect = "mysql"
[docs] def __init__( self, connection: "AiomysqlConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: AiomysqlDataDictionary | None = None
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS - The Execution Engine # ─────────────────────────────────────────────────────────────────────────────
[docs] async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: cursor: aiomysql cursor object statement: SQL statement to execute Returns: ExecutionResult: Statement execution results with data or row counts """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) if statement.returns_rows(): fetched_data = await cursor.fetchall() description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, AIOMYSQL_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, row_format = collect_rows( fetched_data, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return self.create_execution_result( cursor, selected_data=rows, column_names=column_names, data_row_count=len(rows), is_select_result=True, row_format=row_format, ) affected_rows = resolve_rowcount(cursor) last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id)
[docs] async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL statement with multiple parameter sets. Args: cursor: aiomysql cursor object statement: SQL statement with multiple parameter sets Returns: ExecutionResult: Batch execution results Raises: ValueError: If no parameters provided for executemany operation """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None await cursor.executemany(sql, prepared_parameters) affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Args: cursor: aiomysql cursor object statement: SQL script to execute Returns: ExecutionResult: Script execution results with statement count """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: await cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
# ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] async def begin(self) -> None: """Begin a database transaction. Raises: SQLSpecError: If transaction initialization fails """ try: async with AiomysqlCursor(self.connection) as cursor: await cursor.execute("BEGIN") except pymysql.err.MySQLError as e: msg = f"Failed to begin MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def commit(self) -> None: """Commit the current transaction. Raises: SQLSpecError: If transaction commit fails """ try: await self.connection.commit() except pymysql.err.MySQLError as e: msg = f"Failed to commit MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def rollback(self) -> None: """Rollback the current transaction. Raises: SQLSpecError: If transaction rollback fails """ try: await self.connection.rollback() except pymysql.err.MySQLError as e: msg = f"Failed to rollback MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] def with_cursor(self, connection: "AiomysqlConnection") -> "AiomysqlCursor": """Create cursor context manager for the connection. Args: connection: aiomysql database connection Returns: AiomysqlCursor: Context manager for cursor operations """ return AiomysqlCursor(connection)
[docs] def handle_database_exceptions(self) -> "AiomysqlExceptionHandler": """Provide exception handling context manager. Returns: AiomysqlExceptionHandler: Context manager for aiomysql exception handling """ return AiomysqlExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # STORAGE API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute a query and stream Arrow-formatted results into storage.""" self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline = self._storage_pipeline() telemetry_payload = await self._write_result_to_storage_async( arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Load Arrow data into MySQL using batched inserts.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: statement = f"TRUNCATE TABLE {format_identifier(table)}" exc_handler = self.handle_database_exceptions() async with exc_handler, self.with_cursor(self.connection) as cursor: await cursor.execute(statement) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: insert_sql = build_insert_statement(table, columns) exc_handler = self.handle_database_exceptions() async with exc_handler, self.with_cursor(self.connection) as cursor: await cursor.executemany(insert_sql, records) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Load staged artifacts from storage into MySQL.""" arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) return await self.load_from_arrow( table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound )
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "AiomysqlDataDictionary": """Get the data dictionary for this driver. Returns: Data dictionary instance for metadata queries """ if self._data_dictionary is None: self._data_dictionary = AiomysqlDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE/INTERNAL METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect aiomysql rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, AIOMYSQL_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, _row_format = collect_rows( fetched, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return rows, column_names, len(rows)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from aiomysql cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction. aiomysql does not expose reliable transaction state. Returns: False - aiomysql requires explicit transaction management. """ return False
register_driver_profile("aiomysql", driver_profile)