Source code for sqlspec.adapters.spanner.config

"""Spanner configuration."""

from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

from google.cloud.spanner_v1 import Client
from google.cloud.spanner_v1.pool import AbstractSessionPool, FixedSizePool
from typing_extensions import NotRequired

from sqlspec.adapters.spanner._types import SpannerConnection
from sqlspec.adapters.spanner.driver import SpannerSyncDriver, spanner_statement_config
from sqlspec.config import SyncDatabaseConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events._hints import EventRuntimeHints
from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config
from sqlspec.utils.serializers import from_json, to_json

if TYPE_CHECKING:
    from google.auth.credentials import Credentials
    from google.cloud.spanner_v1.database import Database

    from sqlspec.config import ExtensionConfigs
    from sqlspec.core import StatementConfig
    from sqlspec.observability import ObservabilityConfig

__all__ = ("SpannerConnectionParams", "SpannerDriverFeatures", "SpannerPoolParams", "SpannerSyncConfig")


class SpannerConnectionParams(TypedDict):
    """Spanner connection parameters."""

    project: "NotRequired[str]"
    instance_id: "NotRequired[str]"
    database_id: "NotRequired[str]"
    credentials: "NotRequired[Credentials]"
    client_options: "NotRequired[dict[str, Any]]"
    extra: "NotRequired[dict[str, Any]]"


class SpannerPoolParams(SpannerConnectionParams):
    """Session pool configuration."""

    pool_type: "NotRequired[type[AbstractSessionPool]]"
    min_sessions: "NotRequired[int]"
    max_sessions: "NotRequired[int]"
    labels: "NotRequired[dict[str, str]]"
    ping_interval: "NotRequired[int]"


class SpannerDriverFeatures(TypedDict):
    """Driver feature flags for Spanner.

    Attributes:
        enable_uuid_conversion: Enable automatic UUID string conversion.
        json_serializer: Custom JSON serializer for parameter conversion.
        json_deserializer: Custom JSON deserializer for result conversion.
        session_labels: Labels to apply to Spanner sessions.
        enable_events: Enable database event channel support.
            Defaults to True when extension_config["events"] is configured.
        events_backend: Backend type for event handling.
            Spanner only supports "table_queue" (no native pub/sub).
    """

    enable_uuid_conversion: "NotRequired[bool]"
    json_serializer: "NotRequired[Callable[[Any], str]]"
    json_deserializer: "NotRequired[Callable[[str], Any]]"
    session_labels: "NotRequired[dict[str, str]]"


[docs] class SpannerSyncConfig(SyncDatabaseConfig["SpannerConnection", "AbstractSessionPool", SpannerSyncDriver]): """Spanner configuration and session management.""" driver_type: ClassVar[type["SpannerSyncDriver"]] = SpannerSyncDriver connection_type: ClassVar[type["SpannerConnection"]] = cast("type[SpannerConnection]", SpannerConnection) supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = False supports_native_parquet_import: ClassVar[bool] = False requires_staging_for_load: ClassVar[bool] = False
[docs] def __init__( self, *, connection_config: "SpannerPoolParams | dict[str, Any] | None" = None, connection_instance: "AbstractSessionPool | None" = None, migration_config: "dict[str, Any] | None" = None, statement_config: "StatementConfig | None" = None, driver_features: "SpannerDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, extension_config: "ExtensionConfigs | None" = None, observability_config: "ObservabilityConfig | None" = None, **kwargs: Any, ) -> None: connection_config, connection_instance = apply_pool_deprecations( kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance ) self.connection_config = normalize_connection_config(connection_config) self.connection_config.setdefault("min_sessions", 1) self.connection_config.setdefault("max_sessions", 10) self.connection_config.setdefault("pool_type", FixedSizePool) features: dict[str, Any] = dict(driver_features) if driver_features else {} features.setdefault("enable_uuid_conversion", True) features.setdefault("json_serializer", to_json) features.setdefault("json_deserializer", from_json) base_statement_config = statement_config or spanner_statement_config super().__init__( connection_config=self.connection_config, connection_instance=connection_instance, migration_config=migration_config, statement_config=base_statement_config, driver_features=features, bind_key=bind_key, extension_config=extension_config, observability_config=observability_config, **kwargs, ) self._client: Client | None = None self._database: Database | None = None
def _get_client(self) -> Client: if self._client is None: self._client = Client( project=self.connection_config.get("project"), credentials=self.connection_config.get("credentials"), client_options=self.connection_config.get("client_options"), ) return self._client
[docs] def get_database(self) -> "Database": instance_id = self.connection_config.get("instance_id") database_id = self.connection_config.get("database_id") if not instance_id or not database_id: msg = "instance_id and database_id are required." raise ImproperConfigurationError(msg) if self.connection_instance is None: self.connection_instance = self.provide_pool() if self._database is None: client = self._get_client() self._database = client.instance(instance_id).database(database_id, pool=self.connection_instance) # type: ignore[no-untyped-call] return self._database
[docs] def create_connection(self) -> SpannerConnection: instance_id = self.connection_config.get("instance_id") database_id = self.connection_config.get("database_id") if not instance_id or not database_id: msg = "instance_id and database_id are required." raise ImproperConfigurationError(msg) if self.connection_instance is None: self.connection_instance = self.provide_pool() client = self._get_client() database = client.instance(instance_id).database(database_id, pool=self.connection_instance) # type: ignore[no-untyped-call] return cast("SpannerConnection", database.snapshot())
def _create_pool(self) -> AbstractSessionPool: instance_id = self.connection_config.get("instance_id") database_id = self.connection_config.get("database_id") if not instance_id or not database_id: msg = "instance_id and database_id are required." raise ImproperConfigurationError(msg) pool_type = cast("type[AbstractSessionPool]", self.connection_config.get("pool_type", FixedSizePool)) pool_kwargs: dict[str, Any] = {} if pool_type is FixedSizePool: if "size" in self.connection_config: pool_kwargs["size"] = self.connection_config["size"] elif "max_sessions" in self.connection_config: pool_kwargs["size"] = self.connection_config["max_sessions"] if "labels" in self.connection_config: pool_kwargs["labels"] = self.connection_config["labels"] else: valid_pool_keys = {"size", "labels", "ping_interval"} pool_kwargs = {k: v for k, v in self.connection_config.items() if k in valid_pool_keys and v is not None} if "size" not in pool_kwargs and "max_sessions" in self.connection_config: pool_kwargs["size"] = self.connection_config["max_sessions"] pool_factory = cast("Callable[..., AbstractSessionPool]", pool_type) return pool_factory(**pool_kwargs) def _close_pool(self) -> None: if self.connection_instance and hasattr(self.connection_instance, "close"): cast("Any", self.connection_instance).close()
[docs] @contextmanager def provide_connection( self, *args: Any, transaction: "bool" = False, **kwargs: Any ) -> Generator[SpannerConnection, None, None]: """Yield a Snapshot (default) or Transaction context from the configured pool. Args: *args: Additional positional arguments (unused, for interface compatibility). transaction: If True, yields a Transaction context that supports execute_update() for DML statements. If False (default), yields a read-only Snapshot context for SELECT queries. **kwargs: Additional keyword arguments (unused, for interface compatibility). Note: For complex transactional logic with retries, use database.run_in_transaction() directly. The Transaction context here auto-commits on successful exit. """ database = self.get_database() if transaction: session = cast("Any", database).session() session.create() try: txn = session.transaction() txn.__enter__() try: yield cast("SpannerConnection", txn) # Only commit if not already committed (driver.commit() may have been called) has_txn_id = hasattr(txn, "_transaction_id") and txn._transaction_id is not None already_committed = hasattr(txn, "committed") and txn.committed is not None if has_txn_id and not already_committed: txn.commit() except Exception: if hasattr(txn, "_transaction_id") and txn._transaction_id is not None: txn.rollback() raise finally: session.delete() else: with cast("Any", database).snapshot(multi_use=True) as snapshot: yield cast("SpannerConnection", snapshot)
[docs] @contextmanager def provide_session( self, *args: Any, statement_config: "StatementConfig | None" = None, transaction: "bool" = False, **kwargs: Any ) -> Generator[SpannerSyncDriver, None, None]: with self.provide_connection(*args, transaction=transaction, **kwargs) as connection: driver = self.driver_type( connection=connection, statement_config=statement_config or self.statement_config, driver_features=self.driver_features, ) yield self._prepare_driver(driver)
[docs] @contextmanager def provide_write_session( self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any ) -> Generator[SpannerSyncDriver, None, None]: with self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs) as driver: yield driver
[docs] def get_signature_namespace(self) -> dict[str, Any]: namespace = super().get_signature_namespace() namespace.update({ "SpannerSyncConfig": SpannerSyncConfig, "SpannerConnectionParams": SpannerConnectionParams, "SpannerDriverFeatures": SpannerDriverFeatures, "SpannerSyncDriver": SpannerSyncDriver, }) return namespace
[docs] def get_event_runtime_hints(self) -> "EventRuntimeHints": """Return queue defaults for Spanner JSON handling.""" return EventRuntimeHints(json_passthrough=True)