"""Spanner configuration."""
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from google.cloud.spanner_v1 import Client
from google.cloud.spanner_v1.pool import AbstractSessionPool, FixedSizePool
from typing_extensions import NotRequired
from sqlspec.adapters.spanner._types import SpannerConnection
from sqlspec.adapters.spanner.driver import SpannerSyncDriver, spanner_statement_config
from sqlspec.config import SyncDatabaseConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.events._hints import EventRuntimeHints
from sqlspec.utils.config_normalization import apply_pool_deprecations, normalize_connection_config
from sqlspec.utils.serializers import from_json, to_json
if TYPE_CHECKING:
from google.auth.credentials import Credentials
from google.cloud.spanner_v1.database import Database
from sqlspec.config import ExtensionConfigs
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = ("SpannerConnectionParams", "SpannerDriverFeatures", "SpannerPoolParams", "SpannerSyncConfig")
class SpannerConnectionParams(TypedDict):
"""Spanner connection parameters."""
project: "NotRequired[str]"
instance_id: "NotRequired[str]"
database_id: "NotRequired[str]"
credentials: "NotRequired[Credentials]"
client_options: "NotRequired[dict[str, Any]]"
extra: "NotRequired[dict[str, Any]]"
class SpannerPoolParams(SpannerConnectionParams):
"""Session pool configuration."""
pool_type: "NotRequired[type[AbstractSessionPool]]"
min_sessions: "NotRequired[int]"
max_sessions: "NotRequired[int]"
labels: "NotRequired[dict[str, str]]"
ping_interval: "NotRequired[int]"
class SpannerDriverFeatures(TypedDict):
"""Driver feature flags for Spanner.
Attributes:
enable_uuid_conversion: Enable automatic UUID string conversion.
json_serializer: Custom JSON serializer for parameter conversion.
json_deserializer: Custom JSON deserializer for result conversion.
session_labels: Labels to apply to Spanner sessions.
enable_events: Enable database event channel support.
Defaults to True when extension_config["events"] is configured.
events_backend: Backend type for event handling.
Spanner only supports "table_queue" (no native pub/sub).
"""
enable_uuid_conversion: "NotRequired[bool]"
json_serializer: "NotRequired[Callable[[Any], str]]"
json_deserializer: "NotRequired[Callable[[str], Any]]"
session_labels: "NotRequired[dict[str, str]]"
[docs]
class SpannerSyncConfig(SyncDatabaseConfig["SpannerConnection", "AbstractSessionPool", SpannerSyncDriver]):
"""Spanner configuration and session management."""
driver_type: ClassVar[type["SpannerSyncDriver"]] = SpannerSyncDriver
connection_type: ClassVar[type["SpannerConnection"]] = cast("type[SpannerConnection]", SpannerConnection)
supports_transactional_ddl: ClassVar[bool] = False
supports_native_arrow_export: ClassVar[bool] = True
supports_native_arrow_import: ClassVar[bool] = True
supports_native_parquet_export: ClassVar[bool] = False
supports_native_parquet_import: ClassVar[bool] = False
requires_staging_for_load: ClassVar[bool] = False
[docs]
def __init__(
self,
*,
connection_config: "SpannerPoolParams | dict[str, Any] | None" = None,
connection_instance: "AbstractSessionPool | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "SpannerDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
connection_config, connection_instance = apply_pool_deprecations(
kwargs=kwargs, connection_config=connection_config, connection_instance=connection_instance
)
self.connection_config = normalize_connection_config(connection_config)
self.connection_config.setdefault("min_sessions", 1)
self.connection_config.setdefault("max_sessions", 10)
self.connection_config.setdefault("pool_type", FixedSizePool)
features: dict[str, Any] = dict(driver_features) if driver_features else {}
features.setdefault("enable_uuid_conversion", True)
features.setdefault("json_serializer", to_json)
features.setdefault("json_deserializer", from_json)
base_statement_config = statement_config or spanner_statement_config
super().__init__(
connection_config=self.connection_config,
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=base_statement_config,
driver_features=features,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
self._client: Client | None = None
self._database: Database | None = None
def _get_client(self) -> Client:
if self._client is None:
self._client = Client(
project=self.connection_config.get("project"),
credentials=self.connection_config.get("credentials"),
client_options=self.connection_config.get("client_options"),
)
return self._client
[docs]
def get_database(self) -> "Database":
instance_id = self.connection_config.get("instance_id")
database_id = self.connection_config.get("database_id")
if not instance_id or not database_id:
msg = "instance_id and database_id are required."
raise ImproperConfigurationError(msg)
if self.connection_instance is None:
self.connection_instance = self.provide_pool()
if self._database is None:
client = self._get_client()
self._database = client.instance(instance_id).database(database_id, pool=self.connection_instance) # type: ignore[no-untyped-call]
return self._database
[docs]
def create_connection(self) -> SpannerConnection:
instance_id = self.connection_config.get("instance_id")
database_id = self.connection_config.get("database_id")
if not instance_id or not database_id:
msg = "instance_id and database_id are required."
raise ImproperConfigurationError(msg)
if self.connection_instance is None:
self.connection_instance = self.provide_pool()
client = self._get_client()
database = client.instance(instance_id).database(database_id, pool=self.connection_instance) # type: ignore[no-untyped-call]
return cast("SpannerConnection", database.snapshot())
def _create_pool(self) -> AbstractSessionPool:
instance_id = self.connection_config.get("instance_id")
database_id = self.connection_config.get("database_id")
if not instance_id or not database_id:
msg = "instance_id and database_id are required."
raise ImproperConfigurationError(msg)
pool_type = cast("type[AbstractSessionPool]", self.connection_config.get("pool_type", FixedSizePool))
pool_kwargs: dict[str, Any] = {}
if pool_type is FixedSizePool:
if "size" in self.connection_config:
pool_kwargs["size"] = self.connection_config["size"]
elif "max_sessions" in self.connection_config:
pool_kwargs["size"] = self.connection_config["max_sessions"]
if "labels" in self.connection_config:
pool_kwargs["labels"] = self.connection_config["labels"]
else:
valid_pool_keys = {"size", "labels", "ping_interval"}
pool_kwargs = {k: v for k, v in self.connection_config.items() if k in valid_pool_keys and v is not None}
if "size" not in pool_kwargs and "max_sessions" in self.connection_config:
pool_kwargs["size"] = self.connection_config["max_sessions"]
pool_factory = cast("Callable[..., AbstractSessionPool]", pool_type)
return pool_factory(**pool_kwargs)
def _close_pool(self) -> None:
if self.connection_instance and hasattr(self.connection_instance, "close"):
cast("Any", self.connection_instance).close()
[docs]
@contextmanager
def provide_connection(
self, *args: Any, transaction: "bool" = False, **kwargs: Any
) -> Generator[SpannerConnection, None, None]:
"""Yield a Snapshot (default) or Transaction context from the configured pool.
Args:
*args: Additional positional arguments (unused, for interface compatibility).
transaction: If True, yields a Transaction context that supports
execute_update() for DML statements. If False (default), yields
a read-only Snapshot context for SELECT queries.
**kwargs: Additional keyword arguments (unused, for interface compatibility).
Note: For complex transactional logic with retries, use database.run_in_transaction()
directly. The Transaction context here auto-commits on successful exit.
"""
database = self.get_database()
if transaction:
session = cast("Any", database).session()
session.create()
try:
txn = session.transaction()
txn.__enter__()
try:
yield cast("SpannerConnection", txn)
# Only commit if not already committed (driver.commit() may have been called)
has_txn_id = hasattr(txn, "_transaction_id") and txn._transaction_id is not None
already_committed = hasattr(txn, "committed") and txn.committed is not None
if has_txn_id and not already_committed:
txn.commit()
except Exception:
if hasattr(txn, "_transaction_id") and txn._transaction_id is not None:
txn.rollback()
raise
finally:
session.delete()
else:
with cast("Any", database).snapshot(multi_use=True) as snapshot:
yield cast("SpannerConnection", snapshot)
[docs]
@contextmanager
def provide_session(
self, *args: Any, statement_config: "StatementConfig | None" = None, transaction: "bool" = False, **kwargs: Any
) -> Generator[SpannerSyncDriver, None, None]:
with self.provide_connection(*args, transaction=transaction, **kwargs) as connection:
driver = self.driver_type(
connection=connection,
statement_config=statement_config or self.statement_config,
driver_features=self.driver_features,
)
yield self._prepare_driver(driver)
[docs]
@contextmanager
def provide_write_session(
self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
) -> Generator[SpannerSyncDriver, None, None]:
with self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs) as driver:
yield driver
[docs]
def get_signature_namespace(self) -> dict[str, Any]:
namespace = super().get_signature_namespace()
namespace.update({
"SpannerSyncConfig": SpannerSyncConfig,
"SpannerConnectionParams": SpannerConnectionParams,
"SpannerDriverFeatures": SpannerDriverFeatures,
"SpannerSyncDriver": SpannerSyncDriver,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
"""Return queue defaults for Spanner JSON handling."""
return EventRuntimeHints(json_passthrough=True)