Source code for sqlspec.adapters.asyncmy.driver

"""AsyncMy MySQL driver implementation.

Provides MySQL/MariaDB connectivity with parameter style conversion,
type coercion, error handling, and transaction management.
"""

from collections.abc import Sized
from typing import TYPE_CHECKING, Any, Final, cast

import asyncmy.errors  # pyright: ignore
from asyncmy.constants import FIELD_TYPE as ASYNC_MY_FIELD_TYPE  # pyright: ignore

from sqlspec.adapters.asyncmy._typing import AsyncmyCursor, AsyncmySessionContext
from sqlspec.adapters.asyncmy.core import (
    build_insert_statement,
    collect_rows,
    create_mapped_exception,
    default_statement_config,
    detect_json_columns_from_description,
    driver_profile,
    format_identifier,
    normalize_execute_many_parameters,
    normalize_execute_parameters,
    normalize_lastrowid,
    resolve_column_names,
    resolve_many_rowcount,
    resolve_rowcount,
)
from sqlspec.adapters.asyncmy.data_dictionary import AsyncmyDataDictionary
from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile
from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json
from sqlspec.utils.type_guards import supports_json_type

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.adapters.asyncmy._typing import AsyncmyConnection
    from sqlspec.core import SQL, StatementConfig
    from sqlspec.driver import ExecutionResult
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry

__all__ = ("AsyncmyCursor", "AsyncmyDriver", "AsyncmyExceptionHandler", "AsyncmySessionContext")

logger = get_logger(__name__)

json_type_value = (
    ASYNC_MY_FIELD_TYPE.JSON if ASYNC_MY_FIELD_TYPE is not None and supports_json_type(ASYNC_MY_FIELD_TYPE) else None
)
ASYNCMY_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set()


class AsyncmyExceptionHandler(BaseAsyncExceptionHandler):
    """Async context manager for handling asyncmy (MySQL) database exceptions.

    Maps MySQL error codes and SQLSTATE to specific SQLSpec exceptions
    for better error handling in application code.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __aexit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, asyncmy.errors.Error):
            result = create_mapped_exception(exc_val, logger=logger)
            if result is True:
                return True
            self.pending_exception = cast("Exception", result)
            return True
        return False


[docs] class AsyncmyDriver(AsyncDriverAdapterBase): """MySQL/MariaDB database driver using AsyncMy client library. Implements asynchronous database operations for MySQL and MariaDB servers with support for parameter style conversion, type coercion, error handling, and transaction management. """ __slots__ = ("_data_dictionary",) dialect = "mysql"
[docs] def __init__( self, connection: "AsyncmyConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: AsyncmyDataDictionary | None = None
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS - The Execution Engine # ─────────────────────────────────────────────────────────────────────────────
[docs] async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Handles parameter processing, result fetching, and data transformation for MySQL/MariaDB operations. Args: cursor: AsyncMy cursor object statement: SQL statement to execute Returns: ExecutionResult: Statement execution results with data or row counts """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) if statement.returns_rows(): fetched_data = await cursor.fetchall() description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, ASYNCMY_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, row_format = collect_rows( fetched_data, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return self.create_execution_result( cursor, selected_data=rows, column_names=column_names, data_row_count=len(rows), is_select_result=True, row_format=row_format, ) affected_rows = resolve_rowcount(cursor) last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id)
[docs] async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL statement with multiple parameter sets. Uses AsyncMy's executemany for batch operations with MySQL type conversion and parameter processing. Args: cursor: AsyncMy cursor object statement: SQL statement with multiple parameter sets Returns: ExecutionResult: Batch execution results Raises: ValueError: If no parameters provided for executemany operation """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None await cursor.executemany(sql, prepared_parameters) affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Splits multi-statement scripts and executes each statement sequentially. Parameters are embedded as static values for script execution compatibility. Args: cursor: AsyncMy cursor object statement: SQL script to execute Returns: ExecutionResult: Script execution results with statement count """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: await cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
# ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] async def begin(self) -> None: """Begin a database transaction. Explicitly starts a MySQL transaction to ensure proper transaction boundaries. Raises: SQLSpecError: If transaction initialization fails """ try: async with AsyncmyCursor(self.connection) as cursor: await cursor.execute("BEGIN") except asyncmy.errors.MySQLError as e: msg = f"Failed to begin MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def commit(self) -> None: """Commit the current transaction. Raises: SQLSpecError: If transaction commit fails """ try: await self.connection.commit() except asyncmy.errors.MySQLError as e: msg = f"Failed to commit MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def rollback(self) -> None: """Rollback the current transaction. Raises: SQLSpecError: If transaction rollback fails """ try: await self.connection.rollback() except asyncmy.errors.MySQLError as e: msg = f"Failed to rollback MySQL transaction: {e}" raise SQLSpecError(msg) from e
[docs] def with_cursor(self, connection: "AsyncmyConnection") -> "AsyncmyCursor": """Create cursor context manager for the connection. Args: connection: AsyncMy database connection Returns: AsyncmyCursor: Context manager for cursor operations """ return AsyncmyCursor(connection)
[docs] def handle_database_exceptions(self) -> "AsyncmyExceptionHandler": """Provide exception handling context manager. Returns: AsyncmyExceptionHandler: Context manager for AsyncMy exception handling """ return AsyncmyExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # STORAGE API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute a query and stream Arrow-formatted results into storage.""" self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline = self._storage_pipeline() telemetry_payload = await self._write_result_to_storage_async( arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Load Arrow data into MySQL using batched inserts.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: statement = f"TRUNCATE TABLE {format_identifier(table)}" exc_handler = self.handle_database_exceptions() async with exc_handler, self.with_cursor(self.connection) as cursor: await cursor.execute(statement) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: insert_sql = build_insert_statement(table, columns) exc_handler = self.handle_database_exceptions() async with exc_handler, self.with_cursor(self.connection) as cursor: await cursor.executemany(insert_sql, records) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Load staged artifacts from storage into MySQL.""" arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) return await self.load_from_arrow( table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound )
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "AsyncmyDataDictionary": """Get the data dictionary for this driver. Returns: Data dictionary instance for metadata queries """ if self._data_dictionary is None: self._data_dictionary = AsyncmyDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE/INTERNAL METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect asyncmy rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) json_indexes = detect_json_columns_from_description(description, ASYNCMY_JSON_TYPE_CODES) deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) rows, column_names, _row_format = collect_rows( fetched, description, json_indexes, deserializer, column_names=column_names, logger=logger ) return rows, column_names, len(rows)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from asyncmy cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction. AsyncMy uses explicit BEGIN and does not expose reliable transaction state. Returns: False - AsyncMy requires explicit transaction management. """ return False
register_driver_profile("asyncmy", driver_profile)