Source code for sqlspec.adapters.mssql_python.driver

"""mssql-python sync and async drivers."""

import asyncio
from typing import TYPE_CHECKING, Any, cast

from sqlspec.adapters.mssql_python._typing import (
    MSSQL_PYTHON_MODULE,
    MssqlPythonAsyncCursor,
    MssqlPythonAsyncSessionContext,
    MssqlPythonConnection,
    MssqlPythonCursor,
    MssqlPythonRawCursor,
    MssqlPythonSessionContext,
)
from sqlspec.adapters.mssql_python.core import create_mapped_exception, default_statement_config, driver_profile
from sqlspec.adapters.mssql_python.data_dictionary import MssqlPythonAsyncDataDictionary, MssqlPythonSyncDataDictionary
from sqlspec.core import build_arrow_result_from_table, get_cache_config, register_driver_profile
from sqlspec.driver import (
    AsyncDriverAdapterBase,
    BaseAsyncExceptionHandler,
    BaseSyncExceptionHandler,
    SyncDriverAdapterBase,
)
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.module_loader import ensure_pyarrow

if TYPE_CHECKING:
    from collections.abc import Iterable

    from sqlspec.builder import QueryBuilder
    from sqlspec.core import SQL, ArrowResult, Statement, StatementConfig, StatementFilter
    from sqlspec.driver import ExecutionResult
    from sqlspec.typing import ArrowReturnFormat, StatementParameters

__all__ = (
    "MssqlPythonAsyncCursor",
    "MssqlPythonAsyncDriver",
    "MssqlPythonAsyncExceptionHandler",
    "MssqlPythonAsyncSessionContext",
    "MssqlPythonCursor",
    "MssqlPythonDriver",
    "MssqlPythonExceptionHandler",
    "MssqlPythonSessionContext",
)

logger = get_logger("sqlspec.adapters.mssql_python")
_MSSQL_ERROR = cast("type[BaseException]", getattr(MSSQL_PYTHON_MODULE, "Error", Exception))


class MssqlPythonExceptionHandler(BaseSyncExceptionHandler):
    """Sync context manager handling mssql-python exceptions."""

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if isinstance(exc_val, _MSSQL_ERROR):
            self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger)
            return True
        return False


class MssqlPythonAsyncExceptionHandler(BaseAsyncExceptionHandler):
    """Async context manager handling mssql-python exceptions."""

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if isinstance(exc_val, _MSSQL_ERROR):
            self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger)
            return True
        return False


[docs] class MssqlPythonDriver(SyncDriverAdapterBase): """mssql-python sync driver.""" __slots__ = ("_data_dictionary",) dialect = "tsql"
[docs] def __init__( self, connection: "MssqlPythonConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MssqlPythonSyncDataDictionary | None = None
@property def data_dictionary(self) -> "MssqlPythonSyncDataDictionary": if self._data_dictionary is None: self._data_dictionary = MssqlPythonSyncDataDictionary() return self._data_dictionary
[docs] def dispatch_execute(self, cursor: "MssqlPythonRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) _execute_cursor(cursor, sql, prepared_parameters) if statement.returns_rows(): fetched = cursor.fetchall() column_names = [desc[0] for desc in (cursor.description or [])] return self.create_execution_result( cursor, selected_data=fetched, column_names=column_names, data_row_count=len(fetched), is_select_result=True, row_format="tuple", ) return self.create_execution_result(cursor, rowcount_override=_cursor_rowcount(cursor))
[docs] def dispatch_execute_many(self, cursor: "MssqlPythonRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.executemany(sql, cast("Any", prepared_parameters)) return self.create_execution_result(cursor, rowcount_override=_cursor_rowcount(cursor), is_many_result=True)
[docs] def dispatch_execute_script(self, cursor: "MssqlPythonRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 for stmt in statements: _execute_cursor(cursor, stmt, prepared_parameters) successful_count += 1 return self.create_execution_result( cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
[docs] def collect_rows(self, cursor: "MssqlPythonRawCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": column_names = [desc[0] for desc in (cursor.description or [])] return fetched, column_names, len(fetched)
[docs] def resolve_rowcount(self, cursor: "MssqlPythonRawCursor") -> int: return _cursor_rowcount(cursor)
[docs] def begin(self) -> None: return None
[docs] def commit(self) -> None: try: self.connection.commit() except _MSSQL_ERROR as exc: msg = f"Failed to commit transaction: {exc}" raise SQLSpecError(msg) from exc
[docs] def rollback(self) -> None: try: self.connection.rollback() except _MSSQL_ERROR as exc: msg = f"Failed to rollback transaction: {exc}" raise SQLSpecError(msg) from exc
[docs] def with_cursor(self, connection: "MssqlPythonConnection") -> "MssqlPythonCursor": return MssqlPythonCursor(connection)
[docs] def handle_database_exceptions(self) -> "MssqlPythonExceptionHandler": return MssqlPythonExceptionHandler()
def create_savepoint(self, name: str) -> None: self.execute_script(f"SAVE TRANSACTION {name}") def release_savepoint(self, name: str) -> None: return None def rollback_to_savepoint(self, name: str) -> None: self.execute_script(f"ROLLBACK TRANSACTION {name}")
[docs] def select_to_arrow( self, statement: "Statement | QueryBuilder", /, *parameters: "StatementParameters | StatementFilter", statement_config: "StatementConfig | None" = None, return_format: "ArrowReturnFormat" = "table", native_only: bool = False, batch_size: int | None = None, arrow_schema: Any = None, **kwargs: Any, ) -> "ArrowResult": """Execute a query and return native mssql-python Arrow results.""" ensure_pyarrow() config = statement_config or self.statement_config prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) sql, prepared_parameters = self._get_compiled_sql(prepared_statement, config) arrow_kwargs = {"batch_size": batch_size} if batch_size is not None else {} table: Any | None = None exc_handler = self.handle_database_exceptions() with exc_handler, self.with_cursor(self.connection) as cursor: _execute_cursor(cursor, sql, prepared_parameters) table = cursor.arrow(**arrow_kwargs) self._check_pending_exception(exc_handler) if table is None: msg = "mssql-python did not return an Arrow table." raise SQLSpecError(msg) return build_arrow_result_from_table( prepared_statement, table, return_format=return_format, batch_size=batch_size, arrow_schema=arrow_schema )
[docs] def bulk_copy( self, target_table: str, rows: "Iterable[tuple[Any, ...]]", *, batch_size: int = 64_000, timeout: int = 3600, column_mappings: list[str] | list[tuple[int, str]] | None = None, keep_identity: bool = False, check_constraints: bool = True, table_lock: bool = False, keep_nulls: bool = False, fire_triggers: bool = False, use_internal_transaction: bool = False, ) -> int: """Bulk insert rows via mssql-python cursor.bulkcopy().""" rowcount = 0 exc_handler = self.handle_database_exceptions() with exc_handler, self.with_cursor(self.connection) as cursor: cursor.bulkcopy( target_table, rows, batch_size=batch_size, timeout=timeout, column_mappings=column_mappings, keep_identity=keep_identity, check_constraints=check_constraints, table_lock=table_lock, keep_nulls=keep_nulls, fire_triggers=fire_triggers, use_internal_transaction=use_internal_transaction, ) rowcount = _cursor_rowcount(cursor) self._check_pending_exception(exc_handler) return int(rowcount)
[docs] class MssqlPythonAsyncDriver(AsyncDriverAdapterBase): """Async wrapper around mssql-python's sync DB-API via asyncio.to_thread.""" __slots__ = ("_data_dictionary",) dialect = "tsql"
[docs] def __init__( self, connection: "MssqlPythonConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MssqlPythonAsyncDataDictionary | None = None
@property def data_dictionary(self) -> "MssqlPythonAsyncDataDictionary": if self._data_dictionary is None: self._data_dictionary = MssqlPythonAsyncDataDictionary() return self._data_dictionary
[docs] async def dispatch_execute(self, cursor: "MssqlPythonRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await asyncio.to_thread(_execute_cursor, cursor, sql, prepared_parameters) if statement.returns_rows(): fetched = await asyncio.to_thread(cursor.fetchall) column_names = [desc[0] for desc in (cursor.description or [])] return self.create_execution_result( cursor, selected_data=fetched, column_names=column_names, data_row_count=len(fetched), is_select_result=True, row_format="tuple", ) return self.create_execution_result(cursor, rowcount_override=_cursor_rowcount(cursor))
[docs] async def dispatch_execute_many(self, cursor: "MssqlPythonRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await asyncio.to_thread(cursor.executemany, sql, cast("Any", prepared_parameters)) return self.create_execution_result(cursor, rowcount_override=_cursor_rowcount(cursor), is_many_result=True)
[docs] async def dispatch_execute_script(self, cursor: "MssqlPythonRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 for stmt in statements: await asyncio.to_thread(_execute_cursor, cursor, stmt, prepared_parameters) successful_count += 1 return self.create_execution_result( cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
[docs] def collect_rows(self, cursor: "MssqlPythonRawCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": column_names = [desc[0] for desc in (cursor.description or [])] return fetched, column_names, len(fetched)
[docs] def resolve_rowcount(self, cursor: "MssqlPythonRawCursor") -> int: return _cursor_rowcount(cursor)
[docs] async def begin(self) -> None: return None
[docs] async def commit(self) -> None: try: await asyncio.to_thread(self.connection.commit) except _MSSQL_ERROR as exc: msg = f"Failed to commit transaction: {exc}" raise SQLSpecError(msg) from exc
[docs] async def rollback(self) -> None: try: await asyncio.to_thread(self.connection.rollback) except _MSSQL_ERROR as exc: msg = f"Failed to rollback transaction: {exc}" raise SQLSpecError(msg) from exc
[docs] def with_cursor(self, connection: "MssqlPythonConnection") -> "MssqlPythonAsyncCursor": return MssqlPythonAsyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "MssqlPythonAsyncExceptionHandler": return MssqlPythonAsyncExceptionHandler()
async def create_savepoint(self, name: str) -> None: await self.execute_script(f"SAVE TRANSACTION {name}") async def release_savepoint(self, name: str) -> None: return None async def rollback_to_savepoint(self, name: str) -> None: await self.execute_script(f"ROLLBACK TRANSACTION {name}")
[docs] async def select_to_arrow( self, statement: "Statement | QueryBuilder", /, *parameters: "StatementParameters | StatementFilter", statement_config: "StatementConfig | None" = None, return_format: "ArrowReturnFormat" = "table", native_only: bool = False, batch_size: int | None = None, arrow_schema: Any = None, **kwargs: Any, ) -> "ArrowResult": """Execute a query and return native mssql-python Arrow results.""" ensure_pyarrow() config = statement_config or self.statement_config prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) sql, prepared_parameters = self._get_compiled_sql(prepared_statement, config) arrow_kwargs = {"batch_size": batch_size} if batch_size is not None else {} table: Any | None = None exc_handler = self.handle_database_exceptions() async with exc_handler, self.with_cursor(self.connection) as cursor: await asyncio.to_thread(_execute_cursor, cursor, sql, prepared_parameters) table = await asyncio.to_thread(cursor.arrow, **arrow_kwargs) self._check_pending_exception(exc_handler) if table is None: msg = "mssql-python did not return an Arrow table." raise SQLSpecError(msg) return build_arrow_result_from_table( prepared_statement, table, return_format=return_format, batch_size=batch_size, arrow_schema=arrow_schema )
[docs] async def bulk_copy( self, target_table: str, rows: "Iterable[tuple[Any, ...]]", *, batch_size: int = 64_000, timeout: int = 3600, column_mappings: list[str] | list[tuple[int, str]] | None = None, keep_identity: bool = False, check_constraints: bool = True, table_lock: bool = False, keep_nulls: bool = False, fire_triggers: bool = False, use_internal_transaction: bool = False, ) -> int: """Bulk insert rows via mssql-python cursor.bulkcopy().""" exc_handler = self.handle_database_exceptions() rowcount = 0 async with exc_handler, self.with_cursor(self.connection) as cursor: await asyncio.to_thread( cursor.bulkcopy, target_table, rows, batch_size=batch_size, timeout=timeout, column_mappings=column_mappings, keep_identity=keep_identity, check_constraints=check_constraints, table_lock=table_lock, keep_nulls=keep_nulls, fire_triggers=fire_triggers, use_internal_transaction=use_internal_transaction, ) rowcount = _cursor_rowcount(cursor) self._check_pending_exception(exc_handler) return rowcount
def _execute_cursor(cursor: "MssqlPythonRawCursor", sql: str, parameters: Any) -> None: if parameters is None: cursor.execute(sql) else: cursor.execute(sql, parameters) def _cursor_rowcount(cursor: "MssqlPythonRawCursor") -> int: rowcount = getattr(cursor, "rowcount", 0) return rowcount if isinstance(rowcount, int) and rowcount > 0 else 0 register_driver_profile("mssql_python", driver_profile, allow_override=True)