Source code for sqlspec.adapters.spanner.driver

"""Spanner driver implementation."""

from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Protocol, cast

from google.api_core import exceptions as api_exceptions
from google.cloud.spanner_v1.transaction import Transaction
from typing_extensions import Self

from sqlspec.adapters.spanner._typing import SpannerSessionContext
from sqlspec.adapters.spanner.core import (
    build_param_type_signature,
    coerce_params,
    collect_rows,
    create_arrow_data,
    create_mapped_exception,
    default_statement_config,
    driver_profile,
    infer_param_types,
    resolve_column_names,
    supports_batch_update,
    supports_write,
)
from sqlspec.adapters.spanner.data_dictionary import SpannerDataDictionary
from sqlspec.adapters.spanner.type_converter import SpannerOutputConverter
from sqlspec.core import StatementConfig, create_arrow_result, register_driver_profile
from sqlspec.driver import ExecutionResult, SyncDriverAdapterBase
from sqlspec.exceptions import SQLConversionError
from sqlspec.utils.serializers import from_json

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlglot.dialects.dialect import DialectType

    from sqlspec.adapters.spanner._typing import SpannerConnection
    from sqlspec.core import ArrowResult
    from sqlspec.core.statement import SQL
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
    from sqlspec.typing import ArrowReturnFormat

__all__ = (
    "SpannerDataDictionary",
    "SpannerExceptionHandler",
    "SpannerSessionContext",
    "SpannerSyncCursor",
    "SpannerSyncDriver",
)


class _SpannerResultSetProtocol(Protocol):
    metadata: Any

    def __iter__(self) -> Iterator[Any]: ...


class _SpannerReadProtocol(Protocol):
    def execute_sql(
        self, sql: str, params: "dict[str, Any] | None" = None, param_types: "dict[str, Any] | None" = None
    ) -> _SpannerResultSetProtocol: ...


class _SpannerWriteProtocol(_SpannerReadProtocol, Protocol):
    committed: "Any | None"

    def execute_update(
        self, sql: str, params: "dict[str, Any] | None" = None, param_types: "dict[str, Any] | None" = None
    ) -> int: ...

    def batch_update(
        self, batch: "list[tuple[str, dict[str, Any] | None, dict[str, Any]]]"
    ) -> "tuple[Any, list[int]]": ...

    def commit(self) -> None: ...

    def rollback(self) -> None: ...


class SpannerExceptionHandler:
    """Map Spanner client exceptions to SQLSpec exceptions.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __exit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ("pending_exception",)

    def __init__(self) -> None:
        self.pending_exception: Exception | None = None

    def __enter__(self) -> Self:
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
        _ = exc_tb
        if exc_type is None:
            return False

        if isinstance(exc_val, api_exceptions.GoogleAPICallError):
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


class SpannerSyncCursor:
    """Context manager that yields the active Spanner connection."""

    __slots__ = ("connection",)

    def __init__(self, connection: "SpannerConnection") -> None:
        self.connection = connection

    def __enter__(self) -> "SpannerConnection":
        return self.connection

    def __exit__(self, *_: Any) -> None:
        return None


[docs] class SpannerSyncDriver(SyncDriverAdapterBase): """Synchronous Spanner driver operating on Snapshot or Transaction contexts.""" dialect: "DialectType" = "spanner" __slots__ = ("_column_name_cache", "_data_dictionary", "_type_converter")
[docs] def __init__( self, connection: "SpannerConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: features = dict(driver_features) if driver_features else {} if statement_config is None: statement_config = default_statement_config super().__init__(connection=connection, statement_config=statement_config, driver_features=features) json_deserializer = features.get("json_deserializer") self._type_converter = SpannerOutputConverter( enable_uuid_conversion=features.get("enable_uuid_conversion", True), json_deserializer=cast("Callable[[str], Any]", json_deserializer or from_json), ) self._column_name_cache: dict[int, tuple[Any, list[str]]] = {} self._data_dictionary: SpannerDataDictionary | None = None
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS - The Execution Engine # ─────────────────────────────────────────────────────────────────────────────
[docs] def dispatch_execute(self, cursor: "SpannerConnection", statement: "SQL") -> ExecutionResult: sql, params = self._get_compiled_sql(statement, self.statement_config) params = cast("dict[str, Any] | None", params) coerced_params = self._coerce_params(params) param_types_map = self._infer_param_types(coerced_params) if statement.returns_rows(): reader = cast("_SpannerReadProtocol", cursor) result_set = reader.execute_sql(sql, params=coerced_params, param_types=param_types_map) rows = list(result_set) try: metadata = result_set.metadata row_type = metadata.row_type fields = row_type.fields except AttributeError: fields = None if not fields: msg = "Result set metadata not available." raise SQLConversionError(msg) column_names = self._resolve_column_names(fields) data, column_names = collect_rows(rows, fields, self._type_converter, column_names=column_names) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True, row_format="tuple", ) if supports_write(cursor): writer = cast("_SpannerWriteProtocol", cursor) row_count = writer.execute_update(sql, params=coerced_params, param_types=param_types_map) return self.create_execution_result(cursor, rowcount_override=row_count) msg = "Cannot execute DML in a read-only Snapshot context." raise SQLConversionError(msg)
[docs] def dispatch_execute_many(self, cursor: "SpannerConnection", statement: "SQL") -> ExecutionResult: if not supports_batch_update(cursor): msg = "execute_many requires a Transaction context" raise SQLConversionError(msg) sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) if not prepared_parameters or not isinstance(prepared_parameters, list): msg = "execute_many requires at least one parameter set" raise SQLConversionError(msg) coerce_params = self._coerce_params infer_param_types = self._infer_param_types param_types_cache: dict[tuple[tuple[str, type[Any]], ...], dict[str, Any]] = {} empty_param_types: dict[str, Any] = {} batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = [] append_batch_arg = batch_args.append for params in prepared_parameters: coerced_params = coerce_params(cast("dict[str, Any] | None", params)) if not coerced_params: append_batch_arg((sql, {}, empty_param_types)) continue signature = build_param_type_signature(coerced_params) param_types = param_types_cache.get(signature) if param_types is None: param_types = infer_param_types(coerced_params) param_types_cache[signature] = param_types append_batch_arg((sql, coerced_params, param_types)) writer = cast("_SpannerWriteProtocol", cursor) _status, row_counts = writer.batch_update(batch_args) total_rows = sum(row_counts) if row_counts else 0 return self.create_execution_result(cursor, rowcount_override=total_rows, is_many_result=True)
[docs] def dispatch_execute_script(self, cursor: "SpannerConnection", statement: "SQL") -> ExecutionResult: sql, params = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) is_transaction = supports_write(cursor) reader = cast("_SpannerReadProtocol", cursor) count = 0 script_params = cast("dict[str, Any] | None", params) for stmt in statements: is_select = stmt.upper().strip().startswith("SELECT") coerced_params = self._coerce_params(script_params) if not is_select and not is_transaction: msg = "Cannot execute DML in a read-only Snapshot context." raise SQLConversionError(msg) if not is_select and is_transaction: writer = cast("_SpannerWriteProtocol", cursor) writer.execute_update(stmt, params=coerced_params, param_types=self._infer_param_types(coerced_params)) else: _ = list( reader.execute_sql(stmt, params=coerced_params, param_types=self._infer_param_types(coerced_params)) ) count += 1 return self.create_execution_result( cursor, statement_count=count, successful_statements=count, is_script_result=True )
# ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] def begin(self) -> None: return None
[docs] def commit(self) -> None: if isinstance(self.connection, Transaction): writer = cast("_SpannerWriteProtocol", self.connection) if writer.committed is not None: return writer.commit()
[docs] def rollback(self) -> None: if isinstance(self.connection, Transaction): writer = cast("_SpannerWriteProtocol", self.connection) writer.rollback()
[docs] def with_cursor(self, connection: "SpannerConnection") -> "SpannerSyncCursor": return SpannerSyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "SpannerExceptionHandler": return SpannerExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # ARROW API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def select_to_arrow(self, statement: "Any", /, *parameters: "Any", **kwargs: Any) -> "ArrowResult": result = self.execute(statement, *parameters, **kwargs) return_format = cast("ArrowReturnFormat", kwargs.get("return_format", "table")) arrow_data = create_arrow_data(result.get_data(), return_format) return create_arrow_result(result.statement, arrow_data, rows_affected=result.rows_affected)
# ───────────────────────────────────────────────────────────────────────────── # STORAGE API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute query and stream Arrow results to storage.""" self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline = self._storage_pipeline() telemetry_payload = self._write_result_to_storage_sync( arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Load Arrow data into Spanner table via batch mutations.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: delete_sql = f"DELETE FROM {table} WHERE TRUE" if isinstance(self.connection, Transaction): writer = cast("_SpannerWriteProtocol", self.connection) writer.execute_update(delete_sql) else: msg = "Delete requires a Transaction context." raise SQLConversionError(msg) columns, records = self._arrow_table_to_rows(arrow_table) if records: insert_sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({', '.join('@p' + str(i) for i in range(len(columns)))})" batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = [] for record in records: params = {f"p{i}": val for i, val in enumerate(record)} coerced = self._coerce_params(params) batch_args.append((insert_sql, coerced, self._infer_param_types(coerced))) conn = self.connection if not isinstance(conn, Transaction): msg = "Arrow import requires a Transaction context." raise SQLConversionError(msg) writer = cast("_SpannerWriteProtocol", conn) writer.batch_update(batch_args) telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Load artifacts from storage into Spanner table.""" arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "SpannerDataDictionary": if self._data_dictionary is None: self._data_dictionary = SpannerDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE/INTERNAL METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Spanner rows for the direct execution path. Note: Spanner's collect_rows requires result set fields and a type converter. The direct execution path may not always have this metadata available, so this falls back to basic collection. """ # For direct path, we need fields metadata from the result set. # If not available, return raw data with no column names. if not fetched: return [], [], 0 # Attempt to extract column names from dict keys if rows are dicts if isinstance(fetched[0], dict): column_names = list(fetched[0].keys()) return fetched, column_names, len(fetched) # For tuple rows without metadata, return as-is return fetched, [], len(fetched)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from Spanner cursor for the direct execution path.""" # Spanner uses execute_update return value, not cursor.rowcount return 0
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction.""" return False def _coerce_params(self, params: "dict[str, Any] | list[Any] | tuple[Any, ...] | None") -> "dict[str, Any] | None": return coerce_params(params, json_serializer=self.driver_features.get("json_serializer")) def _infer_param_types(self, params: "dict[str, Any] | list[Any] | tuple[Any, ...] | None") -> "dict[str, Any]": return infer_param_types(params) def _resolve_column_names(self, fields: Any) -> list[str]: return resolve_column_names(fields, self._column_name_cache)
register_driver_profile("spanner", driver_profile)