"""AsyncPG database configuration with direct field-based configuration."""
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from asyncpg import Connection, Record
from asyncpg import create_pool as asyncpg_create_pool
from asyncpg.connection import ConnectionMeta
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
from mypy_extensions import mypyc_attr
from typing_extensions import NotRequired
from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPool, AsyncpgPreparedStatement
from sqlspec.adapters.asyncpg.core import (
apply_driver_features,
build_connection_config,
default_statement_config,
register_json_codecs,
register_pgvector_support,
)
from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, AsyncpgExceptionHandler, AsyncpgSessionContext
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.typing import ALLOYDB_CONNECTOR_INSTALLED, CLOUD_SQL_CONNECTOR_INSTALLED, PGVECTOR_INSTALLED
from sqlspec.utils.config_tools import normalize_connection_config
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json
if TYPE_CHECKING:
from asyncio.events import AbstractEventLoop
from collections.abc import Awaitable, Callable
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = (
"PGVECTOR_INSTALLED",
"AsyncpgConfig",
"AsyncpgConnectionConfig",
"AsyncpgDriverFeatures",
"AsyncpgPoolConfig",
"register_json_codecs",
"register_pgvector_support",
)
logger = get_logger(__name__)
[docs]
class AsyncpgConnectionConfig(TypedDict):
"""TypedDict for AsyncPG connection parameters."""
dsn: NotRequired[str]
host: NotRequired[str]
port: NotRequired[int]
user: NotRequired[str]
password: NotRequired[str]
database: NotRequired[str]
ssl: NotRequired[Any]
passfile: NotRequired[str]
direct_tls: NotRequired[bool]
connect_timeout: NotRequired[float]
command_timeout: NotRequired[float]
statement_cache_size: NotRequired[int]
max_cached_statement_lifetime: NotRequired[int]
max_cacheable_statement_size: NotRequired[int]
server_settings: NotRequired["dict[str, str]"]
[docs]
class AsyncpgPoolConfig(AsyncpgConnectionConfig):
"""TypedDict for AsyncPG pool parameters, inheriting connection parameters."""
min_size: NotRequired[int]
max_size: NotRequired[int]
max_queries: NotRequired[int]
max_inactive_connection_lifetime: NotRequired[float]
setup: NotRequired["Callable[[AsyncpgConnection], Awaitable[None]]"]
init: NotRequired["Callable[[AsyncpgConnection], Awaitable[None]]"]
loop: NotRequired["AbstractEventLoop"]
connection_class: NotRequired[type["AsyncpgConnection"]]
record_class: NotRequired[type[Record]]
extra: NotRequired["dict[str, Any]"]
class AsyncpgDriverFeatures(TypedDict):
"""AsyncPG driver feature flags.
json_serializer: Custom JSON serializer function for PostgreSQL JSON/JSONB types.
Defaults to sqlspec.utils.serializers.to_json.
Use for performance optimization (e.g., orjson) or custom encoding behavior.
Applied when enable_json_codecs is True.
json_deserializer: Custom JSON deserializer function for PostgreSQL JSON/JSONB types.
Defaults to sqlspec.utils.serializers.from_json.
Use for performance optimization (e.g., orjson) or custom decoding behavior.
Applied when enable_json_codecs is True.
enable_json_codecs: Enable automatic JSON/JSONB codec registration on connections.
Defaults to True for seamless Python dict/list to PostgreSQL JSON/JSONB conversion.
Set to False to disable automatic codec registration (manual handling required).
enable_pgvector: Enable pgvector extension support for vector similarity search.
Requires pgvector-python package (pip install pgvector) and PostgreSQL with pgvector extension.
Defaults to True when pgvector-python is installed.
Provides automatic conversion between Python objects and PostgreSQL vector types.
Enables vector similarity operations and index support.
enable_paradedb: Enable ParadeDB (pg_search) extension detection.
When enabled and the pg_search extension is detected, the SQL dialect
switches to "paradedb" which supports search operators (@@@, &&&, etc.)
and inherits all pgvector distance operators.
Defaults to True. Independent of enable_pgvector.
enable_cloud_sql: Enable Google Cloud SQL connector integration.
Requires cloud-sql-python-connector package.
Defaults to False (explicit opt-in required).
Auto-configures IAM authentication, SSL, and IP routing.
Mutually exclusive with enable_alloydb.
cloud_sql_instance: Cloud SQL instance connection name.
Format: "project:region:instance"
Required when enable_cloud_sql is True.
cloud_sql_enable_iam_auth: Enable IAM database authentication.
Defaults to False for passwordless authentication.
When False, requires user/password in connection_config.
cloud_sql_ip_type: IP address type for connection.
Options: "PUBLIC", "PRIVATE", "PSC"
Defaults to "PRIVATE".
enable_alloydb: Enable Google AlloyDB connector integration.
Requires cloud-alloydb-python-connector package.
Defaults to False (explicit opt-in required).
Auto-configures IAM authentication and private networking.
Mutually exclusive with enable_cloud_sql.
alloydb_instance_uri: AlloyDB instance URI.
Format: "projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE"
Required when enable_alloydb is True.
alloydb_enable_iam_auth: Enable IAM database authentication.
Defaults to False for passwordless authentication.
alloydb_ip_type: IP address type for connection.
Options: "PUBLIC", "PRIVATE", "PSC"
Defaults to "PRIVATE".
enable_events: Enable database event channel support.
Defaults to True when extension_config["events"] is configured.
Provides pub/sub capabilities via LISTEN/NOTIFY or table-backed fallback.
Requires extension_config["events"] for migration setup when using table_queue backend.
events_backend: Event channel backend selection.
Options: "listen_notify", "table_queue", "listen_notify_durable"
- "listen_notify": Zero-copy PostgreSQL LISTEN/NOTIFY (ephemeral, real-time)
- "table_queue": Durable table-backed queue with retries and exactly-once delivery
- "listen_notify_durable": Hybrid - combines real-time LISTEN/NOTIFY with table durability (recommended for production)
Defaults to "listen_notify" for backward compatibility.
Note: "listen_notify_durable" provides best of both worlds - <100ms latency with full durability.
"""
json_serializer: NotRequired["Callable[[Any], str]"]
json_deserializer: NotRequired["Callable[[str], Any]"]
enable_json_codecs: NotRequired[bool]
enable_pgvector: NotRequired[bool]
enable_paradedb: NotRequired[bool]
enable_cloud_sql: NotRequired[bool]
cloud_sql_instance: NotRequired[str]
cloud_sql_enable_iam_auth: NotRequired[bool]
cloud_sql_ip_type: NotRequired[str]
enable_alloydb: NotRequired[bool]
alloydb_instance_uri: NotRequired[str]
alloydb_enable_iam_auth: NotRequired[bool]
alloydb_ip_type: NotRequired[str]
enable_events: NotRequired[bool]
events_backend: NotRequired[str]
connection_instance: NotRequired["AsyncpgPool"]
on_connection_create: NotRequired["Callable[[AsyncpgConnection], Awaitable[None]]"]
class _AsyncpgCloudSqlConnector:
__slots__ = ("_config", "_database", "_password", "_user")
def __init__(self, config: "AsyncpgConfig", user: str | None, password: str | None, database: str | None) -> None:
self._config = config
self._user = user
self._password = password
self._database = database
async def __call__(self) -> "AsyncpgConnection":
connector = self._config.get_cloud_sql_connector()
if connector is None:
msg = "Cloud SQL connector is not initialized"
raise ImproperConfigurationError(msg)
conn_kwargs: dict[str, Any] = {
"instance_connection_string": self._config.driver_features["cloud_sql_instance"],
"driver": "asyncpg",
"enable_iam_auth": self._config.driver_features.get("cloud_sql_enable_iam_auth", False),
"ip_type": self._config.driver_features.get("cloud_sql_ip_type", "PRIVATE"),
}
if self._user:
conn_kwargs["user"] = self._user
if self._password:
conn_kwargs["password"] = self._password
if self._database:
conn_kwargs["db"] = self._database
return cast("AsyncpgConnection", await connector.connect_async(**conn_kwargs))
class _AsyncpgAlloydbConnector:
__slots__ = ("_config", "_database", "_password", "_user")
def __init__(self, config: "AsyncpgConfig", user: str | None, password: str | None, database: str | None) -> None:
self._config = config
self._user = user
self._password = password
self._database = database
async def __call__(self) -> "AsyncpgConnection":
connector = self._config.get_alloydb_connector()
if connector is None:
msg = "AlloyDB connector is not initialized"
raise ImproperConfigurationError(msg)
conn_kwargs: dict[str, Any] = {
"instance_uri": self._config.driver_features["alloydb_instance_uri"],
"driver": "asyncpg",
"enable_iam_auth": self._config.driver_features.get("alloydb_enable_iam_auth", False),
"ip_type": self._config.driver_features.get("alloydb_ip_type", "PRIVATE"),
}
if self._user:
conn_kwargs["user"] = self._user
if self._password:
conn_kwargs["password"] = self._password
if self._database:
conn_kwargs["db"] = self._database
return cast("AsyncpgConnection", await connector.connect(**conn_kwargs))
class _AsyncpgSessionFactory:
__slots__ = ("_config", "_connection")
def __init__(self, config: "AsyncpgConfig") -> None:
self._config = config
self._connection: AsyncpgConnection | None = None
async def acquire_connection(self) -> "AsyncpgConnection":
pool = self._config.connection_instance
if pool is None:
pool = await self._config.create_pool()
self._config.connection_instance = pool
self._connection = await pool.acquire()
return self._connection
async def release_connection(self, _conn: "AsyncpgConnection") -> None:
if self._connection is not None and self._config.connection_instance is not None:
await self._config.connection_instance.release(self._connection) # type: ignore[arg-type]
self._connection = None
class AsyncpgConnectionContext:
"""Async context manager for AsyncPG connections."""
__slots__ = ("_config", "_connection")
def __init__(self, config: "AsyncpgConfig") -> None:
self._config = config
self._connection: AsyncpgConnection | None = None
async def __aenter__(self) -> "AsyncpgConnection":
pool = self._config.connection_instance
if pool is None:
pool = await self._config.create_pool()
self._config.connection_instance = pool
self._connection = await pool.acquire()
return self._connection
async def __aexit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any
) -> bool | None:
if self._connection is not None:
if self._config.connection_instance:
await self._config.connection_instance.release(self._connection) # type: ignore[arg-type]
self._connection = None
return None
[docs]
@mypyc_attr(native_class=False)
class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", AsyncpgDriver]):
"""Configuration for AsyncPG database connections using TypedDict."""
driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver
connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment]
supports_transactional_ddl: "ClassVar[bool]" = True
supports_native_arrow_export: "ClassVar[bool]" = True
supports_native_arrow_import: "ClassVar[bool]" = True
supports_native_parquet_export: "ClassVar[bool]" = True
supports_native_parquet_import: "ClassVar[bool]" = True
[docs]
def __init__(
self,
*,
connection_config: "AsyncpgPoolConfig | dict[str, Any] | None" = None,
connection_instance: "Pool[Record] | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "AsyncpgDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
**kwargs: Any,
) -> None:
"""Initialize AsyncPG configuration.
Args:
connection_config: Connection and pool configuration parameters (TypedDict or dict)
connection_instance: Existing pool instance to use
migration_config: Migration configuration
statement_config: Statement configuration override
driver_features: Driver features configuration (TypedDict or dict)
bind_key: Optional unique identifier for this configuration
extension_config: Extension-specific configuration (e.g., Litestar plugin settings)
observability_config: Adapter-level observability overrides for lifecycle hooks and observers
**kwargs: Additional keyword arguments
"""
statement_config = statement_config or default_statement_config
statement_config, driver_features = apply_driver_features(statement_config, driver_features)
# Extract user connection hook before storing driver_features
features_dict = dict(driver_features) if driver_features else {}
self._user_connection_hook: Callable[[AsyncpgConnection], Awaitable[None]] | None = features_dict.pop(
"on_connection_create", None
)
super().__init__(
connection_config=normalize_connection_config(connection_config),
connection_instance=connection_instance,
migration_config=migration_config,
statement_config=statement_config,
driver_features=features_dict,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
**kwargs,
)
self._cloud_sql_connector: Any | None = None
self._alloydb_connector: Any | None = None
self._pgvector_available: bool | None = None
self._paradedb_available: bool | None = None
self._validate_connector_config()
[docs]
def get_cloud_sql_connector(self) -> Any | None:
"""Return the configured Cloud SQL connector instance."""
return self._cloud_sql_connector
[docs]
def get_alloydb_connector(self) -> Any | None:
"""Return the configured AlloyDB connector instance."""
return self._alloydb_connector
def _validate_connector_config(self) -> None:
"""Validate Google Cloud connector configuration.
Raises:
ImproperConfigurationError: If configuration is invalid.
MissingDependencyError: If required connector packages are not installed.
"""
enable_cloud_sql = self.driver_features.get("enable_cloud_sql", False)
enable_alloydb = self.driver_features.get("enable_alloydb", False)
match (enable_cloud_sql, enable_alloydb):
case (True, True):
msg = (
"Cannot enable both Cloud SQL and AlloyDB connectors simultaneously. "
"Use separate configs for each database."
)
raise ImproperConfigurationError(msg)
case (False, False):
return
case (True, False):
if not CLOUD_SQL_CONNECTOR_INSTALLED:
raise MissingDependencyError(package="cloud-sql-python-connector", install_package="cloud-sql")
instance = self.driver_features.get("cloud_sql_instance")
if not instance:
msg = "cloud_sql_instance required when enable_cloud_sql is True. Format: 'project:region:instance'"
raise ImproperConfigurationError(msg)
cloud_sql_instance_parts_expected = 2
if instance.count(":") != cloud_sql_instance_parts_expected:
msg = f"Invalid Cloud SQL instance format: {instance}. Expected format: 'project:region:instance'"
raise ImproperConfigurationError(msg)
case (False, True):
if not ALLOYDB_CONNECTOR_INSTALLED:
raise MissingDependencyError(
package="google-cloud-alloydb-connector", install_package="google-cloud-alloydb-connector"
)
instance_uri = self.driver_features.get("alloydb_instance_uri")
if not instance_uri:
msg = "alloydb_instance_uri required when enable_alloydb is True. Format: 'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'"
raise ImproperConfigurationError(msg)
if not instance_uri.startswith("projects/"):
msg = f"Invalid AlloyDB instance URI format: {instance_uri}. Expected format: 'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'"
raise ImproperConfigurationError(msg)
def _setup_cloud_sql_connector(self, config: "dict[str, Any]") -> None:
"""Setup Cloud SQL connector and configure pool for connection factory pattern.
Args:
config: Pool configuration dictionary to modify in-place.
"""
from google.cloud.sql.connector import Connector # type: ignore[import-untyped,unused-ignore]
self._cloud_sql_connector = Connector()
user = config.get("user")
password = config.get("password")
database = config.get("database")
for key in ("dsn", "host", "port", "user", "password", "database"):
config.pop(key, None)
config["connect"] = _AsyncpgCloudSqlConnector(self, user, password, database)
def _setup_alloydb_connector(self, config: "dict[str, Any]") -> None:
"""Setup AlloyDB connector and configure pool for connection factory pattern.
Args:
config: Pool configuration dictionary to modify in-place.
"""
from google.cloud.alloydb.connector import AsyncConnector # type: ignore[import-untyped,unused-ignore]
self._alloydb_connector = AsyncConnector()
user = config.get("user")
password = config.get("password")
database = config.get("database")
for key in ("dsn", "host", "port", "user", "password", "database"):
config.pop(key, None)
config["connect"] = _AsyncpgAlloydbConnector(self, user, password, database)
async def _create_pool(self) -> "Pool[Record]":
"""Create the actual async connection pool."""
config = build_connection_config(self.connection_config)
if self.driver_features.get("enable_cloud_sql", False):
self._setup_cloud_sql_connector(config)
elif self.driver_features.get("enable_alloydb", False):
self._setup_alloydb_connector(config)
config.setdefault("init", self._init_connection)
return await asyncpg_create_pool(**config)
async def _init_connection(self, connection: "AsyncpgConnection") -> None:
"""Initialize connection with JSON codecs, pgvector support, and user callback.
Args:
connection: AsyncPG connection to initialize.
"""
if self.driver_features.get("enable_json_codecs", True):
await register_json_codecs(
connection,
encoder=self.driver_features.get("json_serializer", to_json),
decoder=self.driver_features.get("json_deserializer", from_json),
)
# Detect extensions on first connection, update dialect
if self._pgvector_available is None:
extensions = [
name
for name, enabled in [
("vector", self.driver_features.get("enable_pgvector", False)),
("pg_search", self.driver_features.get("enable_paradedb", False)),
]
if enabled
]
if extensions:
try:
results = await connection.fetch(
"SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])", extensions
)
detected = {r["extname"] for r in results}
self._pgvector_available = "vector" in detected
self._paradedb_available = "pg_search" in detected
except Exception:
self._pgvector_available = False
self._paradedb_available = False
else:
self._pgvector_available = False
self._paradedb_available = False
self._update_dialect_for_extensions()
if self._pgvector_available:
await register_pgvector_support(connection)
# Call user-provided callback after internal setup
if self._user_connection_hook is not None:
await self._user_connection_hook(connection)
def _update_dialect_for_extensions(self) -> None:
"""Update statement_config dialect based on detected extensions.
Priority: paradedb > pgvector > postgres (default).
"""
current_dialect = getattr(self.statement_config, "dialect", "postgres")
if current_dialect != "postgres":
return
if self._paradedb_available:
self.statement_config = self.statement_config.replace(dialect="paradedb")
elif self._pgvector_available:
self.statement_config = self.statement_config.replace(dialect="pgvector")
async def _close_pool(self) -> None:
"""Close the actual async connection pool and cleanup connectors."""
if self.connection_instance:
await self.connection_instance.close()
self.connection_instance = None
if self._cloud_sql_connector is not None:
await self._cloud_sql_connector.close_async()
self._cloud_sql_connector = None
if self._alloydb_connector is not None:
await self._alloydb_connector.close()
self._alloydb_connector = None
[docs]
async def close_pool(self) -> None:
"""Close the connection pool."""
await self._close_pool()
[docs]
async def create_connection(self) -> "AsyncpgConnection":
"""Create a single async connection from the pool.
Returns:
An AsyncPG connection instance.
"""
pool = self.connection_instance
if pool is None:
pool = await self.create_pool()
self.connection_instance = pool
return await pool.acquire()
[docs]
def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncpgConnectionContext":
"""Provide an async connection context manager.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
An AsyncPG connection context manager.
"""
return AsyncpgConnectionContext(self)
[docs]
def provide_session(
self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any
) -> "AsyncpgSessionContext":
"""Provide an async driver session context manager.
Args:
*_args: Additional arguments.
statement_config: Optional statement configuration override.
**_kwargs: Additional keyword arguments.
Returns:
An AsyncPG driver session context manager.
"""
factory = _AsyncpgSessionFactory(self)
return AsyncpgSessionContext(
acquire_connection=factory.acquire_connection,
release_connection=factory.release_connection,
statement_config=statement_config or (lambda: self.statement_config or default_statement_config),
driver_features=self.driver_features,
prepare_driver=self._prepare_driver,
)
[docs]
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
"""Provide async pool instance.
Returns:
The async connection pool.
"""
if not self.connection_instance:
self.connection_instance = await self.create_pool()
return self.connection_instance
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for AsyncPG types.
This provides all AsyncPG-specific types that Litestar needs to recognize
to avoid serialization attempts.
Returns:
Dictionary mapping type names to types.
"""
namespace = super().get_signature_namespace()
namespace.update({
"Connection": Connection,
"Pool": Pool,
"PoolConnectionProxy": PoolConnectionProxy,
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
"ConnectionMeta": ConnectionMeta,
"Record": Record,
"AsyncpgConnection": AsyncpgConnection,
"AsyncpgConnectionConfig": AsyncpgConnectionConfig,
"AsyncpgConnectionContext": AsyncpgConnectionContext,
"AsyncpgCursor": AsyncpgCursor,
"AsyncpgDriver": AsyncpgDriver,
"AsyncpgExceptionHandler": AsyncpgExceptionHandler,
"AsyncpgPool": AsyncpgPool,
"AsyncpgPoolConfig": AsyncpgPoolConfig,
"AsyncpgPreparedStatement": AsyncpgPreparedStatement,
"AsyncpgSessionContext": AsyncpgSessionContext,
})
return namespace
[docs]
def get_event_runtime_hints(self) -> "EventRuntimeHints":
"""Return polling defaults for PostgreSQL queue fallback."""
return EventRuntimeHints(poll_interval=0.5, select_for_update=True, skip_locked=True)