Source code for sqlspec.adapters.cockroach_psycopg.driver

"""CockroachDB psycopg driver implementation."""

import asyncio
import contextlib
import time
from typing import TYPE_CHECKING, Any, cast

import psycopg
from typing_extensions import Self

from sqlspec.adapters.cockroach_psycopg._typing import (
    CockroachAsyncConnection,
    CockroachPsycopgAsyncSessionContext,
    CockroachPsycopgSyncSessionContext,
    CockroachSyncConnection,
)
from sqlspec.adapters.cockroach_psycopg.core import (
    CockroachPsycopgRetryConfig,
    apply_driver_features,
    build_statement_config,
    calculate_backoff_seconds,
    driver_profile,
    is_retryable_error,
)
from sqlspec.adapters.cockroach_psycopg.data_dictionary import (
    CockroachPsycopgAsyncDataDictionary,
    CockroachPsycopgSyncDataDictionary,
)
from sqlspec.adapters.psycopg.core import create_mapped_exception
from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver
from sqlspec.core import SQL, StatementConfig, get_cache_config, register_driver_profile
from sqlspec.exceptions import SerializationConflictError, TransactionRetryError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import has_sqlstate

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.driver import ExecutionResult

__all__ = (
    "CockroachPsycopgAsyncDriver",
    "CockroachPsycopgAsyncExceptionHandler",
    "CockroachPsycopgAsyncSessionContext",
    "CockroachPsycopgSyncDriver",
    "CockroachPsycopgSyncExceptionHandler",
    "CockroachPsycopgSyncSessionContext",
)

logger = get_logger("sqlspec.adapters.cockroach_psycopg")


class CockroachPsycopgSyncExceptionHandler:
    """Context manager for handling CockroachDB psycopg exceptions."""

    __slots__ = ("pending_exception",)

    def __init__(self) -> None:
        self.pending_exception: Exception | None = None

    def __enter__(self) -> Self:
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, psycopg.Error):
            if has_sqlstate(exc_val) and str(exc_val.sqlstate) == "40001":
                self.pending_exception = SerializationConflictError(str(exc_val))
                return True
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


class CockroachPsycopgAsyncExceptionHandler:
    """Async context manager for handling CockroachDB psycopg exceptions."""

    __slots__ = ("pending_exception",)

    def __init__(self) -> None:
        self.pending_exception: Exception | None = None

    async def __aenter__(self) -> Self:
        return self

    async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, psycopg.Error):
            if has_sqlstate(exc_val) and str(exc_val.sqlstate) == "40001":
                self.pending_exception = SerializationConflictError(str(exc_val))
                return True
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


[docs] class CockroachPsycopgSyncDriver(PsycopgSyncDriver): """CockroachDB sync driver using psycopg.crdb.""" __slots__ = ("_enable_retry", "_follower_staleness", "_retry_config") dialect = "postgres"
[docs] def __init__( self, connection: CockroachSyncConnection, statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = build_statement_config().replace( enable_caching=get_cache_config().compiled_cache_enabled ) statement_config, normalized_features = apply_driver_features(statement_config, driver_features) super().__init__(connection=connection, statement_config=statement_config, driver_features=normalized_features) self._retry_config = CockroachPsycopgRetryConfig.from_features(self.driver_features) self._enable_retry = bool(self.driver_features.get("enable_auto_retry", True)) self._follower_staleness = cast("str | None", self.driver_features.get("default_staleness")) # Data dictionary is lazily initialized in property; use parent slot self._data_dictionary = None
def _execute_with_retry(self, operation: "Callable[..., ExecutionResult]", *args: Any) -> "ExecutionResult": if not self._enable_retry: return operation(*args) last_error: Exception | None = None for attempt in range(self._retry_config.max_retries + 1): try: return operation(*args) except Exception as exc: last_error = exc if not is_retryable_error(exc) or attempt >= self._retry_config.max_retries: raise with contextlib.suppress(Exception): self.connection.rollback() delay = calculate_backoff_seconds(attempt, self._retry_config) if self._retry_config.enable_logging: logger.debug("CockroachDB retry %s/%s after %.3fs", attempt + 1, self._retry_config.max_retries, delay) time.sleep(delay) msg = "CockroachDB transaction retry limit exceeded" raise TransactionRetryError(msg) from last_error def _apply_follower_reads(self, cursor: Any) -> None: if not self.driver_features.get("enable_follower_reads", False): return if not self._follower_staleness: return cursor.execute(f"SET TRANSACTION AS OF SYSTEM TIME {self._follower_staleness}") def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": if statement.returns_rows(): self._apply_follower_reads(cursor) return super().dispatch_execute(cursor, statement) def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return super().dispatch_execute_many(cursor, statement) def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return super().dispatch_execute_script(cursor, statement)
[docs] def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return self._dispatch_execute_impl(cursor, statement) return self._execute_with_retry(self._dispatch_execute_impl, cursor, statement)
[docs] def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return super().dispatch_execute_many(cursor, statement) return self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement)
[docs] def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return super().dispatch_execute_script(cursor, statement) return self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement)
[docs] def handle_database_exceptions(self) -> "CockroachPsycopgSyncExceptionHandler": # type: ignore[override] return CockroachPsycopgSyncExceptionHandler()
@property def data_dictionary(self) -> "CockroachPsycopgSyncDataDictionary": # type: ignore[override] if self._data_dictionary is None: # Intentionally assign CockroachDB-specific data dictionary to parent slot object.__setattr__(self, "_data_dictionary", CockroachPsycopgSyncDataDictionary()) return cast("CockroachPsycopgSyncDataDictionary", self._data_dictionary)
[docs] class CockroachPsycopgAsyncDriver(PsycopgAsyncDriver): """CockroachDB async driver using psycopg.crdb.""" __slots__ = ("_enable_retry", "_follower_staleness", "_retry_config") dialect = "postgres"
[docs] def __init__( self, connection: CockroachAsyncConnection, statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = build_statement_config().replace( enable_caching=get_cache_config().compiled_cache_enabled ) statement_config, normalized_features = apply_driver_features(statement_config, driver_features) super().__init__(connection=connection, statement_config=statement_config, driver_features=normalized_features) self._retry_config = CockroachPsycopgRetryConfig.from_features(self.driver_features) self._enable_retry = bool(self.driver_features.get("enable_auto_retry", True)) self._follower_staleness = cast("str | None", self.driver_features.get("default_staleness")) # Data dictionary is lazily initialized in property; use parent slot self._data_dictionary = None
async def _execute_with_retry(self, operation: "Callable[..., Any]", *args: Any) -> "ExecutionResult": if not self._enable_retry: return cast("ExecutionResult", await operation(*args)) last_error: Exception | None = None for attempt in range(self._retry_config.max_retries + 1): try: return cast("ExecutionResult", await operation(*args)) except Exception as exc: last_error = exc if not is_retryable_error(exc) or attempt >= self._retry_config.max_retries: raise with contextlib.suppress(Exception): await self.connection.rollback() delay = calculate_backoff_seconds(attempt, self._retry_config) if self._retry_config.enable_logging: logger.debug("CockroachDB retry %s/%s after %.3fs", attempt + 1, self._retry_config.max_retries, delay) await asyncio.sleep(delay) msg = "CockroachDB transaction retry limit exceeded" raise TransactionRetryError(msg) from last_error async def _apply_follower_reads(self, cursor: Any) -> None: if not self.driver_features.get("enable_follower_reads", False): return if not self._follower_staleness: return await cursor.execute(f"SET TRANSACTION AS OF SYSTEM TIME {self._follower_staleness}") async def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": if statement.returns_rows(): await self._apply_follower_reads(cursor) return await super().dispatch_execute(cursor, statement) async def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return await super().dispatch_execute_many(cursor, statement) async def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return await super().dispatch_execute_script(cursor, statement)
[docs] async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await self._dispatch_execute_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_impl, cursor, statement)
[docs] async def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_many(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement)
[docs] async def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_script(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement)
[docs] def handle_database_exceptions(self) -> "CockroachPsycopgAsyncExceptionHandler": # type: ignore[override] return CockroachPsycopgAsyncExceptionHandler()
@property def data_dictionary(self) -> "CockroachPsycopgAsyncDataDictionary": # type: ignore[override] if self._data_dictionary is None: # Intentionally assign CockroachDB-specific data dictionary to parent slot object.__setattr__(self, "_data_dictionary", CockroachPsycopgAsyncDataDictionary()) return cast("CockroachPsycopgAsyncDataDictionary", self._data_dictionary)
register_driver_profile("cockroach_psycopg", driver_profile)