"""AioSQL adapter implementation for SQLSpec.
This module provides adapter classes that implement the aiosql adapter protocols
while using SQLSpec drivers under the hood. This enables users to load SQL queries
from files using aiosql while using SQLSpec's features for execution and type mapping.
"""
import logging
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, ClassVar, Generic, TypeVar
from sqlspec.core import SQL, SQLResult, StatementConfig
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
from sqlspec.typing import AiosqlAsyncProtocol, AiosqlParamType, AiosqlSQLOperationType, AiosqlSyncProtocol
from sqlspec.utils.module_loader import ensure_aiosql
logger = logging.getLogger("sqlspec.extensions.aiosql")
__all__ = ("AiosqlAsyncAdapter", "AiosqlSyncAdapter")
T = TypeVar("T")
DriverT = TypeVar("DriverT", bound="SyncDriverAdapterBase | AsyncDriverAdapterBase")
class AsyncCursorLike:
def __init__(self, result: Any) -> None:
self.result = result
async def fetchall(self) -> list[Any]:
if isinstance(self.result, SQLResult) and self.result.data is not None:
return list(self.result.data)
return []
async def fetchone(self) -> Any | None:
rows = await self.fetchall()
return rows[0] if rows else None
class CursorLike:
def __init__(self, result: Any) -> None:
self.result = result
def fetchall(self) -> list[Any]:
if isinstance(self.result, SQLResult) and self.result.data is not None:
return list(self.result.data)
return []
def fetchone(self) -> Any | None:
rows = self.fetchall()
return rows[0] if rows else None
def _normalize_dialect(dialect: "str | Any | None") -> str:
"""Normalize dialect name for SQLGlot compatibility.
Args:
dialect: Original dialect name (can be str, Dialect, type[Dialect], or None)
Returns:
Converted dialect name compatible with SQLGlot
"""
if dialect is None:
return "sql"
if isinstance(dialect, str):
dialect_str = dialect.lower()
elif hasattr(dialect, "__name__"):
dialect_str = str(dialect.__name__).lower()
elif hasattr(dialect, "name"):
dialect_str = str(dialect.name).lower()
else:
dialect_str = str(dialect).lower()
dialect_mapping = {
"postgresql": "postgres",
"psycopg": "postgres",
"asyncpg": "postgres",
"psqlpy": "postgres",
"sqlite3": "sqlite",
"aiosqlite": "sqlite",
}
return dialect_mapping.get(dialect_str, dialect_str)
class _AiosqlAdapterBase(Generic[DriverT]):
"""Base adapter class providing common functionality for aiosql integration."""
def __init__(self, driver: DriverT) -> None:
"""Initialize the base adapter.
Args:
driver: SQLSpec driver to use for execution.
"""
ensure_aiosql()
self.driver: DriverT = driver
def process_sql(self, query_name: str, op_type: "AiosqlSQLOperationType", sql: str) -> str:
"""Process SQL string for aiosql compatibility.
Args:
query_name: Name of the query
op_type: Operation type
sql: SQL string to process
Returns:
Processed SQL string
"""
return sql
def _create_sql_object(self, sql: str, parameters: "AiosqlParamType" = None) -> SQL:
"""Create SQL object with proper configuration.
Args:
sql: SQL string
parameters: Query parameters
Returns:
Configured SQL object
"""
return SQL(
sql,
parameters,
config=StatementConfig(enable_validation=False),
dialect=_normalize_dialect(getattr(self.driver, "dialect", "sqlite")),
)
[docs]
class AiosqlSyncAdapter(_AiosqlAdapterBase[SyncDriverAdapterBase], AiosqlSyncProtocol):
"""Synchronous adapter that implements aiosql protocol using SQLSpec drivers.
This adapter bridges aiosql's synchronous driver protocol with SQLSpec's sync drivers,
enabling queries loaded by aiosql to be executed with SQLSpec drivers.
"""
is_aio_driver: ClassVar[bool] = False
[docs]
def __init__(self, driver: "SyncDriverAdapterBase") -> None:
"""Initialize the sync adapter.
Args:
driver: SQLSpec sync driver to use for execution
"""
super().__init__(driver)
[docs]
def select(
self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType", record_class: Any | None = None
) -> Generator[Any, None, None]:
"""Execute a SELECT query and return results as generator.
Args:
conn: Database connection (passed through to SQLSpec driver)
query_name: Name of the query
sql: SQL string
parameters: Query parameters
record_class: Deprecated - use schema_type in driver.execute instead
Yields:
Query result rows
Note:
The record_class parameter is ignored for compatibility. Use schema_type
in driver.execute or _sqlspec_schema_type in parameters for type mapping.
"""
if record_class is not None:
logger.warning(
"record_class parameter is deprecated and ignored. "
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
)
result = self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
if isinstance(result, SQLResult) and result.data is not None:
yield from result.data
[docs]
def select_one(
self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType", record_class: Any | None = None
) -> tuple[Any, ...] | None:
"""Execute a SELECT query and return first result.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
record_class: Deprecated - use schema_type in driver.execute instead
Returns:
First result row or None
Note:
The record_class parameter is ignored for compatibility. Use schema_type
in driver.execute or _sqlspec_schema_type in parameters for type mapping.
"""
if record_class is not None:
logger.warning(
"record_class parameter is deprecated and ignored. "
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
)
result = self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
if hasattr(result, "data") and result.data and isinstance(result, SQLResult):
row = result.data[0]
return tuple(row.values()) if isinstance(row, dict) else row
return None
[docs]
def select_value(self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType") -> Any | None:
"""Execute a SELECT query and return first value of first row.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Returns:
First value of first row or None
"""
row = self.select_one(conn, query_name, sql, parameters)
if row is None:
return None
if isinstance(row, dict):
return next(iter(row.values())) if row else None
if hasattr(row, "__getitem__"):
return row[0] if len(row) > 0 else None
return row
[docs]
@contextmanager
def select_cursor(
self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType"
) -> Generator[Any, None, None]:
"""Execute a SELECT query and return cursor context manager.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Yields:
Cursor-like object with results
"""
result = self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
yield CursorLike(result)
[docs]
def insert_update_delete(self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType") -> int:
"""Execute INSERT/UPDATE/DELETE and return affected rows.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Returns:
Number of affected rows
"""
result = self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
return result.rows_affected if hasattr(result, "rows_affected") else 0
[docs]
def insert_update_delete_many(self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType") -> int:
"""Execute INSERT/UPDATE/DELETE with many parameter sets.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Sequence of parameter sets
Returns:
Number of affected rows
"""
result = self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
return result.rows_affected if hasattr(result, "rows_affected") else 0
[docs]
def insert_returning(self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType") -> Any | None:
"""Execute INSERT with RETURNING and return result.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Returns:
Returned value or None
"""
return self.select_one(conn, query_name, sql, parameters)
class _AsyncCursorContextManager(Generic[T]):
def __init__(self, cursor_like: T) -> None:
self._cursor_like = cursor_like
async def __aenter__(self) -> T:
return self._cursor_like
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass
[docs]
class AiosqlAsyncAdapter(_AiosqlAdapterBase[AsyncDriverAdapterBase], AiosqlAsyncProtocol):
"""Asynchronous adapter that implements aiosql protocol using SQLSpec drivers.
This adapter bridges aiosql's async driver protocol with SQLSpec's async drivers,
enabling queries loaded by aiosql to be executed with SQLSpec async drivers.
"""
is_aio_driver: ClassVar[bool] = True
[docs]
def __init__(self, driver: "AsyncDriverAdapterBase") -> None:
"""Initialize the async adapter.
Args:
driver: SQLSpec async driver to use for execution
"""
super().__init__(driver)
[docs]
async def select(
self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType", record_class: Any | None = None
) -> list[Any]:
"""Execute a SELECT query and return results as list.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
record_class: Deprecated - use schema_type in driver.execute instead
Returns:
List of query result rows
Note:
The record_class parameter is ignored for compatibility. Use schema_type
in driver.execute or _sqlspec_schema_type in parameters for type mapping.
"""
if record_class is not None:
logger.warning(
"record_class parameter is deprecated and ignored. "
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
)
result = await self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
if hasattr(result, "data") and result.data is not None and isinstance(result, SQLResult):
return list(result.data)
return []
[docs]
async def select_one(
self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType", record_class: Any | None = None
) -> Any | None:
"""Execute a SELECT query and return first result.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
record_class: Deprecated - use schema_type in driver.execute instead
Returns:
First result row or None
Note:
The record_class parameter is ignored for compatibility. Use schema_type
in driver.execute or _sqlspec_schema_type in parameters for type mapping.
"""
if record_class is not None:
logger.warning(
"record_class parameter is deprecated and ignored. "
"Use schema_type in driver.execute or _sqlspec_schema_type in parameters."
)
result = await self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
if hasattr(result, "data") and result.data and isinstance(result, SQLResult):
row = result.data[0]
return tuple(row.values()) if isinstance(row, dict) else row
return None
[docs]
async def select_value(self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType") -> Any | None:
"""Execute a SELECT query and return first value of first row.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Returns:
First value of first row or None
"""
row = await self.select_one(conn, query_name, sql, parameters)
if row is None:
return None
if isinstance(row, dict):
return next(iter(row.values())) if row else None
if hasattr(row, "__getitem__"):
return row[0] if len(row) > 0 else None
return row
[docs]
async def select_cursor(
self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType"
) -> "_AsyncCursorContextManager[Any]":
"""Execute a SELECT query and return cursor context manager.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Yields:
Cursor-like object with results
"""
result = await self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
return _AsyncCursorContextManager(AsyncCursorLike(result))
[docs]
async def insert_update_delete(self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType") -> None:
"""Execute INSERT/UPDATE/DELETE.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Note:
Returns None per aiosql async protocol
"""
await self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
[docs]
async def insert_update_delete_many(
self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType"
) -> None:
"""Execute INSERT/UPDATE/DELETE with many parameter sets.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Sequence of parameter sets
Note:
Returns None per aiosql async protocol
"""
await self.driver.execute(self._create_sql_object(sql, parameters), connection=conn)
[docs]
async def insert_returning(self, conn: Any, query_name: str, sql: str, parameters: "AiosqlParamType") -> Any | None:
"""Execute INSERT with RETURNING and return result.
Args:
conn: Database connection
query_name: Name of the query
sql: SQL string
parameters: Query parameters
Returns:
Returned value or None
"""
return await self.select_one(conn, query_name, sql, parameters)