Source code for sqlspec.adapters.psycopg.driver

"""PostgreSQL psycopg driver implementation."""

from collections.abc import Sized
from contextlib import AsyncExitStack, ExitStack
from typing import TYPE_CHECKING, Any, cast

import psycopg
from typing_extensions import Self

from sqlspec.adapters.psycopg._typing import (
    PsycopgAsyncConnection,
    PsycopgAsyncSessionContext,
    PsycopgSyncConnection,
    PsycopgSyncSessionContext,
)
from sqlspec.adapters.psycopg.core import (
    TRANSACTION_STATUS_IDLE,
    PipelineCursorEntry,
    PreparedStackOperation,
    build_async_pipeline_execution_result,
    build_copy_from_command,
    build_pipeline_execution_result,
    build_truncate_command,
    create_mapped_exception,
    default_statement_config,
    driver_profile,
    execute_with_optional_parameters,
    execute_with_optional_parameters_async,
    pipeline_supported,
    resolve_many_rowcount,
    resolve_rowcount,
)
from sqlspec.adapters.psycopg.data_dictionary import PsycopgAsyncDataDictionary, PsycopgSyncDataDictionary
from sqlspec.core import (
    SQL,
    SQLResult,
    StackResult,
    StatementConfig,
    StatementStack,
    get_cache_config,
    is_copy_from_operation,
    is_copy_operation,
    is_copy_to_operation,
    register_driver_profile,
)
from sqlspec.driver import (
    AsyncDriverAdapterBase,
    StackExecutionObserver,
    SyncDriverAdapterBase,
    describe_stack_statement,
)
from sqlspec.exceptions import SQLSpecError, StackExecutionError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import is_readable

if TYPE_CHECKING:
    from sqlspec.adapters.psycopg._typing import PsycopgPipelineDriver
    from sqlspec.core import ArrowResult
    from sqlspec.driver import ExecutionResult
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry


__all__ = (
    "PsycopgAsyncCursor",
    "PsycopgAsyncDriver",
    "PsycopgAsyncExceptionHandler",
    "PsycopgAsyncSessionContext",
    "PsycopgSyncCursor",
    "PsycopgSyncDriver",
    "PsycopgSyncExceptionHandler",
    "PsycopgSyncSessionContext",
)

logger = get_logger("sqlspec.adapters.psycopg")
COLUMN_CACHE_MAX_SIZE = 256


class PsycopgPipelineMixin:
    """Shared helpers for psycopg sync/async pipeline execution."""

    __slots__ = ()

    def _prepare_pipeline_operations(self, stack: "StatementStack") -> "list[PreparedStackOperation] | None":
        prepared: list[PreparedStackOperation] = []
        for index, operation in enumerate(stack.operations):
            if operation.method != "execute":
                return None

            kwargs = dict(operation.keyword_arguments) if operation.keyword_arguments else {}
            statement_config = kwargs.pop("statement_config", None)
            driver = cast("PsycopgPipelineDriver", self)
            config = statement_config or driver.statement_config

            sql_statement = driver.prepare_statement(
                operation.statement, operation.arguments, statement_config=config, kwargs=kwargs
            )

            if sql_statement.is_script or sql_statement.is_many:
                return None

            sql_text, prepared_parameters = driver._get_compiled_sql(  # pyright: ignore[reportPrivateUsage]
                sql_statement, config
            )
            prepared.append(
                PreparedStackOperation(
                    operation_index=index,
                    operation=operation,
                    statement=sql_statement,
                    sql=sql_text,
                    parameters=prepared_parameters,
                )
            )
        return prepared


class PsycopgSyncCursor:
    """Context manager for PostgreSQL psycopg cursor management."""

    __slots__ = ("connection", "cursor")

    def __init__(self, connection: PsycopgSyncConnection) -> None:
        self.connection = connection
        self.cursor: Any | None = None

    def __enter__(self) -> Any:
        self.cursor = self.connection.cursor()
        return self.cursor

    def __exit__(self, *_: Any) -> None:
        if self.cursor is not None:
            self.cursor.close()


class PsycopgSyncExceptionHandler:
    """Context manager for handling PostgreSQL psycopg database exceptions.

    Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
    for better error handling in application code.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __exit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ("pending_exception",)

    def __init__(self) -> None:
        self.pending_exception: Exception | None = None

    def __enter__(self) -> Self:
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, psycopg.Error):
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


[docs] class PsycopgSyncDriver(PsycopgPipelineMixin, SyncDriverAdapterBase): """PostgreSQL psycopg synchronous driver. Provides synchronous database operations for PostgreSQL using psycopg3. Supports SQL statement execution with parameter binding, transaction management, result processing with column metadata, parameter style conversion, PostgreSQL arrays and JSON handling, COPY operations for bulk data transfer, and PostgreSQL-specific error handling. """ __slots__ = ("_column_name_cache", "_data_dictionary") dialect = "postgres"
[docs] def __init__( self, connection: PsycopgSyncConnection, statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: PsycopgSyncDataDictionary | None = None self._column_name_cache: dict[int, tuple[Any, list[str]]] = {}
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: cursor: Database cursor statement: SQL statement to execute Returns: ExecutionResult with statement execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) execute_with_optional_parameters(cursor, sql, prepared_parameters) if statement.returns_rows(): fetched_data = cursor.fetchall() data = cast("list[Any] | None", fetched_data) or [] column_names = self._resolve_column_names(cursor.description) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True, row_format="tuple", ) affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs] def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets. Args: cursor: Database cursor statement: SQL statement with parameter list Returns: ExecutionResult with batch execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) if not prepared_parameters: return self.create_execution_result(cursor, rowcount_override=0, is_many_result=True) parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None cursor.executemany(sql, prepared_parameters) affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with multiple statements. Args: cursor: Database cursor statement: SQL statement containing multiple commands Returns: ExecutionResult with script execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: execute_with_optional_parameters(cursor, stmt, prepared_parameters) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
[docs] def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None": """Hook for PostgreSQL-specific special operations. Args: cursor: Psycopg cursor object statement: SQL statement to analyze Returns: SQLResult if special handling was applied, None otherwise """ if not is_copy_operation(statement.operation_type): return None sql, _ = self._get_compiled_sql(statement, statement.statement_config) operation_type = statement.operation_type copy_data = statement.parameters if isinstance(copy_data, list) and len(copy_data) == 1: copy_data = copy_data[0] if is_copy_from_operation(operation_type): if isinstance(copy_data, (str, bytes)): data_to_write = copy_data elif is_readable(copy_data): data_to_write = copy_data.read() else: data_to_write = str(copy_data) if isinstance(data_to_write, str): data_to_write = data_to_write.encode() with cursor.copy(sql) as copy_ctx: copy_ctx.write(data_to_write) rows_affected = max(cursor.rowcount, 0) return SQLResult( data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FROM_STDIN"} ) if is_copy_to_operation(operation_type): output_data: list[str] = [] with cursor.copy(sql) as copy_ctx: output_data.extend(row.decode() if isinstance(row, bytes) else str(row) for row in copy_ctx) exported_data = "".join(output_data) return SQLResult( data=[{"copy_output": exported_data}], rows_affected=0, statement=statement, metadata={"copy_operation": "TO_STDOUT"}, ) cursor.execute(sql) rows_affected = max(cursor.rowcount, 0) return SQLResult( data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FILE"} )
# ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] def begin(self) -> None: """Begin a database transaction on the current connection.""" try: if self.connection.autocommit: self.connection.autocommit = False except Exception as e: msg = f"Failed to begin transaction: {e}" raise SQLSpecError(msg) from e
[docs] def commit(self) -> None: """Commit the current transaction on the current connection.""" try: self.connection.commit() except Exception as e: msg = f"Failed to commit transaction: {e}" raise SQLSpecError(msg) from e
[docs] def rollback(self) -> None: """Rollback the current transaction on the current connection.""" try: self.connection.rollback() except Exception as e: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e
[docs] def with_cursor(self, connection: PsycopgSyncConnection) -> PsycopgSyncCursor: """Create context manager for PostgreSQL cursor.""" return PsycopgSyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "PsycopgSyncExceptionHandler": """Handle database-specific exceptions and wrap them appropriately.""" return PsycopgSyncExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # STACK EXECUTION METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def execute_stack(self, stack: "StatementStack", *, continue_on_error: bool = False) -> "tuple[StackResult, ...]": """Execute a StatementStack using psycopg pipeline mode when supported.""" if ( not isinstance(stack, StatementStack) or not stack or self.stack_native_disabled or not pipeline_supported() or continue_on_error ): return super().execute_stack(stack, continue_on_error=continue_on_error) prepared_ops = self._prepare_pipeline_operations(stack) if prepared_ops is None: return super().execute_stack(stack, continue_on_error=continue_on_error) return self._execute_stack_pipeline(stack, prepared_ops)
def _execute_stack_pipeline( self, stack: "StatementStack", prepared_ops: "list[PreparedStackOperation]" ) -> "tuple[StackResult, ...]": def _raise_pending_exception(exception_ctx: "PsycopgSyncExceptionHandler") -> None: if exception_ctx.pending_exception is not None: raise exception_ctx.pending_exception from None results: list[StackResult] = [] started_transaction = False with StackExecutionObserver(self, stack, continue_on_error=False, native_pipeline=True): try: if not self._connection_in_transaction(): self.begin() started_transaction = True exception_handlers = [] with ExitStack() as resource_stack: pipeline = resource_stack.enter_context(self.connection.pipeline()) pending: list[PipelineCursorEntry] = [] for prepared in prepared_ops: exception_ctx = self.handle_database_exceptions() exception_handlers.append(exception_ctx) resource_stack.enter_context(exception_ctx) cursor = resource_stack.enter_context(self.with_cursor(self.connection)) try: if prepared.parameters: cursor.execute(prepared.sql, prepared.parameters) else: cursor.execute(prepared.sql) except Exception as exc: stack_error = StackExecutionError( prepared.operation_index, describe_stack_statement(prepared.operation.statement), exc, adapter=type(self).__name__, mode="fail-fast", ) raise stack_error from exc pending.append(PipelineCursorEntry(prepared=prepared, cursor=cursor)) pipeline.sync() for entry in pending: statement = entry.prepared.statement cursor = entry.cursor execution_result = build_pipeline_execution_result( statement, cursor, column_name_resolver=self._resolve_column_names ) sql_result = self.build_statement_result(statement, execution_result) results.append(StackResult.from_sql_result(sql_result)) for exception_ctx in exception_handlers: _raise_pending_exception(exception_ctx) if started_transaction: self.commit() except Exception: if started_transaction: try: self.rollback() except Exception as rollback_error: # pragma: no cover - diagnostics only logger.debug("Rollback after psycopg pipeline failure failed: %s", rollback_error) raise return tuple(results) # ───────────────────────────────────────────────────────────────────────────── # STORAGE API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute a query and stream Arrow results to storage (sync).""" self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline = self._storage_pipeline() telemetry_payload = self._write_result_to_storage_sync( arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Load Arrow data into PostgreSQL using COPY.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: truncate_sql = build_truncate_command(table) exc_handler = self.handle_database_exceptions() with self.with_cursor(self.connection) as cursor, exc_handler: cursor.execute(truncate_sql) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: copy_sql = build_copy_from_command(table, columns) exc_handler = self.handle_database_exceptions() with ExitStack() as stack: stack.enter_context(exc_handler) cursor = stack.enter_context(self.with_cursor(self.connection)) copy_ctx = stack.enter_context(cursor.copy(copy_sql)) for record in records: copy_ctx.write_row(record) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Load staged artifacts into PostgreSQL via COPY.""" arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "PsycopgSyncDataDictionary": """Get the data dictionary for this driver. Returns: Data dictionary instance for metadata queries """ if self._data_dictionary is None: self._data_dictionary = PsycopgSyncDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE / INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── def _resolve_column_names(self, description: Any) -> list[str]: """Resolve and cache psycopg column names for hot row materialization paths.""" if not description: return [] cache_key = id(description) cached = self._column_name_cache.get(cache_key) if cached is not None and cached[0] is description: return cached[1] column_names = [col.name for col in description] if len(self._column_name_cache) >= COLUMN_CACHE_MAX_SIZE: self._column_name_cache.pop(next(iter(self._column_name_cache))) self._column_name_cache[cache_key] = (description, column_names) return column_names
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psycopg sync rows for the direct execution path.""" data = cast("list[Any] | None", fetched) or [] column_names = self._resolve_column_names(cursor.description) return data, column_names, len(data)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from psycopg cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction.""" return bool(self.connection.info.transaction_status != TRANSACTION_STATUS_IDLE)
class PsycopgAsyncCursor: """Async context manager for PostgreSQL psycopg cursor management.""" __slots__ = ("connection", "cursor") def __init__(self, connection: "PsycopgAsyncConnection") -> None: self.connection = connection self.cursor: Any | None = None async def __aenter__(self) -> Any: self.cursor = self.connection.cursor() return self.cursor async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: _ = (exc_type, exc_val, exc_tb) if self.cursor is not None: await self.cursor.close() class PsycopgAsyncExceptionHandler: """Async context manager for handling PostgreSQL psycopg database exceptions. Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions for better error handling in application code. Uses deferred exception pattern for mypyc compatibility: exceptions are stored in pending_exception rather than raised from __aexit__ to avoid ABI boundary violations with compiled code. """ __slots__ = ("pending_exception",) def __init__(self) -> None: self.pending_exception: Exception | None = None async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): self.pending_exception = create_mapped_exception(exc_val) return True return False
[docs] class PsycopgAsyncDriver(PsycopgPipelineMixin, AsyncDriverAdapterBase): """PostgreSQL psycopg asynchronous driver. Provides asynchronous database operations for PostgreSQL using psycopg3. Supports async SQL statement execution with parameter binding, async transaction management, async result processing with column metadata, parameter style conversion, PostgreSQL arrays and JSON handling, COPY operations for bulk data transfer, PostgreSQL-specific error handling, and async pub/sub support. """ __slots__ = ("_column_name_cache", "_data_dictionary") dialect = "postgres"
[docs] def __init__( self, connection: "PsycopgAsyncConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: PsycopgAsyncDataDictionary | None = None self._column_name_cache: dict[int, tuple[Any, list[str]]] = {}
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement (async). Args: cursor: Database cursor statement: SQL statement to execute Returns: ExecutionResult with statement execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await execute_with_optional_parameters_async(cursor, sql, prepared_parameters) if statement.returns_rows(): fetched_data = await cursor.fetchall() data = cast("list[Any] | None", fetched_data) or [] column_names = self._resolve_column_names(cursor.description) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True, row_format="tuple", ) affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs] async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets (async). Args: cursor: Database cursor statement: SQL statement with parameter list Returns: ExecutionResult with batch execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) if not prepared_parameters: return self.create_execution_result(cursor, rowcount_override=0, is_many_result=True) parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None await cursor.executemany(sql, prepared_parameters) affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with multiple statements (async). Args: cursor: Database cursor statement: SQL statement containing multiple commands Returns: ExecutionResult with script execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: await execute_with_optional_parameters_async(cursor, stmt, prepared_parameters) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
[docs] async def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None": """Hook for PostgreSQL-specific special operations. Args: cursor: Psycopg async cursor object statement: SQL statement to analyze Returns: SQLResult if special handling was applied, None otherwise """ if not is_copy_operation(statement.operation_type): return None sql, _ = self._get_compiled_sql(statement, statement.statement_config) sql_upper = sql.upper() operation_type = statement.operation_type copy_data = statement.parameters if isinstance(copy_data, list) and len(copy_data) == 1: copy_data = copy_data[0] if is_copy_from_operation(operation_type) and "FROM STDIN" in sql_upper: if isinstance(copy_data, (str, bytes)): data_to_write = copy_data elif is_readable(copy_data): data_to_write = copy_data.read() else: data_to_write = str(copy_data) if isinstance(data_to_write, str): data_to_write = data_to_write.encode() async with cursor.copy(sql) as copy_ctx: await copy_ctx.write(data_to_write) rows_affected = max(cursor.rowcount, 0) return SQLResult( data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FROM_STDIN"} ) if is_copy_to_operation(operation_type) and "TO STDOUT" in sql_upper: output_data: list[str] = [] async with cursor.copy(sql) as copy_ctx: output_data.extend([row.decode() if isinstance(row, bytes) else str(row) async for row in copy_ctx]) exported_data = "".join(output_data) return SQLResult( data=[{"copy_output": exported_data}], rows_affected=0, statement=statement, metadata={"copy_operation": "TO_STDOUT"}, ) await cursor.execute(sql) rows_affected = max(cursor.rowcount, 0) return SQLResult( data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FILE"} )
# ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] async def begin(self) -> None: """Begin a database transaction on the current connection.""" try: try: autocommit_flag = self.connection.autocommit except AttributeError: autocommit_flag = None if isinstance(autocommit_flag, bool) and not autocommit_flag: return await self.connection.set_autocommit(False) except Exception as e: msg = f"Failed to begin transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def commit(self) -> None: """Commit the current transaction on the current connection.""" try: await self.connection.commit() except Exception as e: msg = f"Failed to commit transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def rollback(self) -> None: """Rollback the current transaction on the current connection.""" try: await self.connection.rollback() except Exception as e: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e
[docs] def with_cursor(self, connection: "PsycopgAsyncConnection") -> "PsycopgAsyncCursor": """Create async context manager for PostgreSQL cursor.""" return PsycopgAsyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "PsycopgAsyncExceptionHandler": """Handle database-specific exceptions and wrap them appropriately.""" return PsycopgAsyncExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # STACK EXECUTION METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def execute_stack( self, stack: "StatementStack", *, continue_on_error: bool = False ) -> "tuple[StackResult, ...]": """Execute a StatementStack using psycopg async pipeline when supported.""" if ( not isinstance(stack, StatementStack) or not stack or self.stack_native_disabled or not pipeline_supported() or continue_on_error ): return await super().execute_stack(stack, continue_on_error=continue_on_error) prepared_ops = self._prepare_pipeline_operations(stack) if prepared_ops is None: return await super().execute_stack(stack, continue_on_error=continue_on_error) return await self._execute_stack_pipeline(stack, prepared_ops)
async def _execute_stack_pipeline( self, stack: "StatementStack", prepared_ops: "list[PreparedStackOperation]" ) -> "tuple[StackResult, ...]": def _raise_pending_exception(exception_ctx: "PsycopgAsyncExceptionHandler") -> None: if exception_ctx.pending_exception is not None: raise exception_ctx.pending_exception from None results: list[StackResult] = [] started_transaction = False with StackExecutionObserver(self, stack, continue_on_error=False, native_pipeline=True): try: if not self._connection_in_transaction(): await self.begin() started_transaction = True exception_handlers = [] async with AsyncExitStack() as resource_stack: pipeline = await resource_stack.enter_async_context(self.connection.pipeline()) pending: list[PipelineCursorEntry] = [] for prepared in prepared_ops: exception_ctx = self.handle_database_exceptions() exception_handlers.append(exception_ctx) await resource_stack.enter_async_context(exception_ctx) cursor = await resource_stack.enter_async_context(self.with_cursor(self.connection)) try: if prepared.parameters: await cursor.execute(prepared.sql, prepared.parameters) else: await cursor.execute(prepared.sql) except Exception as exc: stack_error = StackExecutionError( prepared.operation_index, describe_stack_statement(prepared.operation.statement), exc, adapter=type(self).__name__, mode="fail-fast", ) raise stack_error from exc pending.append(PipelineCursorEntry(prepared=prepared, cursor=cursor)) await pipeline.sync() for entry in pending: statement = entry.prepared.statement cursor = entry.cursor execution_result = await build_async_pipeline_execution_result( statement, cursor, column_name_resolver=self._resolve_column_names ) sql_result = self.build_statement_result(statement, execution_result) results.append(StackResult.from_sql_result(sql_result)) for exception_ctx in exception_handlers: _raise_pending_exception(exception_ctx) if started_transaction: await self.commit() except Exception: if started_transaction: try: await self.rollback() except Exception as rollback_error: # pragma: no cover - diagnostics only logger.debug("Rollback after psycopg pipeline failure failed: %s", rollback_error) raise return tuple(results) # ───────────────────────────────────────────────────────────────────────────── # STORAGE API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute a query and stream Arrow data to storage asynchronously.""" self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline = self._storage_pipeline() telemetry_payload = await self._write_result_to_storage_async( arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Load Arrow data into PostgreSQL asynchronously via COPY.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: truncate_sql = build_truncate_command(table) exc_handler = self.handle_database_exceptions() async with self.with_cursor(self.connection) as cursor, exc_handler: await cursor.execute(truncate_sql) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: copy_sql = build_copy_from_command(table, columns) exc_handler = self.handle_database_exceptions() async with AsyncExitStack() as stack: await stack.enter_async_context(exc_handler) cursor = await stack.enter_async_context(self.with_cursor(self.connection)) copy_ctx = await stack.enter_async_context(cursor.copy(copy_sql)) for record in records: await copy_ctx.write_row(record) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Load staged artifacts asynchronously.""" arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) return await self.load_from_arrow( table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound )
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "PsycopgAsyncDataDictionary": """Get the data dictionary for this driver. Returns: Data dictionary instance for metadata queries """ if self._data_dictionary is None: self._data_dictionary = PsycopgAsyncDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE / INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── def _resolve_column_names(self, description: Any) -> list[str]: """Resolve and cache psycopg column names for hot row materialization paths.""" if not description: return [] cache_key = id(description) cached = self._column_name_cache.get(cache_key) if cached is not None and cached[0] is description: return cached[1] column_names = [col.name for col in description] if len(self._column_name_cache) >= COLUMN_CACHE_MAX_SIZE: self._column_name_cache.pop(next(iter(self._column_name_cache))) self._column_name_cache[cache_key] = (description, column_names) return column_names
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psycopg async rows for the direct execution path.""" data = cast("list[Any] | None", fetched) or [] column_names = self._resolve_column_names(cursor.description) return data, column_names, len(data)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from psycopg cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction.""" return bool(self.connection.info.transaction_status != TRANSACTION_STATUS_IDLE)
register_driver_profile("psycopg", driver_profile)