"""PostgreSQL psycopg driver implementation."""
from collections.abc import Sized
from contextlib import AsyncExitStack, ExitStack
from typing import TYPE_CHECKING, Any, cast
import psycopg
from typing_extensions import Self
from sqlspec.adapters.psycopg._typing import (
PsycopgAsyncConnection,
PsycopgAsyncSessionContext,
PsycopgSyncConnection,
PsycopgSyncSessionContext,
)
from sqlspec.adapters.psycopg.core import (
TRANSACTION_STATUS_IDLE,
PipelineCursorEntry,
PreparedStackOperation,
build_async_pipeline_execution_result,
build_copy_from_command,
build_pipeline_execution_result,
build_truncate_command,
create_mapped_exception,
default_statement_config,
driver_profile,
execute_with_optional_parameters,
execute_with_optional_parameters_async,
pipeline_supported,
resolve_many_rowcount,
resolve_rowcount,
)
from sqlspec.adapters.psycopg.data_dictionary import PsycopgAsyncDataDictionary, PsycopgSyncDataDictionary
from sqlspec.core import (
SQL,
SQLResult,
StackResult,
StatementConfig,
StatementStack,
get_cache_config,
is_copy_from_operation,
is_copy_operation,
is_copy_to_operation,
register_driver_profile,
)
from sqlspec.driver import (
AsyncDriverAdapterBase,
StackExecutionObserver,
SyncDriverAdapterBase,
describe_stack_statement,
)
from sqlspec.exceptions import SQLSpecError, StackExecutionError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import is_readable
if TYPE_CHECKING:
from sqlspec.adapters.psycopg._typing import PsycopgPipelineDriver
from sqlspec.core import ArrowResult
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
__all__ = (
"PsycopgAsyncCursor",
"PsycopgAsyncDriver",
"PsycopgAsyncExceptionHandler",
"PsycopgAsyncSessionContext",
"PsycopgSyncCursor",
"PsycopgSyncDriver",
"PsycopgSyncExceptionHandler",
"PsycopgSyncSessionContext",
)
logger = get_logger("sqlspec.adapters.psycopg")
COLUMN_CACHE_MAX_SIZE = 256
class PsycopgPipelineMixin:
"""Shared helpers for psycopg sync/async pipeline execution."""
__slots__ = ()
def _prepare_pipeline_operations(self, stack: "StatementStack") -> "list[PreparedStackOperation] | None":
prepared: list[PreparedStackOperation] = []
for index, operation in enumerate(stack.operations):
if operation.method != "execute":
return None
kwargs = dict(operation.keyword_arguments) if operation.keyword_arguments else {}
statement_config = kwargs.pop("statement_config", None)
driver = cast("PsycopgPipelineDriver", self)
config = statement_config or driver.statement_config
sql_statement = driver.prepare_statement(
operation.statement, operation.arguments, statement_config=config, kwargs=kwargs
)
if sql_statement.is_script or sql_statement.is_many:
return None
sql_text, prepared_parameters = driver._get_compiled_sql( # pyright: ignore[reportPrivateUsage]
sql_statement, config
)
prepared.append(
PreparedStackOperation(
operation_index=index,
operation=operation,
statement=sql_statement,
sql=sql_text,
parameters=prepared_parameters,
)
)
return prepared
class PsycopgSyncCursor:
"""Context manager for PostgreSQL psycopg cursor management."""
__slots__ = ("connection", "cursor")
def __init__(self, connection: PsycopgSyncConnection) -> None:
self.connection = connection
self.cursor: Any | None = None
def __enter__(self) -> Any:
self.cursor = self.connection.cursor()
return self.cursor
def __exit__(self, *_: Any) -> None:
if self.cursor is not None:
self.cursor.close()
class PsycopgSyncExceptionHandler:
"""Context manager for handling PostgreSQL psycopg database exceptions.
Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
for better error handling in application code.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __exit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ("pending_exception",)
def __init__(self) -> None:
self.pending_exception: Exception | None = None
def __enter__(self) -> Self:
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
if exc_type is None:
return False
if issubclass(exc_type, psycopg.Error):
self.pending_exception = create_mapped_exception(exc_val)
return True
return False
[docs]
class PsycopgSyncDriver(PsycopgPipelineMixin, SyncDriverAdapterBase):
"""PostgreSQL psycopg synchronous driver.
Provides synchronous database operations for PostgreSQL using psycopg3.
Supports SQL statement execution with parameter binding, transaction
management, result processing with column metadata, parameter style
conversion, PostgreSQL arrays and JSON handling, COPY operations for
bulk data transfer, and PostgreSQL-specific error handling.
"""
__slots__ = ("_column_name_cache", "_data_dictionary")
dialect = "postgres"
[docs]
def __init__(
self,
connection: PsycopgSyncConnection,
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: PsycopgSyncDataDictionary | None = None
self._column_name_cache: dict[int, tuple[Any, list[str]]] = {}
# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
"""Execute single SQL statement.
Args:
cursor: Database cursor
statement: SQL statement to execute
Returns:
ExecutionResult with statement execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
execute_with_optional_parameters(cursor, sql, prepared_parameters)
if statement.returns_rows():
fetched_data = cursor.fetchall()
data = cast("list[Any] | None", fetched_data) or []
column_names = self._resolve_column_names(cursor.description)
return self.create_execution_result(
cursor,
selected_data=data,
column_names=column_names,
data_row_count=len(data),
is_select_result=True,
row_format="tuple",
)
affected_rows = resolve_rowcount(cursor)
return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs]
def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
"""Execute SQL with multiple parameter sets.
Args:
cursor: Database cursor
statement: SQL statement with parameter list
Returns:
ExecutionResult with batch execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
if not prepared_parameters:
return self.create_execution_result(cursor, rowcount_override=0, is_many_result=True)
parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None
cursor.executemany(sql, prepared_parameters)
affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count)
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs]
def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
"""Execute SQL script with multiple statements.
Args:
cursor: Database cursor
statement: SQL statement containing multiple commands
Returns:
ExecutionResult with script execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_cursor = cursor
for stmt in statements:
execute_with_optional_parameters(cursor, stmt, prepared_parameters)
successful_count += 1
return self.create_execution_result(
last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
[docs]
def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None":
"""Hook for PostgreSQL-specific special operations.
Args:
cursor: Psycopg cursor object
statement: SQL statement to analyze
Returns:
SQLResult if special handling was applied, None otherwise
"""
if not is_copy_operation(statement.operation_type):
return None
sql, _ = self._get_compiled_sql(statement, statement.statement_config)
operation_type = statement.operation_type
copy_data = statement.parameters
if isinstance(copy_data, list) and len(copy_data) == 1:
copy_data = copy_data[0]
if is_copy_from_operation(operation_type):
if isinstance(copy_data, (str, bytes)):
data_to_write = copy_data
elif is_readable(copy_data):
data_to_write = copy_data.read()
else:
data_to_write = str(copy_data)
if isinstance(data_to_write, str):
data_to_write = data_to_write.encode()
with cursor.copy(sql) as copy_ctx:
copy_ctx.write(data_to_write)
rows_affected = max(cursor.rowcount, 0)
return SQLResult(
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FROM_STDIN"}
)
if is_copy_to_operation(operation_type):
output_data: list[str] = []
with cursor.copy(sql) as copy_ctx:
output_data.extend(row.decode() if isinstance(row, bytes) else str(row) for row in copy_ctx)
exported_data = "".join(output_data)
return SQLResult(
data=[{"copy_output": exported_data}],
rows_affected=0,
statement=statement,
metadata={"copy_operation": "TO_STDOUT"},
)
cursor.execute(sql)
rows_affected = max(cursor.rowcount, 0)
return SQLResult(
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FILE"}
)
# ─────────────────────────────────────────────────────────────────────────────
# TRANSACTION MANAGEMENT
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def begin(self) -> None:
"""Begin a database transaction on the current connection."""
try:
if self.connection.autocommit:
self.connection.autocommit = False
except Exception as e:
msg = f"Failed to begin transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def commit(self) -> None:
"""Commit the current transaction on the current connection."""
try:
self.connection.commit()
except Exception as e:
msg = f"Failed to commit transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def rollback(self) -> None:
"""Rollback the current transaction on the current connection."""
try:
self.connection.rollback()
except Exception as e:
msg = f"Failed to rollback transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def with_cursor(self, connection: PsycopgSyncConnection) -> PsycopgSyncCursor:
"""Create context manager for PostgreSQL cursor."""
return PsycopgSyncCursor(connection)
[docs]
def handle_database_exceptions(self) -> "PsycopgSyncExceptionHandler":
"""Handle database-specific exceptions and wrap them appropriately."""
return PsycopgSyncExceptionHandler()
# ─────────────────────────────────────────────────────────────────────────────
# STACK EXECUTION METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def execute_stack(self, stack: "StatementStack", *, continue_on_error: bool = False) -> "tuple[StackResult, ...]":
"""Execute a StatementStack using psycopg pipeline mode when supported."""
if (
not isinstance(stack, StatementStack)
or not stack
or self.stack_native_disabled
or not pipeline_supported()
or continue_on_error
):
return super().execute_stack(stack, continue_on_error=continue_on_error)
prepared_ops = self._prepare_pipeline_operations(stack)
if prepared_ops is None:
return super().execute_stack(stack, continue_on_error=continue_on_error)
return self._execute_stack_pipeline(stack, prepared_ops)
def _execute_stack_pipeline(
self, stack: "StatementStack", prepared_ops: "list[PreparedStackOperation]"
) -> "tuple[StackResult, ...]":
def _raise_pending_exception(exception_ctx: "PsycopgSyncExceptionHandler") -> None:
if exception_ctx.pending_exception is not None:
raise exception_ctx.pending_exception from None
results: list[StackResult] = []
started_transaction = False
with StackExecutionObserver(self, stack, continue_on_error=False, native_pipeline=True):
try:
if not self._connection_in_transaction():
self.begin()
started_transaction = True
exception_handlers = []
with ExitStack() as resource_stack:
pipeline = resource_stack.enter_context(self.connection.pipeline())
pending: list[PipelineCursorEntry] = []
for prepared in prepared_ops:
exception_ctx = self.handle_database_exceptions()
exception_handlers.append(exception_ctx)
resource_stack.enter_context(exception_ctx)
cursor = resource_stack.enter_context(self.with_cursor(self.connection))
try:
if prepared.parameters:
cursor.execute(prepared.sql, prepared.parameters)
else:
cursor.execute(prepared.sql)
except Exception as exc:
stack_error = StackExecutionError(
prepared.operation_index,
describe_stack_statement(prepared.operation.statement),
exc,
adapter=type(self).__name__,
mode="fail-fast",
)
raise stack_error from exc
pending.append(PipelineCursorEntry(prepared=prepared, cursor=cursor))
pipeline.sync()
for entry in pending:
statement = entry.prepared.statement
cursor = entry.cursor
execution_result = build_pipeline_execution_result(
statement, cursor, column_name_resolver=self._resolve_column_names
)
sql_result = self.build_statement_result(statement, execution_result)
results.append(StackResult.from_sql_result(sql_result))
for exception_ctx in exception_handlers:
_raise_pending_exception(exception_ctx)
if started_transaction:
self.commit()
except Exception:
if started_transaction:
try:
self.rollback()
except Exception as rollback_error: # pragma: no cover - diagnostics only
logger.debug("Rollback after psycopg pipeline failure failed: %s", rollback_error)
raise
return tuple(results)
# ─────────────────────────────────────────────────────────────────────────────
# STORAGE API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def select_to_storage(
self,
statement: "SQL | str",
destination: "StorageDestination",
/,
*parameters: Any,
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Execute a query and stream Arrow results to storage (sync)."""
self._require_capability("arrow_export_enabled")
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
sync_pipeline = self._storage_pipeline()
telemetry_payload = self._write_result_to_storage_sync(
arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Load Arrow data into PostgreSQL using COPY."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
if overwrite:
truncate_sql = build_truncate_command(table)
exc_handler = self.handle_database_exceptions()
with self.with_cursor(self.connection) as cursor, exc_handler:
cursor.execute(truncate_sql)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
columns, records = self._arrow_table_to_rows(arrow_table)
if records:
copy_sql = build_copy_from_command(table, columns)
exc_handler = self.handle_database_exceptions()
with ExitStack() as stack:
stack.enter_context(exc_handler)
cursor = stack.enter_context(self.with_cursor(self.connection))
copy_ctx = stack.enter_context(cursor.copy(copy_sql))
for record in records:
copy_ctx.write_row(record)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Load staged artifacts into PostgreSQL via COPY."""
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ─────────────────────────────────────────────────────────────────────────────
# UTILITY METHODS
# ─────────────────────────────────────────────────────────────────────────────
@property
def data_dictionary(self) -> "PsycopgSyncDataDictionary":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = PsycopgSyncDataDictionary()
return self._data_dictionary
# ─────────────────────────────────────────────────────────────────────────────
# PRIVATE / INTERNAL METHODS
# ─────────────────────────────────────────────────────────────────────────────
def _resolve_column_names(self, description: Any) -> list[str]:
"""Resolve and cache psycopg column names for hot row materialization paths."""
if not description:
return []
cache_key = id(description)
cached = self._column_name_cache.get(cache_key)
if cached is not None and cached[0] is description:
return cached[1]
column_names = [col.name for col in description]
if len(self._column_name_cache) >= COLUMN_CACHE_MAX_SIZE:
self._column_name_cache.pop(next(iter(self._column_name_cache)))
self._column_name_cache[cache_key] = (description, column_names)
return column_names
[docs]
def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect psycopg sync rows for the direct execution path."""
data = cast("list[Any] | None", fetched) or []
column_names = self._resolve_column_names(cursor.description)
return data, column_names, len(data)
[docs]
def resolve_rowcount(self, cursor: Any) -> int:
"""Resolve rowcount from psycopg cursor for the direct execution path."""
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
"""Check if connection is in transaction."""
return bool(self.connection.info.transaction_status != TRANSACTION_STATUS_IDLE)
class PsycopgAsyncCursor:
"""Async context manager for PostgreSQL psycopg cursor management."""
__slots__ = ("connection", "cursor")
def __init__(self, connection: "PsycopgAsyncConnection") -> None:
self.connection = connection
self.cursor: Any | None = None
async def __aenter__(self) -> Any:
self.cursor = self.connection.cursor()
return self.cursor
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
_ = (exc_type, exc_val, exc_tb)
if self.cursor is not None:
await self.cursor.close()
class PsycopgAsyncExceptionHandler:
"""Async context manager for handling PostgreSQL psycopg database exceptions.
Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
for better error handling in application code.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __aexit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ("pending_exception",)
def __init__(self) -> None:
self.pending_exception: Exception | None = None
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
if exc_type is None:
return False
if issubclass(exc_type, psycopg.Error):
self.pending_exception = create_mapped_exception(exc_val)
return True
return False
[docs]
class PsycopgAsyncDriver(PsycopgPipelineMixin, AsyncDriverAdapterBase):
"""PostgreSQL psycopg asynchronous driver.
Provides asynchronous database operations for PostgreSQL using psycopg3.
Supports async SQL statement execution with parameter binding, async
transaction management, async result processing with column metadata,
parameter style conversion, PostgreSQL arrays and JSON handling, COPY
operations for bulk data transfer, PostgreSQL-specific error handling,
and async pub/sub support.
"""
__slots__ = ("_column_name_cache", "_data_dictionary")
dialect = "postgres"
[docs]
def __init__(
self,
connection: "PsycopgAsyncConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: PsycopgAsyncDataDictionary | None = None
self._column_name_cache: dict[int, tuple[Any, list[str]]] = {}
# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
"""Execute single SQL statement (async).
Args:
cursor: Database cursor
statement: SQL statement to execute
Returns:
ExecutionResult with statement execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
await execute_with_optional_parameters_async(cursor, sql, prepared_parameters)
if statement.returns_rows():
fetched_data = await cursor.fetchall()
data = cast("list[Any] | None", fetched_data) or []
column_names = self._resolve_column_names(cursor.description)
return self.create_execution_result(
cursor,
selected_data=data,
column_names=column_names,
data_row_count=len(data),
is_select_result=True,
row_format="tuple",
)
affected_rows = resolve_rowcount(cursor)
return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs]
async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
"""Execute SQL with multiple parameter sets (async).
Args:
cursor: Database cursor
statement: SQL statement with parameter list
Returns:
ExecutionResult with batch execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
if not prepared_parameters:
return self.create_execution_result(cursor, rowcount_override=0, is_many_result=True)
parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None
await cursor.executemany(sql, prepared_parameters)
affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count)
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs]
async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
"""Execute SQL script with multiple statements (async).
Args:
cursor: Database cursor
statement: SQL statement containing multiple commands
Returns:
ExecutionResult with script execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_cursor = cursor
for stmt in statements:
await execute_with_optional_parameters_async(cursor, stmt, prepared_parameters)
successful_count += 1
return self.create_execution_result(
last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
[docs]
async def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None":
"""Hook for PostgreSQL-specific special operations.
Args:
cursor: Psycopg async cursor object
statement: SQL statement to analyze
Returns:
SQLResult if special handling was applied, None otherwise
"""
if not is_copy_operation(statement.operation_type):
return None
sql, _ = self._get_compiled_sql(statement, statement.statement_config)
sql_upper = sql.upper()
operation_type = statement.operation_type
copy_data = statement.parameters
if isinstance(copy_data, list) and len(copy_data) == 1:
copy_data = copy_data[0]
if is_copy_from_operation(operation_type) and "FROM STDIN" in sql_upper:
if isinstance(copy_data, (str, bytes)):
data_to_write = copy_data
elif is_readable(copy_data):
data_to_write = copy_data.read()
else:
data_to_write = str(copy_data)
if isinstance(data_to_write, str):
data_to_write = data_to_write.encode()
async with cursor.copy(sql) as copy_ctx:
await copy_ctx.write(data_to_write)
rows_affected = max(cursor.rowcount, 0)
return SQLResult(
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FROM_STDIN"}
)
if is_copy_to_operation(operation_type) and "TO STDOUT" in sql_upper:
output_data: list[str] = []
async with cursor.copy(sql) as copy_ctx:
output_data.extend([row.decode() if isinstance(row, bytes) else str(row) async for row in copy_ctx])
exported_data = "".join(output_data)
return SQLResult(
data=[{"copy_output": exported_data}],
rows_affected=0,
statement=statement,
metadata={"copy_operation": "TO_STDOUT"},
)
await cursor.execute(sql)
rows_affected = max(cursor.rowcount, 0)
return SQLResult(
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FILE"}
)
# ─────────────────────────────────────────────────────────────────────────────
# TRANSACTION MANAGEMENT
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def begin(self) -> None:
"""Begin a database transaction on the current connection."""
try:
try:
autocommit_flag = self.connection.autocommit
except AttributeError:
autocommit_flag = None
if isinstance(autocommit_flag, bool) and not autocommit_flag:
return
await self.connection.set_autocommit(False)
except Exception as e:
msg = f"Failed to begin transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
async def commit(self) -> None:
"""Commit the current transaction on the current connection."""
try:
await self.connection.commit()
except Exception as e:
msg = f"Failed to commit transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
async def rollback(self) -> None:
"""Rollback the current transaction on the current connection."""
try:
await self.connection.rollback()
except Exception as e:
msg = f"Failed to rollback transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def with_cursor(self, connection: "PsycopgAsyncConnection") -> "PsycopgAsyncCursor":
"""Create async context manager for PostgreSQL cursor."""
return PsycopgAsyncCursor(connection)
[docs]
def handle_database_exceptions(self) -> "PsycopgAsyncExceptionHandler":
"""Handle database-specific exceptions and wrap them appropriately."""
return PsycopgAsyncExceptionHandler()
# ─────────────────────────────────────────────────────────────────────────────
# STACK EXECUTION METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def execute_stack(
self, stack: "StatementStack", *, continue_on_error: bool = False
) -> "tuple[StackResult, ...]":
"""Execute a StatementStack using psycopg async pipeline when supported."""
if (
not isinstance(stack, StatementStack)
or not stack
or self.stack_native_disabled
or not pipeline_supported()
or continue_on_error
):
return await super().execute_stack(stack, continue_on_error=continue_on_error)
prepared_ops = self._prepare_pipeline_operations(stack)
if prepared_ops is None:
return await super().execute_stack(stack, continue_on_error=continue_on_error)
return await self._execute_stack_pipeline(stack, prepared_ops)
async def _execute_stack_pipeline(
self, stack: "StatementStack", prepared_ops: "list[PreparedStackOperation]"
) -> "tuple[StackResult, ...]":
def _raise_pending_exception(exception_ctx: "PsycopgAsyncExceptionHandler") -> None:
if exception_ctx.pending_exception is not None:
raise exception_ctx.pending_exception from None
results: list[StackResult] = []
started_transaction = False
with StackExecutionObserver(self, stack, continue_on_error=False, native_pipeline=True):
try:
if not self._connection_in_transaction():
await self.begin()
started_transaction = True
exception_handlers = []
async with AsyncExitStack() as resource_stack:
pipeline = await resource_stack.enter_async_context(self.connection.pipeline())
pending: list[PipelineCursorEntry] = []
for prepared in prepared_ops:
exception_ctx = self.handle_database_exceptions()
exception_handlers.append(exception_ctx)
await resource_stack.enter_async_context(exception_ctx)
cursor = await resource_stack.enter_async_context(self.with_cursor(self.connection))
try:
if prepared.parameters:
await cursor.execute(prepared.sql, prepared.parameters)
else:
await cursor.execute(prepared.sql)
except Exception as exc:
stack_error = StackExecutionError(
prepared.operation_index,
describe_stack_statement(prepared.operation.statement),
exc,
adapter=type(self).__name__,
mode="fail-fast",
)
raise stack_error from exc
pending.append(PipelineCursorEntry(prepared=prepared, cursor=cursor))
await pipeline.sync()
for entry in pending:
statement = entry.prepared.statement
cursor = entry.cursor
execution_result = await build_async_pipeline_execution_result(
statement, cursor, column_name_resolver=self._resolve_column_names
)
sql_result = self.build_statement_result(statement, execution_result)
results.append(StackResult.from_sql_result(sql_result))
for exception_ctx in exception_handlers:
_raise_pending_exception(exception_ctx)
if started_transaction:
await self.commit()
except Exception:
if started_transaction:
try:
await self.rollback()
except Exception as rollback_error: # pragma: no cover - diagnostics only
logger.debug("Rollback after psycopg pipeline failure failed: %s", rollback_error)
raise
return tuple(results)
# ─────────────────────────────────────────────────────────────────────────────
# STORAGE API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def select_to_storage(
self,
statement: "SQL | str",
destination: "StorageDestination",
/,
*parameters: Any,
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Execute a query and stream Arrow data to storage asynchronously."""
self._require_capability("arrow_export_enabled")
arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
async_pipeline = self._storage_pipeline()
telemetry_payload = await self._write_result_to_storage_async(
arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
async def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Load Arrow data into PostgreSQL asynchronously via COPY."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
if overwrite:
truncate_sql = build_truncate_command(table)
exc_handler = self.handle_database_exceptions()
async with self.with_cursor(self.connection) as cursor, exc_handler:
await cursor.execute(truncate_sql)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
columns, records = self._arrow_table_to_rows(arrow_table)
if records:
copy_sql = build_copy_from_command(table, columns)
exc_handler = self.handle_database_exceptions()
async with AsyncExitStack() as stack:
await stack.enter_async_context(exc_handler)
cursor = await stack.enter_async_context(self.with_cursor(self.connection))
copy_ctx = await stack.enter_async_context(cursor.copy(copy_sql))
for record in records:
await copy_ctx.write_row(record)
if exc_handler.pending_exception is not None:
raise exc_handler.pending_exception from None
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
async def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Load staged artifacts asynchronously."""
arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format)
return await self.load_from_arrow(
table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound
)
# ─────────────────────────────────────────────────────────────────────────────
# UTILITY METHODS
# ─────────────────────────────────────────────────────────────────────────────
@property
def data_dictionary(self) -> "PsycopgAsyncDataDictionary":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = PsycopgAsyncDataDictionary()
return self._data_dictionary
# ─────────────────────────────────────────────────────────────────────────────
# PRIVATE / INTERNAL METHODS
# ─────────────────────────────────────────────────────────────────────────────
def _resolve_column_names(self, description: Any) -> list[str]:
"""Resolve and cache psycopg column names for hot row materialization paths."""
if not description:
return []
cache_key = id(description)
cached = self._column_name_cache.get(cache_key)
if cached is not None and cached[0] is description:
return cached[1]
column_names = [col.name for col in description]
if len(self._column_name_cache) >= COLUMN_CACHE_MAX_SIZE:
self._column_name_cache.pop(next(iter(self._column_name_cache)))
self._column_name_cache[cache_key] = (description, column_names)
return column_names
[docs]
def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect psycopg async rows for the direct execution path."""
data = cast("list[Any] | None", fetched) or []
column_names = self._resolve_column_names(cursor.description)
return data, column_names, len(data)
[docs]
def resolve_rowcount(self, cursor: Any) -> int:
"""Resolve rowcount from psycopg cursor for the direct execution path."""
return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool:
"""Check if connection is in transaction."""
return bool(self.connection.info.transaction_status != TRANSACTION_STATUS_IDLE)
register_driver_profile("psycopg", driver_profile)