"""AsyncPG PostgreSQL driver implementation for async PostgreSQL operations."""
from collections import OrderedDict
from io import BytesIO
from typing import TYPE_CHECKING, Any, cast
import asyncpg
from sqlspec.adapters.asyncpg.core import (
PREPARED_STATEMENT_CACHE_SIZE,
NormalizedStackOperation,
collect_rows,
create_mapped_exception,
default_statement_config,
driver_profile,
invoke_prepared_statement,
parse_status,
resolve_many_rowcount,
)
from sqlspec.adapters.asyncpg.data_dictionary import AsyncpgDataDictionary
from sqlspec.core import (
SQL,
StackResult,
StatementStack,
create_sql_result,
get_cache_config,
is_copy_from_operation,
is_copy_operation,
register_driver_profile,
)
from sqlspec.driver import AsyncDriverAdapterBase, StackExecutionObserver, describe_stack_statement
from sqlspec.exceptions import SQLSpecError, StackExecutionError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import has_sqlstate
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPreparedStatement
from sqlspec.core import ArrowResult, SQLResult, StatementConfig
from sqlspec.driver import ExecutionResult
from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
from typing_extensions import Self
from sqlspec.adapters.asyncpg._typing import AsyncpgSessionContext
__all__ = ("AsyncpgCursor", "AsyncpgDriver", "AsyncpgExceptionHandler", "AsyncpgSessionContext")
logger = get_logger("sqlspec.adapters.asyncpg")
class AsyncpgCursor:
"""Context manager for AsyncPG cursor management."""
__slots__ = ("connection",)
def __init__(self, connection: "AsyncpgConnection") -> None:
self.connection = connection
async def __aenter__(self) -> "AsyncpgConnection":
return self.connection
async def __aexit__(self, *_: Any) -> None: ...
class AsyncpgExceptionHandler:
"""Async context manager for handling AsyncPG database exceptions.
Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
for better error handling in application code.
Uses deferred exception pattern for mypyc compatibility: exceptions
are stored in pending_exception rather than raised from __aexit__
to avoid ABI boundary violations with compiled code.
"""
__slots__ = ("pending_exception",)
def __init__(self) -> None:
self.pending_exception: Exception | None = None
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
if exc_val is None:
return False
if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val):
self.pending_exception = create_mapped_exception(exc_val)
return True
return False
[docs]
class AsyncpgDriver(AsyncDriverAdapterBase):
"""AsyncPG PostgreSQL driver for async database operations.
Supports COPY operations, numeric parameter style handling, PostgreSQL
exception handling, transaction management, SQL statement compilation
and caching, and parameter processing with type coercion.
"""
__slots__ = ("_data_dictionary", "_prepared_statements")
dialect = "postgres"
[docs]
def __init__(
self,
connection: "AsyncpgConnection",
statement_config: "StatementConfig | None" = None,
driver_features: "dict[str, Any] | None" = None,
) -> None:
if statement_config is None:
statement_config = default_statement_config.replace(
enable_caching=get_cache_config().compiled_cache_enabled
)
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
self._data_dictionary: AsyncpgDataDictionary | None = None
self._prepared_statements: OrderedDict[str, AsyncpgPreparedStatement] = OrderedDict()
# ─────────────────────────────────────────────────────────────────────────────
# CORE DISPATCH METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def dispatch_execute(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult":
"""Execute single SQL statement.
Handles both SELECT queries and non-SELECT operations.
Args:
cursor: AsyncPG connection object
statement: SQL statement to execute
Returns:
ExecutionResult with statement execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
params: tuple[Any, ...] = cast("tuple[Any, ...]", prepared_parameters) if prepared_parameters else ()
if statement.returns_rows():
records = await cursor.fetch(sql, *params) if params else await cursor.fetch(sql)
data, column_names = collect_rows(records)
return self.create_execution_result(
cursor,
selected_data=data,
column_names=column_names,
data_row_count=len(data),
is_select_result=True,
row_format="record",
)
result = await cursor.execute(sql, *params) if params else await cursor.execute(sql)
affected_rows = parse_status(result)
return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs]
async def dispatch_execute_many(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult":
"""Execute SQL with multiple parameter sets using AsyncPG's executemany.
Args:
cursor: AsyncPG connection object
statement: SQL statement with multiple parameter sets
Returns:
ExecutionResult with batch execution details
"""
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
if prepared_parameters:
parameter_sets = cast("list[Sequence[object]]", prepared_parameters)
await cursor.executemany(sql, parameter_sets)
affected_rows = resolve_many_rowcount(parameter_sets)
else:
affected_rows = 0
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs]
async def dispatch_execute_script(self, cursor: "AsyncpgConnection", statement: "SQL") -> "ExecutionResult":
"""Execute SQL script with statement splitting and parameter handling.
Args:
cursor: AsyncPG connection object
statement: SQL statement containing multiple statements
Returns:
ExecutionResult with script execution details
"""
sql, _ = self._get_compiled_sql(statement, self.statement_config)
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
successful_count = 0
last_result = None
for stmt in statements:
result = await cursor.execute(stmt)
last_result = result
successful_count += 1
return self.create_execution_result(
last_result, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
)
[docs]
async def dispatch_special_handling(self, cursor: "AsyncpgConnection", statement: "SQL") -> "SQLResult | None":
"""Handle PostgreSQL COPY operations and other special cases.
Args:
cursor: AsyncPG connection object
statement: SQL statement to analyze
Returns:
SQLResult if special operation was handled, None for standard execution
"""
if is_copy_operation(statement.operation_type):
await self._handle_copy_operation(cursor, statement)
return self.build_statement_result(statement, self.create_execution_result(cursor))
return None
# ─────────────────────────────────────────────────────────────────────────────
# TRANSACTION MANAGEMENT
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def begin(self) -> None:
"""Begin a database transaction."""
try:
await self.connection.execute("BEGIN")
except asyncpg.PostgresError as e:
msg = f"Failed to begin async transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
async def commit(self) -> None:
"""Commit the current transaction."""
try:
await self.connection.execute("COMMIT")
except asyncpg.PostgresError as e:
msg = f"Failed to commit async transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
async def rollback(self) -> None:
"""Rollback the current transaction."""
try:
await self.connection.execute("ROLLBACK")
except asyncpg.PostgresError as e:
msg = f"Failed to rollback async transaction: {e}"
raise SQLSpecError(msg) from e
[docs]
def with_cursor(self, connection: "AsyncpgConnection") -> "AsyncpgCursor":
"""Create context manager for AsyncPG cursor."""
return AsyncpgCursor(connection)
[docs]
def handle_database_exceptions(self) -> "AsyncpgExceptionHandler":
"""Handle database exceptions with PostgreSQL error codes."""
return AsyncpgExceptionHandler()
# ─────────────────────────────────────────────────────────────────────────────
# STACK EXECUTION METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def execute_stack(
self, stack: "StatementStack", *, continue_on_error: bool = False
) -> "tuple[StackResult, ...]":
"""Execute a StatementStack using asyncpg's rapid batching."""
if not isinstance(stack, StatementStack) or not stack or self.stack_native_disabled:
return await super().execute_stack(stack, continue_on_error=continue_on_error)
return await self._execute_stack_native(stack, continue_on_error=continue_on_error)
async def _execute_stack_native(
self, stack: "StatementStack", *, continue_on_error: bool
) -> "tuple[StackResult, ...]":
results: list[StackResult] = []
transaction_cm = None
if not continue_on_error and not self._connection_in_transaction():
transaction_cm = self.connection.transaction()
with StackExecutionObserver(self, stack, continue_on_error, native_pipeline=True) as observer:
if transaction_cm is not None:
async with transaction_cm:
await self._run_stack_operations(stack, continue_on_error, observer, results)
else:
await self._run_stack_operations(stack, continue_on_error, observer, results)
return tuple(results)
async def _run_stack_operations(
self,
stack: "StatementStack",
continue_on_error: bool,
observer: "StackExecutionObserver",
results: "list[StackResult]",
) -> None:
"""Run operations for statement stack execution.
Extracted from _execute_stack_native to avoid closure compilation issues.
"""
for index, operation in enumerate(stack.operations):
try:
normalized: NormalizedStackOperation | None = None
if operation.method == "execute":
kwargs = dict(operation.keyword_arguments) if operation.keyword_arguments else {}
statement_config = kwargs.pop("statement_config", None)
config = statement_config or self.statement_config
sql_statement = self.prepare_statement(
operation.statement, operation.arguments, statement_config=config, kwargs=kwargs
)
if not sql_statement.is_script and not sql_statement.is_many:
sql_text, prepared_parameters = self._get_compiled_sql(sql_statement, config)
prepared_parameters = cast("tuple[Any, ...] | dict[str, Any] | None", prepared_parameters)
normalized = NormalizedStackOperation(
operation=operation, statement=sql_statement, sql=sql_text, parameters=prepared_parameters
)
if normalized is not None:
stack_result = await self._execute_stack_operation_prepared(normalized)
else:
result = await self._execute_stack_operation(operation)
stack_result = StackResult(result=result)
except Exception as exc:
stack_error = StackExecutionError(
index,
describe_stack_statement(operation.statement),
exc,
adapter=type(self).__name__,
mode="continue-on-error" if continue_on_error else "fail-fast",
)
if continue_on_error:
observer.record_operation_error(stack_error)
results.append(StackResult.from_error(stack_error))
continue
raise stack_error from exc
results.append(stack_result)
async def _execute_stack_operation_prepared(self, normalized: "NormalizedStackOperation") -> StackResult:
prepared = await self._get_prepared_statement(normalized.sql)
metadata = {"prepared_statement": True}
if normalized.statement.returns_rows():
rows = await invoke_prepared_statement(prepared, normalized.parameters, fetch=True)
data, _ = collect_rows(rows)
sql_result = create_sql_result(
normalized.statement, data=data, rows_affected=len(data), metadata=metadata, row_format="record"
)
return StackResult.from_sql_result(sql_result)
status = await invoke_prepared_statement(prepared, normalized.parameters, fetch=False)
rowcount = parse_status(status)
sql_result = create_sql_result(normalized.statement, rows_affected=rowcount, metadata=metadata)
return StackResult.from_sql_result(sql_result)
# ─────────────────────────────────────────────────────────────────────────────
# STORAGE API METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
async def select_to_storage(
self,
statement: "SQL | str",
destination: "StorageDestination",
/,
*parameters: Any,
statement_config: "StatementConfig | None" = None,
partitioner: "dict[str, object] | None" = None,
format_hint: "StorageFormat | None" = None,
telemetry: "StorageTelemetry | None" = None,
**kwargs: Any,
) -> "StorageBridgeJob":
"""Execute a query and persist results to storage once native COPY is available."""
self._require_capability("arrow_export_enabled")
arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
async_pipeline = self._storage_pipeline()
telemetry_payload = await self._write_result_to_storage_async(
arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline
)
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
async def load_from_arrow(
self,
table: str,
source: "ArrowResult | Any",
*,
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
telemetry: "StorageTelemetry | None" = None,
) -> "StorageBridgeJob":
"""Load Arrow data into a PostgreSQL table via COPY."""
self._require_capability("arrow_import_enabled")
arrow_table = self._coerce_arrow_table(source)
if overwrite:
try:
await self.connection.execute(f"TRUNCATE TABLE {table}")
except asyncpg.PostgresError as exc:
msg = f"Failed to truncate table '{table}': {exc}"
raise SQLSpecError(msg) from exc
columns, records = self._arrow_table_to_rows(arrow_table)
if records:
await self.connection.copy_records_to_table(table, records=records, columns=columns)
telemetry_payload = self._build_ingest_telemetry(arrow_table)
telemetry_payload["destination"] = table
self._attach_partition_telemetry(telemetry_payload, partitioner)
return self._create_storage_job(telemetry_payload, telemetry)
[docs]
async def load_from_storage(
self,
table: str,
source: "StorageDestination",
*,
file_format: "StorageFormat",
partitioner: "dict[str, object] | None" = None,
overwrite: bool = False,
) -> "StorageBridgeJob":
"""Read an artifact from storage and ingest it via COPY."""
arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format)
return await self.load_from_arrow(
table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound
)
# ─────────────────────────────────────────────────────────────────────────────
# UTILITY METHODS
# ─────────────────────────────────────────────────────────────────────────────
@property
def data_dictionary(self) -> "AsyncpgDataDictionary":
"""Get the data dictionary for this driver.
Returns:
Data dictionary instance for metadata queries
"""
if self._data_dictionary is None:
self._data_dictionary = AsyncpgDataDictionary()
return self._data_dictionary
# ─────────────────────────────────────────────────────────────────────────────
# PRIVATE/INTERNAL METHODS
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]":
"""Collect asyncpg rows for the direct execution path."""
data, column_names = collect_rows(fetched)
return data, column_names, len(data)
[docs]
def resolve_rowcount(self, cursor: Any) -> int:
"""Resolve rowcount from asyncpg status for the direct execution path."""
return parse_status(cursor)
def _connection_in_transaction(self) -> bool:
"""Check if connection is in transaction."""
return bool(self.connection.is_in_transaction())
async def _get_prepared_statement(self, sql: str) -> "AsyncpgPreparedStatement":
cached = self._prepared_statements.get(sql)
if cached is not None:
self._prepared_statements.move_to_end(sql)
return cached
prepared = cast("AsyncpgPreparedStatement", await self.connection.prepare(sql))
self._prepared_statements[sql] = prepared
if len(self._prepared_statements) > PREPARED_STATEMENT_CACHE_SIZE:
self._prepared_statements.popitem(last=False)
return prepared
async def _handle_copy_operation(self, cursor: "AsyncpgConnection", statement: "SQL") -> None:
"""Handle PostgreSQL COPY operations.
Supports both COPY FROM STDIN and COPY TO STDOUT operations.
Args:
cursor: AsyncPG connection object
statement: SQL statement with COPY operation
"""
execution_args = statement.statement_config.execution_args
metadata: dict[str, Any] = dict(execution_args) if execution_args else {}
sql_text, _ = self._get_compiled_sql(statement, statement.statement_config)
sql_upper = sql_text.upper()
copy_data = metadata.get("postgres_copy_data")
if copy_data and is_copy_from_operation(statement.operation_type) and "FROM STDIN" in sql_upper:
if isinstance(copy_data, dict):
data_str = (
str(next(iter(copy_data.values())))
if len(copy_data) == 1
else "\n".join(str(value) for value in copy_data.values())
)
elif isinstance(copy_data, (list, tuple)):
data_str = str(copy_data[0]) if len(copy_data) == 1 else "\n".join(str(value) for value in copy_data)
else:
data_str = str(copy_data)
data_io = BytesIO(data_str.encode("utf-8"))
await cursor.copy_from_query(sql_text, output=data_io)
return
await cursor.execute(sql_text)
register_driver_profile("asyncpg", driver_profile)