"""pymssql SQL Server driver implementation."""
from collections.abc import Sized
from typing import TYPE_CHECKING, Any, cast
from sqlspec.adapters.pymssql._typing import (
PYMSSQL_MODULE,
PymssqlConnection,
PymssqlCursor,
PymssqlRawCursor,
PymssqlSessionContext,
)
from sqlspec.adapters.pymssql.core import (
collect_rows,
create_mapped_exception,
default_statement_config,
driver_profile,
normalize_execute_many_parameters,
normalize_execute_parameters,
resolve_column_names,
resolve_many_rowcount,
resolve_rowcount,
)
from sqlspec.adapters.pymssql.data_dictionary import PymssqlSyncDataDictionary
from sqlspec.core import SQL, StatementConfig, get_cache_config, register_driver_profile
from sqlspec.driver import BaseSyncExceptionHandler, ExecutionResult, SyncDriverAdapterBase
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.logging import get_logger
if TYPE_CHECKING:
from collections.abc import Sequence
from pymssql._pymssql import QueryParams
__all__ = ("PymssqlCursor", "PymssqlDriver", "PymssqlExceptionHandler", "PymssqlSessionContext")
logger = get_logger("sqlspec.adapters.pymssql")
pymssql = PYMSSQL_MODULE
class _UnavailablePymssqlError(Exception):
"""Fallback pymssql exception base when pymssql is unavailable."""
class PymssqlExceptionHandler(BaseSyncExceptionHandler):
"""Context manager for handling pymssql exceptions."""
__slots__ = ()
def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
if exc_type is None:
return False
error_type = _pymssql_error_type()
if isinstance(exc_val, error_type):
self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger)
return True
return False
[docs]
class PymssqlDriver(SyncDriverAdapterBase):
"""SQL Server database driver using pymssql."""
__slots__ = ("_data_dictionary",)
dialect = "tsql"
[docs]
def __init__(
self,
connection: "PymssqlConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: PymssqlSyncDataDictionary | None = None
[docs]
def dispatch_execute(self, cursor: "PymssqlRawCursor", statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._compiled_sql(statement, self.statement_config)
cursor.execute(sql, normalize_execute_parameters(prepared_parameters))
if statement.returns_rows():
fetched_data = cursor.fetchall()
description = cursor.description or None
rows, column_names, row_format = collect_rows(fetched_data, description)
return self.create_execution_result(
cursor,
selected_data=rows,
column_names=column_names,
data_row_count=len(rows),
is_select_result=True,
row_format=row_format,
)
return self.create_execution_result(cursor, rowcount_override=resolve_rowcount(cursor))
[docs]
def dispatch_execute_many(self, cursor: "PymssqlRawCursor", statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._compiled_sql(statement, self.statement_config)
prepared_parameters = normalize_execute_many_parameters(prepared_parameters)
parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None
cursor.executemany(sql, cast("Sequence[QueryParams]", prepared_parameters))
affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count)
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs]
def dispatch_execute_script(self, cursor: "PymssqlRawCursor", statement: "SQL") -> "ExecutionResult":
sql, prepared_parameters = self._compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
if prepared_parameters and len(statements) > 1:
msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements"
raise SQLSpecError(msg)
successful_count = 0
for stmt in statements:
cursor.execute(stmt, normalize_execute_parameters(prepared_parameters))
successful_count += 1
return self.create_execution_result(
cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
[docs]
def begin(self) -> None:
try:
with PymssqlCursor(self.connection) as cursor:
cursor.execute("BEGIN TRANSACTION")
except _pymssql_error_type() as exc:
msg = f"Failed to begin SQL Server transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def commit(self) -> None:
try:
self.connection.commit()
except _pymssql_error_type() as exc:
msg = f"Failed to commit SQL Server transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def rollback(self) -> None:
try:
self.connection.rollback()
except _pymssql_error_type() as exc:
msg = f"Failed to rollback SQL Server transaction: {exc}"
raise SQLSpecError(msg) from exc
[docs]
def with_cursor(self, connection: "PymssqlConnection") -> "PymssqlCursor":
return PymssqlCursor(connection)
[docs]
def handle_database_exceptions(self) -> "PymssqlExceptionHandler":
return PymssqlExceptionHandler()
def create_savepoint(self, name: str) -> None:
self.execute_script(f"SAVE TRANSACTION {name}")
def release_savepoint(self, name: str) -> None:
return None
def rollback_to_savepoint(self, name: str) -> None:
self.execute_script(f"ROLLBACK TRANSACTION {name}")
@property
def data_dictionary(self) -> "PymssqlSyncDataDictionary":
if self._data_dictionary is None:
self._data_dictionary = PymssqlSyncDataDictionary()
return self._data_dictionary
[docs]
def collect_rows(self, cursor: "PymssqlRawCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
column_names = resolve_column_names(cursor.description or None)
return fetched, column_names, len(fetched)
[docs]
def resolve_rowcount(self, cursor: "PymssqlRawCursor") -> int:
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
return False
def _pymssql_error_type() -> "type[BaseException]":
return cast("type[BaseException]", getattr(pymssql, "Error", _UnavailablePymssqlError))
register_driver_profile("pymssql", driver_profile)