"""PyMySQL MySQL driver implementation."""
from collections.abc import Sized
from typing import TYPE_CHECKING, Any, Final, cast
import pymysql
from pymysql.constants import FIELD_TYPE
from sqlspec.adapters.pymysql.core import (
build_insert_statement,
collect_rows,
create_mapped_exception,
default_statement_config,
detect_json_columns_from_description,
driver_profile,
format_identifier,
normalize_execute_many_parameters,
normalize_execute_parameters,
normalize_lastrowid,
resolve_column_names,
resolve_many_rowcount,
resolve_rowcount,
)
from sqlspec.adapters.pymysql.data_dictionary import PyMysqlDataDictionary
from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile
from sqlspec.driver import SyncDriverAdapterBase
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json
from sqlspec.utils.type_guards import supports_json_type
if TYPE_CHECKING:
from collections.abc import Callable
from sqlspec.adapters.pymysql._typing import PyMysqlConnection
from sqlspec.core import SQL, StatementConfig
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
from typing_extensions import Self
from sqlspec.adapters.pymysql._typing import PyMysqlSessionContext
__all__ = ("PyMysqlCursor", "PyMysqlDriver", "PyMysqlExceptionHandler", "PyMysqlSessionContext")
logger = get_logger("sqlspec.adapters.pymysql")
json_type_value = FIELD_TYPE.JSON if supports_json_type(FIELD_TYPE) else None
PYMYSQL_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set()
class PyMysqlCursor:
"""Context manager for PyMySQL cursor operations."""
__slots__ = ("connection", "cursor")
def __init__(self, connection: "PyMysqlConnection") -> None:
self.connection = connection
self.cursor: Any | None = None
def __enter__(self) -> Any:
self.cursor = self.connection.cursor()
return self.cursor
def __exit__(self, *_: Any) -> None:
if self.cursor is not None:
self.cursor.close()
class PyMysqlExceptionHandler:
"""Context manager for handling PyMySQL exceptions."""
__slots__ = ("pending_exception",)
def __init__(self) -> None:
self.pending_exception: Exception | None = None
def __enter__(self) -> Self:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
if exc_type is None:
return False
if issubclass(exc_type, pymysql.MySQLError):
result = create_mapped_exception(exc_val, logger=logger)
if result is True:
return True
self.pending_exception = cast("Exception", result)
return True
return False
[docs]
class PyMysqlDriver(SyncDriverAdapterBase):
"""MySQL/MariaDB database driver using PyMySQL."""
__slots__ = ("_data_dictionary",)
dialect = "mysql"
[docs]
def __init__(
self,
connection: "PyMysqlConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: PyMysqlDataDictionary | None = None
[docs]
def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
cursor.execute(sql, normalize_execute_parameters(prepared_parameters))
if statement.returns_rows():
fetched_data = cursor.fetchall()
description = cursor.description or None
column_names = resolve_column_names(description)
json_indexes = detect_json_columns_from_description(description, PYMYSQL_JSON_TYPE_CODES)
deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json))
rows, column_names, row_format = collect_rows(
fetched_data, description, json_indexes, deserializer, column_names=column_names, logger=logger
)
return self.create_execution_result(
cursor,
selected_data=rows,
column_names=column_names,
data_row_count=len(rows),
is_select_result=True,
row_format=row_format,
)
affected_rows = resolve_rowcount(cursor)
last_id = normalize_lastrowid(cursor)
return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id)
[docs]
def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
prepared_parameters = normalize_execute_many_parameters(prepared_parameters)
parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None
cursor.executemany(sql, prepared_parameters)
affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count)
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs]
def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_cursor = cursor
for stmt in statements:
cursor.execute(stmt, normalize_execute_parameters(prepared_parameters))
successful_count += 1
return self.create_execution_result(
last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
[docs]
def begin(self) -> None:
try:
with PyMysqlCursor(self.connection) as cursor:
cursor.execute("BEGIN")
except pymysql.MySQLError as exc:
msg = f"Failed to begin MySQL transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def commit(self) -> None:
try:
self.connection.commit()
except pymysql.MySQLError as exc:
msg = f"Failed to commit MySQL transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def rollback(self) -> None:
try:
self.connection.rollback()
except pymysql.MySQLError as exc:
msg = f"Failed to rollback MySQL transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def with_cursor(self, connection: "PyMysqlConnection") -> "PyMysqlCursor":
return PyMysqlCursor(connection)
[docs]
def handle_database_exceptions(self) -> "PyMysqlExceptionHandler":
return PyMysqlExceptionHandler()
[docs]
def select_to_storage(
self,
statement: "SQL | str",
destination: "StorageDestination",
/,
*parameters: Any,
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
self._require_capability("arrow_export_enabled")
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
pipeline = self._storage_pipeline()
telemetry_payload = self._write_result_to_storage_sync(
arrow_result, destination, format_hint=format_hint, pipeline=pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
if overwrite:
statement = f"TRUNCATE TABLE {format_identifier(table)}"
exc_handler = self.handle_database_exceptions()
with exc_handler, self.with_cursor(self.connection) as cursor:
cursor.execute(statement)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
columns, records = self._arrow_table_to_rows(arrow_table)
if records:
insert_sql = build_insert_statement(table, columns)
exc_handler = self.handle_database_exceptions()
with exc_handler, self.with_cursor(self.connection) as cursor:
cursor.executemany(insert_sql, records)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
@property
def data_dictionary(self) -> "PyMysqlDataDictionary":
if self._data_dictionary is None:
self._data_dictionary = PyMysqlDataDictionary()
return self._data_dictionary
[docs]
def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect PyMySQL rows for the direct execution path."""
description = cursor.description or None
column_names = resolve_column_names(description)
json_indexes = detect_json_columns_from_description(description, PYMYSQL_JSON_TYPE_CODES)
deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json))
rows, column_names, _row_format = collect_rows(
fetched, description, json_indexes, deserializer, column_names=column_names, logger=logger
)
return rows, column_names, len(rows)
[docs]
def resolve_rowcount(self, cursor: Any) -> int:
"""Resolve rowcount from PyMySQL cursor for the direct execution path."""
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
get_autocommit = getattr(self.connection, "get_autocommit", None)
if callable(get_autocommit):
return not bool(get_autocommit())
autocommit = getattr(self.connection, "autocommit", None)
if autocommit is not None:
try:
return not bool(autocommit)
except Exception:
return False
return False
register_driver_profile("pymysql", driver_profile)