"""Spanner configuration."""
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from google.cloud.spanner_v1 import Client
from google.cloud.spanner_v1.pool import AbstractSessionPool, BurstyPool, FixedSizePool, PingingPool
from typing_extensions import NotRequired
from sqlspec.adapters.spanner._typing import SpannerConnection
from sqlspec.adapters.spanner.core import apply_driver_features, default_statement_config
from sqlspec.adapters.spanner.driver import SpannerSessionContext, SpannerSyncDriver
from sqlspec.config import SyncDatabaseConfig
from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.utils.config_tools import normalize_connection_config
from sqlspec.utils.type_guards import supports_close
if TYPE_CHECKING:
from collections.abc import Callable
from logging import Logger
from types import TracebackType
from google.api_core.client_info import ClientInfo
from google.api_core.client_options import ClientOptions
from google.api_core.retry import Retry
from google.auth.credentials import Credentials
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect, EncryptionConfig
from google.cloud.spanner_v1 import DirectedReadOptions, ExecuteSqlRequest
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.transaction import DefaultTransactionOptions
from sqlspec.config import ExtensionConfigs
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = ("SpannerConnectionParams", "SpannerDriverFeatures", "SpannerPoolParams", "SpannerSyncConfig")
_DEFAULT_SESSION_TRANSACTION: bool = True
"""Default ``transaction`` flag for ``provide_session`` / ``provide_connection``.
``True`` yields a write-capable :class:`Transaction` context matching every
other sqlspec adapter. Read-only :class:`Snapshot` contexts are available via
:meth:`SpannerSyncConfig.provide_read_session`. Pulled into a module-level
constant so an eventual ``SpannerAsyncConfig`` can import the same default."""
_CLIENT_CONFIG_FIELDS = frozenset({
"project",
"credentials",
"client_info",
"client_options",
"query_options",
"route_to_leader_enabled",
"directed_read_options",
"observability_options",
"default_transaction_options",
"experimental_host",
"disable_builtin_metrics",
"client_context",
"use_plain_text",
"ca_certificate",
"client_certificate",
"client_key",
"instance_type",
})
_INSTANCE_CONFIG_FIELDS = frozenset({"configuration_name", "display_name", "node_count", "processing_units"})
_DATABASE_CONFIG_FIELDS = frozenset({
"ddl_statements",
"logger",
"encryption_config",
"database_dialect",
"database_role",
"enable_drop_protection",
"enable_interceptors_in_tests",
"proto_descriptors",
})
class SpannerConnectionParams(TypedDict):
"""Spanner connection parameters."""
project: "NotRequired[str]"
credentials: "NotRequired[Credentials]"
client_info: "NotRequired[ClientInfo]"
client_options: "NotRequired[ClientOptions | dict[str, Any]]"
query_options: "NotRequired[ExecuteSqlRequest.QueryOptions]"
route_to_leader_enabled: "NotRequired[bool]"
directed_read_options: "NotRequired[DirectedReadOptions]"
observability_options: "NotRequired[Any]"
default_transaction_options: "NotRequired[DefaultTransactionOptions]"
experimental_host: "NotRequired[str]"
disable_builtin_metrics: "NotRequired[bool]"
client_context: "NotRequired[dict[str, str]]"
use_plain_text: "NotRequired[bool]"
ca_certificate: "NotRequired[str]"
client_certificate: "NotRequired[str]"
client_key: "NotRequired[str]"
instance_type: "NotRequired[str]"
instance_id: "NotRequired[str]"
configuration_name: "NotRequired[str]"
display_name: "NotRequired[str]"
node_count: "NotRequired[int]"
processing_units: "NotRequired[int]"
instance_labels: "NotRequired[dict[str, str]]"
database_id: "NotRequired[str]"
ddl_statements: "NotRequired[tuple[str, ...] | list[str]]"
logger: "NotRequired[Logger]"
encryption_config: "NotRequired[EncryptionConfig | dict[str, Any]]"
database_dialect: "NotRequired[DatabaseDialect]"
database_role: "NotRequired[str]"
enable_drop_protection: "NotRequired[bool]"
enable_interceptors_in_tests: "NotRequired[bool]"
proto_descriptors: "NotRequired[bytes]"
extra: "NotRequired[dict[str, Any]]"
class SpannerPoolParams(SpannerConnectionParams):
"""Session pool configuration."""
pool_type: "NotRequired[type[AbstractSessionPool]]"
size: "NotRequired[int]"
target_size: "NotRequired[int]"
max_sessions: "NotRequired[int]"
default_timeout: "NotRequired[int | float]"
session_labels: "NotRequired[dict[str, str]]"
labels: "NotRequired[dict[str, str]]"
ping_interval: "NotRequired[int]"
max_age_minutes: "NotRequired[int]"
class SpannerDriverFeatures(TypedDict):
"""Driver feature flags for Spanner.
Attributes:
enable_uuid_conversion: Enable automatic UUID string conversion.
json_serializer: Custom JSON serializer for parameter conversion.
json_deserializer: Custom JSON deserializer for result conversion.
retry: Per-request retry policy passed to execute_sql(), execute_update(), and batch_update().
timeout: Per-request timeout in seconds passed to execute_sql(), execute_update(), and batch_update().
session_labels: Deprecated compatibility alias for pool session labels.
Prefer ``connection_config["session_labels"]``.
enable_events: Enable database event channel support.
Defaults to True when extension_config["events"] is configured.
events_backend: Backend type for event handling.
Spanner only supports "table_queue" (no native pub/sub).
"""
enable_uuid_conversion: "NotRequired[bool]"
json_serializer: "NotRequired[Callable[[Any], str]]"
json_deserializer: "NotRequired[Callable[[str], Any]]"
retry: "NotRequired[Retry | None]"
timeout: "NotRequired[float | None]"
session_labels: "NotRequired[dict[str, str]]"
enable_events: "NotRequired[bool]"
events_backend: "NotRequired[str]"
class SpannerConnectionContext(SyncPoolConnectionContext):
"""Context manager for Spanner connections."""
__slots__ = ("_connection", "_session", "_transaction")
def __init__(self, config: "SpannerSyncConfig", transaction: bool = False) -> None:
super().__init__(config)
self._transaction = transaction
self._connection: SpannerConnection | None = None
self._session: Any = None
def __enter__(self) -> SpannerConnection:
database = self._config.get_database()
if self._transaction:
self._session = cast("Any", database).session()
self._session.create()
try:
txn = self._session.transaction()
txn.__enter__()
self._connection = cast("SpannerConnection", txn)
except Exception:
self._session.delete()
raise
else:
return self._connection
else:
self._session = cast("Any", database).snapshot(multi_use=True)
self._connection = cast("SpannerConnection", self._session.__enter__())
return self._connection
def __exit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
) -> bool | None:
if self._transaction and self._connection:
txn = cast("Any", self._connection)
try:
if exc_type is None:
try:
txn_id = txn._transaction_id
except AttributeError:
txn_id = None
try:
committed = txn.committed
except AttributeError:
committed = None
if txn_id is not None and committed is None:
txn.commit()
else:
try:
rollback_txn_id = txn._transaction_id
except AttributeError:
rollback_txn_id = None
if rollback_txn_id is not None:
txn.rollback()
finally:
if self._session:
self._session.delete()
elif self._session:
self._session.__exit__(exc_type, exc_val, exc_tb)
self._connection = None
self._session = None
return None
class _SpannerSessionConnectionHandler(SyncPoolSessionFactory):
__slots__ = ("_connection_ctx",)
def __init__(self, config: "SpannerSyncConfig", connection_ctx: "SpannerConnectionContext") -> None:
super().__init__(config)
self._connection_ctx = connection_ctx
def acquire_connection(self) -> "SpannerConnection":
return self._connection_ctx.__enter__()
def release_connection(self, _conn: "SpannerConnection", **kwargs: Any) -> None:
self._connection_ctx.__exit__(kwargs.get("exc_type"), kwargs.get("exc_val"), kwargs.get("exc_tb"))
[docs]
class SpannerSyncConfig(SyncDatabaseConfig["SpannerConnection", "AbstractSessionPool", SpannerSyncDriver]):
"""Spanner configuration and session management."""
driver_type: ClassVar[type["SpannerSyncDriver"]] = SpannerSyncDriver
connection_type: ClassVar[type["SpannerConnection"]] = cast("type[SpannerConnection]", SpannerConnection)
supports_transactional_ddl: ClassVar[bool] = False
supports_native_arrow_export: ClassVar[bool] = True
supports_native_arrow_import: ClassVar[bool] = True
supports_native_parquet_export: ClassVar[bool] = False
supports_native_parquet_import: ClassVar[bool] = False
requires_staging_for_load: ClassVar[bool] = False
_connection_context_class: "ClassVar[type[SpannerConnectionContext]]" = SpannerConnectionContext
_session_factory_class: "ClassVar[type[_SpannerSessionConnectionHandler]]" = _SpannerSessionConnectionHandler
_session_context_class: "ClassVar[type[SpannerSessionContext]]" = SpannerSessionContext
_default_statement_config = default_statement_config
[docs]
def __init__(
self,
*,
connection_config: "SpannerPoolParams | dict[str, Any] | None" = None,
connection_instance: "AbstractSessionPool | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "SpannerDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
self.connection_config = normalize_connection_config(connection_config)
if "min_sessions" in self.connection_config:
msg = "Spanner session pools do not support 'min_sessions'; use 'size' or 'target_size'."
raise ImproperConfigurationError(msg)
raw_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
legacy_session_labels = raw_driver_features.pop("session_labels", None)
if (
legacy_session_labels is not None
and "session_labels" not in self.connection_config
and "labels" not in self.connection_config
):
self.connection_config["session_labels"] = legacy_session_labels
self.connection_config.setdefault("size", self.connection_config.pop("max_sessions", 10))
self.connection_config.setdefault("pool_type", FixedSizePool)
driver_features = apply_driver_features(raw_driver_features)
statement_config = statement_config or default_statement_config
super().__init__(
connection_config=self.connection_config,
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=statement_config,
driver_features=driver_features,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
self._client: Client | None = None
self._database: Database | None = None
def _get_client(self) -> Client:
if self._client is None:
client_kwargs = self._resolve_kwargs(_CLIENT_CONFIG_FIELDS)
self._client = Client(**client_kwargs)
return self._client
def get_database(self) -> "Database":
instance_id = self.connection_config.get("instance_id")
database_id = self.connection_config.get("database_id")
if not instance_id or not database_id:
msg = "instance_id and database_id are required."
raise ImproperConfigurationError(msg)
if self.connection_instance is None:
self.connection_instance = self.provide_pool()
if self._database is None:
client = self._get_client()
instance_kwargs = self._resolve_kwargs(_INSTANCE_CONFIG_FIELDS)
instance_labels = self.connection_config.get("instance_labels")
if instance_labels is not None:
instance_kwargs["labels"] = instance_labels
database_kwargs = self._resolve_kwargs(_DATABASE_CONFIG_FIELDS)
database_kwargs["pool"] = self.connection_instance
self._database = client.instance(instance_id, **instance_kwargs).database( # type: ignore[no-untyped-call]
database_id, **database_kwargs
)
return self._database
[docs]
def create_connection(self) -> SpannerConnection:
return cast("SpannerConnection", self.get_database().snapshot()) # type: ignore[no-untyped-call]
def _create_pool(self) -> AbstractSessionPool:
instance_id = self.connection_config.get("instance_id")
database_id = self.connection_config.get("database_id")
if not instance_id or not database_id:
msg = "instance_id and database_id are required."
raise ImproperConfigurationError(msg)
pool_type = cast("type[AbstractSessionPool]", self.connection_config.get("pool_type", FixedSizePool))
labels = self.connection_config.get("session_labels", self.connection_config.get("labels"))
pool_kwargs: dict[str, Any] = self._resolve_pool_base_kwargs(labels=cast("dict[str, str] | None", labels))
if issubclass(pool_type, PingingPool):
pool_kwargs.update(self._resolve_kwargs({"size", "default_timeout", "ping_interval"}))
elif issubclass(pool_type, FixedSizePool):
pool_kwargs.update(self._resolve_kwargs({"size", "default_timeout", "max_age_minutes"}))
elif issubclass(pool_type, BurstyPool):
target_size = self.connection_config.get("target_size", self.connection_config.get("size"))
if target_size is not None:
pool_kwargs["target_size"] = target_size
else:
pool_kwargs.update(
self._resolve_kwargs({"size", "target_size", "default_timeout", "ping_interval", "max_age_minutes"})
)
pool_factory = cast("Callable[..., AbstractSessionPool]", pool_type)
return pool_factory(**pool_kwargs)
def _resolve_pool_base_kwargs(self, *, labels: "dict[str, str] | None") -> dict[str, Any]:
pool_kwargs: dict[str, Any] = {}
if labels is not None:
pool_kwargs["labels"] = labels
database_role = self.connection_config.get("database_role")
if database_role is not None:
pool_kwargs["database_role"] = database_role
return pool_kwargs
def _resolve_kwargs(self, fields: "frozenset[str] | set[str]") -> dict[str, Any]:
return {
field: self.connection_config[field] for field in fields if self.connection_config.get(field) is not None
}
def _close_pool(self) -> None:
if self.connection_instance and supports_close(self.connection_instance):
self.connection_instance.close()
if self._client and supports_close(self._client):
self._client.close()
self._client = None
self._database = None
[docs]
def provide_connection(
self, *args: Any, transaction: "bool" = _DEFAULT_SESSION_TRANSACTION, **kwargs: Any
) -> "SpannerConnectionContext":
"""Yield a Transaction (default) or Snapshot context from the configured pool.
Args:
*args: Additional positional arguments (unused, for interface compatibility).
transaction: If True (default), yields a Transaction context that
supports execute_update() for DML statements. If False, yields
a read-only Snapshot context for SELECT queries.
**kwargs: Additional keyword arguments (unused, for interface compatibility).
"""
return SpannerConnectionContext(self, transaction=transaction)
[docs]
def provide_session(
self,
*args: Any,
statement_config: "StatementConfig | None" = None,
transaction: "bool" = _DEFAULT_SESSION_TRANSACTION,
**kwargs: Any,
) -> "SpannerSessionContext":
"""Provide a Spanner driver session context manager.
Returns a write-capable Transaction session by default, matching every
other sqlspec adapter. Pass ``transaction=False`` or use
:meth:`provide_read_session` to obtain a read-only Snapshot session.
Args:
*args: Additional arguments.
statement_config: Optional statement configuration override.
transaction: Whether to use a Transaction (True, default) or
Snapshot (False).
**kwargs: Additional keyword arguments.
Returns:
A Spanner driver session context manager.
"""
connection_ctx = SpannerConnectionContext(self, transaction=transaction)
handler = _SpannerSessionConnectionHandler(self, connection_ctx)
return SpannerSessionContext(
acquire_connection=handler.acquire_connection,
release_connection=handler.release_connection,
statement_config=statement_config or self.statement_config or default_statement_config,
driver_features=self.driver_features,
prepare_driver=self._prepare_driver,
)
[docs]
def provide_write_session(
self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
) -> "SpannerSessionContext":
"""Provide a write-capable Spanner session (alias for :meth:`provide_session`)."""
return self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs)
[docs]
def provide_read_session(
self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
) -> "SpannerSessionContext":
"""Provide a read-only Snapshot Spanner session.
Use for query workloads that benefit from Spanner's snapshot reads.
For DDL/DML, use :meth:`provide_session` (write-capable by default).
"""
return self.provide_session(*args, statement_config=statement_config, transaction=False, **kwargs)
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for SpannerSyncConfig types.
Returns:
Dictionary mapping type names to types.
"""
namespace = super().get_signature_namespace()
namespace.update({
"SpannerConnectionContext": SpannerConnectionContext,
"SpannerConnection": SpannerConnection,
"SpannerConnectionParams": SpannerConnectionParams,
"SpannerDriverFeatures": SpannerDriverFeatures,
"SpannerPoolParams": SpannerPoolParams,
"SpannerSessionContext": SpannerSessionContext,
"SpannerSyncConfig": SpannerSyncConfig,
"SpannerSyncDriver": SpannerSyncDriver,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
"""Return queue defaults for Spanner JSON handling."""
return EventRuntimeHints(json_passthrough=True)