"""AsyncPG database configuration with direct field-based configuration."""
import logging
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict
from asyncpg import Connection, Record
from asyncpg import create_pool as asyncpg_create_pool
from asyncpg.connection import ConnectionMeta
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
from typing_extensions import NotRequired
from sqlspec.adapters.asyncpg._type_handlers import register_json_codecs, register_pgvector_support
from sqlspec.adapters.asyncpg._types import AsyncpgConnection, AsyncpgPool, AsyncpgPreparedStatement
from sqlspec.adapters.asyncpg.driver import (
AsyncpgCursor,
AsyncpgDriver,
AsyncpgExceptionHandler,
asyncpg_statement_config,
build_asyncpg_statement_config,
)
from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.typing import ALLOYDB_CONNECTOR_INSTALLED, CLOUD_SQL_CONNECTOR_INSTALLED, PGVECTOR_INSTALLED
from sqlspec.utils.serializers import from_json, to_json
if TYPE_CHECKING:
from asyncio.events import AbstractEventLoop
from collections.abc import AsyncGenerator, Awaitable
from sqlspec.core import StatementConfig
from sqlspec.observability import ObservabilityConfig
__all__ = ("AsyncpgConfig", "AsyncpgConnectionConfig", "AsyncpgDriverFeatures", "AsyncpgPoolConfig")
logger = logging.getLogger("sqlspec")
class AsyncpgConnectionConfig(TypedDict):
"""TypedDict for AsyncPG connection parameters."""
dsn: NotRequired[str]
host: NotRequired[str]
port: NotRequired[int]
user: NotRequired[str]
password: NotRequired[str]
database: NotRequired[str]
ssl: NotRequired[Any]
passfile: NotRequired[str]
direct_tls: NotRequired[bool]
connect_timeout: NotRequired[float]
command_timeout: NotRequired[float]
statement_cache_size: NotRequired[int]
max_cached_statement_lifetime: NotRequired[int]
max_cacheable_statement_size: NotRequired[int]
server_settings: NotRequired[dict[str, str]]
[docs]
class AsyncpgPoolConfig(AsyncpgConnectionConfig):
"""TypedDict for AsyncPG pool parameters, inheriting connection parameters."""
min_size: NotRequired[int]
max_size: NotRequired[int]
max_queries: NotRequired[int]
max_inactive_connection_lifetime: NotRequired[float]
setup: NotRequired["Callable[[AsyncpgConnection], Awaitable[None]]"]
init: NotRequired["Callable[[AsyncpgConnection], Awaitable[None]]"]
loop: NotRequired["AbstractEventLoop"]
connection_class: NotRequired[type["AsyncpgConnection"]]
record_class: NotRequired[type[Record]]
extra: NotRequired[dict[str, Any]]
class AsyncpgDriverFeatures(TypedDict):
"""AsyncPG driver feature flags.
json_serializer: Custom JSON serializer function for PostgreSQL JSON/JSONB types.
Defaults to sqlspec.utils.serializers.to_json.
Use for performance optimization (e.g., orjson) or custom encoding behavior.
Applied when enable_json_codecs is True.
json_deserializer: Custom JSON deserializer function for PostgreSQL JSON/JSONB types.
Defaults to sqlspec.utils.serializers.from_json.
Use for performance optimization (e.g., orjson) or custom decoding behavior.
Applied when enable_json_codecs is True.
enable_json_codecs: Enable automatic JSON/JSONB codec registration on connections.
Defaults to True for seamless Python dict/list to PostgreSQL JSON/JSONB conversion.
Set to False to disable automatic codec registration (manual handling required).
enable_pgvector: Enable pgvector extension support for vector similarity search.
Requires pgvector-python package (pip install pgvector) and PostgreSQL with pgvector extension.
Defaults to True when pgvector-python is installed.
Provides automatic conversion between Python objects and PostgreSQL vector types.
Enables vector similarity operations and index support.
enable_cloud_sql: Enable Google Cloud SQL connector integration.
Requires cloud-sql-python-connector package.
Defaults to False (explicit opt-in required).
Auto-configures IAM authentication, SSL, and IP routing.
Mutually exclusive with enable_alloydb.
cloud_sql_instance: Cloud SQL instance connection name.
Format: "project:region:instance"
Required when enable_cloud_sql is True.
cloud_sql_enable_iam_auth: Enable IAM database authentication.
Defaults to False for passwordless authentication.
When False, requires user/password in pool_config.
cloud_sql_ip_type: IP address type for connection.
Options: "PUBLIC", "PRIVATE", "PSC"
Defaults to "PRIVATE".
enable_alloydb: Enable Google AlloyDB connector integration.
Requires cloud-alloydb-python-connector package.
Defaults to False (explicit opt-in required).
Auto-configures IAM authentication and private networking.
Mutually exclusive with enable_cloud_sql.
alloydb_instance_uri: AlloyDB instance URI.
Format: "projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE"
Required when enable_alloydb is True.
alloydb_enable_iam_auth: Enable IAM database authentication.
Defaults to False for passwordless authentication.
alloydb_ip_type: IP address type for connection.
Options: "PUBLIC", "PRIVATE", "PSC"
Defaults to "PRIVATE".
"""
json_serializer: NotRequired[Callable[[Any], str]]
json_deserializer: NotRequired[Callable[[str], Any]]
enable_json_codecs: NotRequired[bool]
enable_pgvector: NotRequired[bool]
enable_cloud_sql: NotRequired[bool]
cloud_sql_instance: NotRequired[str]
cloud_sql_enable_iam_auth: NotRequired[bool]
cloud_sql_ip_type: NotRequired[str]
enable_alloydb: NotRequired[bool]
alloydb_instance_uri: NotRequired[str]
alloydb_enable_iam_auth: NotRequired[bool]
alloydb_ip_type: NotRequired[str]
[docs]
class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", AsyncpgDriver]):
"""Configuration for AsyncPG database connections using TypedDict."""
driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver
connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment]
supports_transactional_ddl: "ClassVar[bool]" = True
supports_native_arrow_export: "ClassVar[bool]" = True
supports_native_arrow_import: "ClassVar[bool]" = True
supports_native_parquet_export: "ClassVar[bool]" = True
supports_native_parquet_import: "ClassVar[bool]" = True
[docs]
def __init__(
self,
*,
pool_config: "AsyncpgPoolConfig | dict[str, Any] | None" = None,
pool_instance: "Pool[Record] | None" = None,
migration_config: "dict[str, Any] | None" = None,
statement_config: "StatementConfig | None" = None,
driver_features: "AsyncpgDriverFeatures | dict[str, Any] | None" = None,
bind_key: "str | None" = None,
extension_config: "ExtensionConfigs | None" = None,
observability_config: "ObservabilityConfig | None" = None,
) -> None:
"""Initialize AsyncPG configuration.
Args:
pool_config: Pool configuration parameters (TypedDict or dict)
pool_instance: Existing pool instance to use
migration_config: Migration configuration
statement_config: Statement configuration override
driver_features: Driver features configuration (TypedDict or dict)
bind_key: Optional unique identifier for this configuration
extension_config: Extension-specific configuration (e.g., Litestar plugin settings)
observability_config: Adapter-level observability overrides for lifecycle hooks and observers
"""
features_dict: dict[str, Any] = dict(driver_features) if driver_features else {}
serializer = features_dict.setdefault("json_serializer", to_json)
deserializer = features_dict.setdefault("json_deserializer", from_json)
features_dict.setdefault("enable_json_codecs", True)
features_dict.setdefault("enable_pgvector", PGVECTOR_INSTALLED)
features_dict.setdefault("enable_cloud_sql", False)
features_dict.setdefault("enable_alloydb", False)
base_statement_config = statement_config or build_asyncpg_statement_config(
json_serializer=serializer, json_deserializer=deserializer
)
super().__init__(
pool_config=dict(pool_config) if pool_config else {},
pool_instance=pool_instance,
migration_config=migration_config,
statement_config=base_statement_config,
driver_features=features_dict,
bind_key=bind_key,
extension_config=extension_config,
observability_config=observability_config,
)
self._cloud_sql_connector: Any | None = None
self._alloydb_connector: Any | None = None
self._validate_connector_config()
def _validate_connector_config(self) -> None:
"""Validate Google Cloud connector configuration.
Raises:
ImproperConfigurationError: If configuration is invalid.
"""
enable_cloud_sql = self.driver_features.get("enable_cloud_sql", False)
enable_alloydb = self.driver_features.get("enable_alloydb", False)
if enable_cloud_sql and enable_alloydb:
msg = "Cannot enable both Cloud SQL and AlloyDB connectors simultaneously. Use separate configs for each database."
raise ImproperConfigurationError(msg)
if enable_cloud_sql:
if not CLOUD_SQL_CONNECTOR_INSTALLED:
msg = "cloud-sql-python-connector package not installed. Install with: pip install cloud-sql-python-connector"
raise ImproperConfigurationError(msg)
instance = self.driver_features.get("cloud_sql_instance")
if not instance:
msg = "cloud_sql_instance required when enable_cloud_sql is True. Format: 'project:region:instance'"
raise ImproperConfigurationError(msg)
cloud_sql_instance_parts_expected = 2
if instance.count(":") != cloud_sql_instance_parts_expected:
msg = f"Invalid Cloud SQL instance format: {instance}. Expected format: 'project:region:instance'"
raise ImproperConfigurationError(msg)
elif enable_alloydb:
if not ALLOYDB_CONNECTOR_INSTALLED:
msg = "cloud-alloydb-python-connector package not installed. Install with: pip install cloud-alloydb-python-connector"
raise ImproperConfigurationError(msg)
instance_uri = self.driver_features.get("alloydb_instance_uri")
if not instance_uri:
msg = "alloydb_instance_uri required when enable_alloydb is True. Format: 'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'"
raise ImproperConfigurationError(msg)
if not instance_uri.startswith("projects/"):
msg = f"Invalid AlloyDB instance URI format: {instance_uri}. Expected format: 'projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE'"
raise ImproperConfigurationError(msg)
def _get_pool_config_dict(self) -> "dict[str, Any]":
"""Get pool configuration as plain dict for external library.
Returns:
Dictionary with pool parameters, filtering out None values.
"""
config: dict[str, Any] = dict(self.pool_config)
extras = config.pop("extra", {})
config.update(extras)
return {k: v for k, v in config.items() if v is not None}
def _setup_cloud_sql_connector(self, config: "dict[str, Any]") -> None:
"""Setup Cloud SQL connector and configure pool for connection factory pattern.
Args:
config: Pool configuration dictionary to modify in-place.
"""
from google.cloud.sql.connector import Connector # type: ignore[import-untyped,unused-ignore]
self._cloud_sql_connector = Connector()
user = config.get("user")
password = config.get("password")
database = config.get("database")
async def get_conn() -> "AsyncpgConnection":
conn_kwargs: dict[str, Any] = {
"instance_connection_string": self.driver_features["cloud_sql_instance"],
"driver": "asyncpg",
"enable_iam_auth": self.driver_features.get("cloud_sql_enable_iam_auth", False),
"ip_type": self.driver_features.get("cloud_sql_ip_type", "PRIVATE"),
}
if user:
conn_kwargs["user"] = user
if password:
conn_kwargs["password"] = password
if database:
conn_kwargs["db"] = database
conn: AsyncpgConnection = await self._cloud_sql_connector.connect_async(**conn_kwargs) # type: ignore[union-attr]
return conn
for key in ("dsn", "host", "port", "user", "password", "database"):
config.pop(key, None)
config["connect"] = get_conn
def _setup_alloydb_connector(self, config: "dict[str, Any]") -> None:
"""Setup AlloyDB connector and configure pool for connection factory pattern.
Args:
config: Pool configuration dictionary to modify in-place.
"""
from google.cloud.alloydb.connector import AsyncConnector # type: ignore[import-untyped,unused-ignore]
self._alloydb_connector = AsyncConnector()
user = config.get("user")
password = config.get("password")
database = config.get("database")
async def get_conn() -> "AsyncpgConnection":
conn_kwargs: dict[str, Any] = {
"instance_uri": self.driver_features["alloydb_instance_uri"],
"driver": "asyncpg",
"enable_iam_auth": self.driver_features.get("alloydb_enable_iam_auth", False),
"ip_type": self.driver_features.get("alloydb_ip_type", "PRIVATE"),
}
if user:
conn_kwargs["user"] = user
if password:
conn_kwargs["password"] = password
if database:
conn_kwargs["db"] = database
conn: AsyncpgConnection = await self._alloydb_connector.connect(**conn_kwargs) # type: ignore[union-attr]
return conn
for key in ("dsn", "host", "port", "user", "password", "database"):
config.pop(key, None)
config["connect"] = get_conn
async def _create_pool(self) -> "Pool[Record]":
"""Create the actual async connection pool."""
config = self._get_pool_config_dict()
if self.driver_features.get("enable_cloud_sql", False):
self._setup_cloud_sql_connector(config)
elif self.driver_features.get("enable_alloydb", False):
self._setup_alloydb_connector(config)
config.setdefault("init", self._init_connection)
return await asyncpg_create_pool(**config)
async def _init_connection(self, connection: "AsyncpgConnection") -> None:
"""Initialize connection with JSON codecs and pgvector support.
Args:
connection: AsyncPG connection to initialize.
"""
if self.driver_features.get("enable_json_codecs", True):
await register_json_codecs(
connection,
encoder=self.driver_features.get("json_serializer", to_json),
decoder=self.driver_features.get("json_deserializer", from_json),
)
if self.driver_features.get("enable_pgvector", False):
await register_pgvector_support(connection)
async def _close_pool(self) -> None:
"""Close the actual async connection pool and cleanup connectors."""
if self.pool_instance:
await self.pool_instance.close()
if self._cloud_sql_connector is not None:
await self._cloud_sql_connector.close_async()
self._cloud_sql_connector = None
if self._alloydb_connector is not None:
await self._alloydb_connector.close()
self._alloydb_connector = None
[docs]
async def close_pool(self) -> None:
"""Close the connection pool."""
await self._close_pool()
[docs]
async def create_connection(self) -> "AsyncpgConnection":
"""Create a single async connection from the pool.
Returns:
An AsyncPG connection instance.
"""
if self.pool_instance is None:
self.pool_instance = await self._create_pool()
return await self.pool_instance.acquire()
[docs]
@asynccontextmanager
async def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncGenerator[AsyncpgConnection, None]":
"""Provide an async connection context manager.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Yields:
An AsyncPG connection instance.
"""
if self.pool_instance is None:
self.pool_instance = await self._create_pool()
connection = None
try:
connection = await self.pool_instance.acquire()
yield connection
finally:
if connection is not None:
await self.pool_instance.release(connection)
[docs]
@asynccontextmanager
async def provide_session(
self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
) -> "AsyncGenerator[AsyncpgDriver, None]":
"""Provide an async driver session context manager.
Args:
*args: Additional arguments.
statement_config: Optional statement configuration override.
**kwargs: Additional keyword arguments.
Yields:
An AsyncpgDriver instance.
"""
async with self.provide_connection(*args, **kwargs) as connection:
final_statement_config = statement_config or self.statement_config or asyncpg_statement_config
driver = self.driver_type(
connection=connection, statement_config=final_statement_config, driver_features=self.driver_features
)
yield self._prepare_driver(driver)
[docs]
async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
"""Provide async pool instance.
Returns:
The async connection pool.
"""
if not self.pool_instance:
self.pool_instance = await self.create_pool()
return self.pool_instance
[docs]
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for AsyncPG types.
This provides all AsyncPG-specific types that Litestar needs to recognize
to avoid serialization attempts.
Returns:
Dictionary mapping type names to types.
"""
namespace = super().get_signature_namespace()
namespace.update({
"Connection": Connection,
"Pool": Pool,
"PoolConnectionProxy": PoolConnectionProxy,
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
"ConnectionMeta": ConnectionMeta,
"Record": Record,
"AsyncpgConnection": AsyncpgConnection,
"AsyncpgConnectionConfig": AsyncpgConnectionConfig,
"AsyncpgCursor": AsyncpgCursor,
"AsyncpgDriver": AsyncpgDriver,
"AsyncpgExceptionHandler": AsyncpgExceptionHandler,
"AsyncpgPool": AsyncpgPool,
"AsyncpgPoolConfig": AsyncpgPoolConfig,
"AsyncpgPreparedStatement": AsyncpgPreparedStatement,
})
return namespace