"""Multi-connection pool for aiosqlite."""
import asyncio
import logging
import time
from contextlib import suppress
from inspect import isawaitable
from threading import Thread
from typing import TYPE_CHECKING, Any
import aiosqlite
from sqlspec.exceptions import SQLSpecError
from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context
from sqlspec.utils.uuids import uuid4
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from types import TracebackType
from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection
__all__ = (
"AiosqliteConnectTimeoutError",
"AiosqliteConnectionPool",
"AiosqlitePoolClosedError",
"AiosqlitePoolConnection",
"AiosqlitePoolConnectionContext",
)
logger = get_logger(POOL_LOGGER_NAME)
_ADAPTER_NAME = "aiosqlite"
class AiosqlitePoolClosedError(SQLSpecError):
"""Pool has been closed and cannot accept new operations."""
class AiosqliteConnectTimeoutError(SQLSpecError):
"""Connection could not be established within the specified timeout period."""
class AiosqlitePoolConnection:
"""Wrapper for database connections in the pool."""
__slots__ = ("_closed", "_healthy", "connection", "id", "idle_since")
def __init__(self, connection: "AiosqliteConnection") -> None:
"""Initialize pool connection wrapper.
Args:
connection: The raw aiosqlite connection
"""
self.id = uuid4().hex
self.connection = connection
self.idle_since: float | None = None
self._closed = False
self._healthy = True
@property
def idle_time(self) -> float:
"""Get idle time in seconds.
Returns:
Idle time in seconds, 0.0 if connection is in use
"""
if self.idle_since is None:
return 0.0
return time.time() - self.idle_since
@property
def is_closed(self) -> bool:
"""Check if connection is closed.
Returns:
True if connection is closed
"""
return self._closed
@property
def is_healthy(self) -> bool:
"""Check if connection was healthy on last check.
Returns:
True if connection is presumed healthy
"""
return self._healthy and not self._closed
def mark_as_in_use(self) -> None:
"""Mark connection as in use."""
self.idle_since = None
def mark_as_idle(self) -> None:
"""Mark connection as idle."""
self.idle_since = time.time()
def mark_unhealthy(self) -> None:
"""Mark connection as unhealthy."""
self._healthy = False
async def is_alive(self) -> bool:
"""Check if connection is alive and functional.
Returns:
True if connection is healthy
"""
if self._closed:
self._healthy = False
return False
try:
await self.connection.execute("SELECT 1")
except Exception:
self._healthy = False
return False
else:
self._healthy = True
return True
async def reset(self) -> None:
"""Reset connection to clean state."""
if self._closed:
return
with suppress(Exception):
await self.connection.rollback()
async def close(self) -> None:
"""Close the connection."""
if self._closed:
return
try:
with suppress(Exception):
await self.connection.rollback()
await self.connection.close()
except Exception:
# Note: No pool context available at connection level
log_with_context(
logger, logging.DEBUG, "pool.connection.close.error", adapter=_ADAPTER_NAME, connection_id=self.id
)
finally:
self._closed = True
class AiosqlitePoolConnectionContext:
"""Async context manager for pooled aiosqlite connections."""
__slots__ = ("_connection", "_pool")
def __init__(self, pool: "AiosqliteConnectionPool") -> None:
"""Initialize the context manager.
Args:
pool: Connection pool instance.
"""
self._pool = pool
self._connection: AiosqlitePoolConnection | None = None
async def __aenter__(self) -> "AiosqliteConnection":
self._connection = await self._pool.acquire()
return self._connection.connection
async def __aexit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
) -> "bool | None":
if self._connection is None:
return False
await self._pool.release(self._connection)
self._connection = None
return False
[docs]
class AiosqliteConnectionPool:
"""Multi-connection pool for aiosqlite."""
__slots__ = (
"_closed_event_instance",
"_connect_timeout",
"_connection_parameters",
"_connection_registry",
"_health_check_interval",
"_idle_timeout",
"_lock_instance",
"_min_size",
"_on_connection_create",
"_operation_timeout",
"_pool_id",
"_pool_size",
"_queue_instance",
"_wal_initialized",
"_warmed",
)
[docs]
def __init__(
self,
connection_parameters: "dict[str, Any]",
pool_size: int = 5,
min_size: int = 0,
connect_timeout: float = 30.0,
idle_timeout: float = 24 * 60 * 60,
operation_timeout: float = 10.0,
health_check_interval: float = 30.0,
on_connection_create: "Callable[[AiosqliteConnection], Awaitable[None]] | None" = None,
) -> None:
"""Initialize connection pool.
Args:
connection_parameters: SQLite connection parameters
pool_size: Maximum number of connections in the pool
min_size: Minimum connections to pre-create (pool warming)
connect_timeout: Maximum time to wait for connection acquisition
idle_timeout: Maximum time a connection can remain idle
operation_timeout: Maximum time for connection operations
health_check_interval: Seconds of idle time before running health check
on_connection_create: Async callback executed when connection is created
"""
self._connection_parameters = connection_parameters
self._pool_size = pool_size
self._min_size = min(min_size, pool_size)
self._connect_timeout = connect_timeout
self._idle_timeout = idle_timeout
self._operation_timeout = operation_timeout
self._health_check_interval = health_check_interval
self._on_connection_create = on_connection_create
self._connection_registry: dict[str, AiosqlitePoolConnection] = {}
self._wal_initialized = False
self._warmed = False
self._pool_id = uuid4().hex[:8] # Short ID for logging
self._queue_instance: asyncio.Queue[AiosqlitePoolConnection] | None = None
self._lock_instance: asyncio.Lock | None = None
self._closed_event_instance: asyncio.Event | None = None
@property
def _queue(self) -> "asyncio.Queue[AiosqlitePoolConnection]":
"""Lazy initialization of asyncio.Queue for Python 3.9 compatibility."""
if self._queue_instance is None:
self._queue_instance = asyncio.Queue(maxsize=self._pool_size)
return self._queue_instance
@property
def _lock(self) -> asyncio.Lock:
"""Lazy initialization of asyncio.Lock for Python 3.9 compatibility."""
if self._lock_instance is None:
self._lock_instance = asyncio.Lock()
return self._lock_instance
@property
def _closed_event(self) -> asyncio.Event:
"""Lazy initialization of asyncio.Event for Python 3.9 compatibility."""
if self._closed_event_instance is None:
self._closed_event_instance = asyncio.Event()
return self._closed_event_instance
@property
def is_closed(self) -> bool:
"""Check if pool is closed.
Returns:
True if pool is closed
"""
return self._closed_event_instance is not None and self._closed_event.is_set()
@property
def _database_name(self) -> str:
"""Get sanitized database name for logging."""
db = self._connection_parameters.get("database", "unknown")
return str(db).split("/")[-1] if db else "unknown"
def _set_connect_proxy_daemon(self, connect_proxy: Any) -> None:
"""Set daemon mode on aiosqlite worker thread before await.
aiosqlite <=0.21 used Connection as a Thread subclass.
aiosqlite >=0.22 stores an internal ``_thread`` attribute instead.
"""
try:
if isinstance(connect_proxy, Thread):
connect_proxy.daemon = True
return
worker_thread = connect_proxy._thread # pyright: ignore[reportAttributeAccessIssue]
if isinstance(worker_thread, Thread):
worker_thread.daemon = True
except Exception:
log_with_context(
logger,
logging.DEBUG,
"pool.connection.daemon.configure.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
)
async def _force_stop_connection(self, connection: AiosqlitePoolConnection, *, reason: str) -> None:
"""Force-stop aiosqlite worker thread when graceful close times out."""
try:
stop_method = connection.connection.stop # pyright: ignore[reportAttributeAccessIssue]
except Exception:
log_with_context(
logger,
logging.DEBUG,
"pool.connection.force_stop.unavailable",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
reason=reason,
)
return
try:
stop_result = stop_method()
if isawaitable(stop_result):
await asyncio.wait_for(stop_result, timeout=self._operation_timeout)
log_with_context(
logger,
logging.DEBUG,
"pool.connection.force_stop.success",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
reason=reason,
)
except asyncio.TimeoutError:
log_with_context(
logger,
logging.WARNING,
"pool.connection.force_stop.timeout",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
timeout_seconds=self._operation_timeout,
reason=reason,
)
except Exception as e:
log_with_context(
logger,
logging.WARNING,
"pool.connection.force_stop.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
reason=reason,
error=str(e),
)
[docs]
def size(self) -> int:
"""Get total number of connections in pool.
Returns:
Total connection count
"""
return len(self._connection_registry)
[docs]
def checked_out(self) -> int:
"""Get number of checked out connections.
Returns:
Number of connections currently in use
"""
if self._queue_instance is None:
return len(self._connection_registry)
return len(self._connection_registry) - self._queue.qsize()
async def _create_connection(self) -> AiosqlitePoolConnection:
"""Create a new connection.
Returns:
New pool connection instance
"""
connect_proxy = aiosqlite.connect(**self._connection_parameters)
self._set_connect_proxy_daemon(connect_proxy)
connection = await connect_proxy
database_path = str(self._connection_parameters.get("database", ""))
is_shared_cache = "cache=shared" in database_path
is_memory_db = ":memory:" in database_path or "mode=memory" in database_path
try:
if is_memory_db:
await connection.execute("PRAGMA journal_mode = MEMORY")
await connection.execute("PRAGMA synchronous = OFF")
await connection.execute("PRAGMA temp_store = MEMORY")
await connection.execute("PRAGMA cache_size = -16000")
else:
await connection.execute("PRAGMA journal_mode = WAL")
await connection.execute("PRAGMA synchronous = NORMAL")
await connection.execute("PRAGMA foreign_keys = ON")
await connection.execute("PRAGMA busy_timeout = 30000")
if is_shared_cache and is_memory_db:
await connection.execute("PRAGMA read_uncommitted = ON")
await connection.commit()
if is_shared_cache:
self._wal_initialized = True
except Exception:
log_with_context(
logger,
logging.WARNING,
"pool.connection.configure.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
)
await connection.execute("PRAGMA foreign_keys = ON")
await connection.execute("PRAGMA busy_timeout = 30000")
await connection.commit()
# Call user-provided callback after internal setup
if self._on_connection_create is not None:
await self._on_connection_create(connection)
pool_connection = AiosqlitePoolConnection(connection)
pool_connection.mark_as_idle()
async with self._lock:
self._connection_registry[pool_connection.id] = pool_connection
return pool_connection
async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool:
"""Check if connection is healthy and claim it.
Uses passive health checks: connections idle less than health_check_interval
are assumed healthy based on their last known state. Active health checks
(SELECT 1) are only performed on long-idle connections.
Args:
connection: Connection to check and claim
Returns:
True if connection was claimed
"""
if connection.idle_time > self._idle_timeout:
await self._retire_connection(connection, reason="idle_timeout")
return False
if not connection.is_healthy:
await self._retire_connection(connection, reason="unhealthy")
return False
if connection.idle_time > self._health_check_interval:
try:
is_alive = await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout)
if not is_alive:
await self._retire_connection(connection, reason="health_check_failed")
return False
except asyncio.TimeoutError:
await self._retire_connection(connection, reason="health_check_timeout")
return False
connection.mark_as_in_use()
return True
async def _retire_connection(self, connection: AiosqlitePoolConnection, *, reason: str | None = None) -> None:
"""Retire a connection from the pool.
Args:
connection: Connection to retire
reason: Optional reason for retirement
"""
if reason:
log_with_context(
logger,
logging.DEBUG,
"pool.connection.retire",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
reason=reason,
)
async with self._lock:
self._connection_registry.pop(connection.id, None)
try:
await asyncio.wait_for(connection.close(), timeout=self._operation_timeout)
except asyncio.TimeoutError:
log_with_context(
logger,
logging.WARNING,
"pool.connection.close.timeout",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
timeout_seconds=self._operation_timeout,
)
await self._force_stop_connection(connection, reason="retire_close_timeout")
async def _try_provision_new_connection(self) -> "AiosqlitePoolConnection | None":
"""Try to create a new connection if under capacity.
Returns:
New connection if successful, None if at capacity
"""
async with self._lock:
if len(self._connection_registry) >= self._pool_size:
return None
try:
connection = await self._create_connection()
except Exception:
log_with_context(
logger,
logging.WARNING,
"pool.connection.create.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
pool_size=len(self._connection_registry),
max_size=self._pool_size,
)
return None
else:
connection.mark_as_in_use()
return connection
async def _wait_for_healthy_connection(self) -> AiosqlitePoolConnection:
"""Wait for a healthy connection to become available.
Returns:
Available healthy connection
Raises:
AiosqlitePoolClosedError: If pool is closed while waiting
"""
while True:
get_connection_task = asyncio.create_task(self._queue.get())
pool_closed_task = asyncio.create_task(self._closed_event.wait())
done, pending = await asyncio.wait(
{get_connection_task, pool_closed_task}, return_when=asyncio.FIRST_COMPLETED
)
try:
if pool_closed_task in done:
msg = "Pool closed during connection acquisition"
raise AiosqlitePoolClosedError(msg)
connection = get_connection_task.result()
if await self._claim_if_healthy(connection):
return connection
finally:
for task in pending:
task.cancel()
with suppress(asyncio.CancelledError):
await task
async def _warm_pool(self) -> None:
"""Pre-create minimum connections for pool warming.
Creates connections up to min_size to avoid cold-start latency
on first requests.
"""
if self._warmed or self._min_size <= 0:
return
self._warmed = True
connections_needed = self._min_size - len(self._connection_registry)
if connections_needed <= 0:
return
log_with_context(
logger,
logging.DEBUG,
"pool.warmup.start",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
connections_needed=connections_needed,
min_size=self._min_size,
)
tasks = [self._create_connection() for _ in range(connections_needed)]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, AiosqlitePoolConnection):
self._queue.put_nowait(result)
elif isinstance(result, Exception):
log_with_context(
logger,
logging.WARNING,
"pool.warmup.connection.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
error=str(result),
)
async def _get_connection(self) -> AiosqlitePoolConnection:
"""Run the three-phase connection acquisition cycle.
Returns:
Available connection
Raises:
AiosqlitePoolClosedError: If pool is closed
"""
# Fast path: check closed state directly to avoid property overhead
if self._closed_event_instance is not None and self._closed_event_instance.is_set():
msg = "Cannot acquire connection from closed pool"
raise AiosqlitePoolClosedError(msg)
if not self._warmed and self._min_size > 0:
await self._warm_pool()
# Fast path: try to get from queue without health check overhead for fresh connections
while not self._queue.empty():
connection = self._queue.get_nowait()
# Fast claim for recently-used connections (idle < health_check_interval)
if connection.idle_since is not None:
idle_time = time.time() - connection.idle_since
if idle_time <= self._health_check_interval and connection.is_healthy:
connection.idle_since = None # mark_as_in_use inline
return connection
# Fall back to full health check for older connections
if await self._claim_if_healthy(connection):
return connection
# Try to create new connection if under capacity
# Fast path: check capacity without lock first
if len(self._connection_registry) < self._pool_size:
new_connection = await self._try_provision_new_connection()
if new_connection is not None:
return new_connection
return await self._wait_for_healthy_connection()
[docs]
async def acquire(self) -> AiosqlitePoolConnection:
"""Acquire a connection from the pool.
Returns:
Available connection
Raises:
AiosqliteConnectTimeoutError: If acquisition times out
"""
# Fast path: try to get connection without timeout wrapper
# Only use timeout when we need to wait for a connection
try:
connection = await self._get_connection()
except AiosqlitePoolClosedError:
raise
except Exception:
# If fast path fails, fall back to timeout-wrapped acquisition
try:
connection = await asyncio.wait_for(self._get_connection(), timeout=self._connect_timeout)
except asyncio.TimeoutError as e:
msg = f"Connection acquisition timed out after {self._connect_timeout}s"
raise AiosqliteConnectTimeoutError(msg) from e
if not self._wal_initialized and "cache=shared" in str(self._connection_parameters.get("database", "")):
await asyncio.sleep(0.01)
return connection
[docs]
async def release(self, connection: AiosqlitePoolConnection) -> None:
"""Release a connection back to the pool.
Args:
connection: Connection to release
"""
# Fast path: check closed state directly
if self._closed_event_instance is not None and self._closed_event_instance.is_set():
await self._retire_connection(connection)
return
if connection.id not in self._connection_registry:
log_with_context(
logger,
logging.WARNING,
"pool.connection.release.unknown",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
)
return
try:
# Fast path: skip timeout wrapper for reset, just do the rollback directly
# The rollback itself is fast for SQLite; timeout is overkill for hot path
with suppress(Exception):
await connection.connection.rollback()
connection.idle_since = time.time() # mark_as_idle inline
self._queue.put_nowait(connection)
except Exception as e:
log_with_context(
logger,
logging.WARNING,
"pool.connection.reset.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connection.id,
error=str(e),
)
connection.mark_unhealthy()
await self._retire_connection(connection)
[docs]
def get_connection(self) -> "AiosqlitePoolConnectionContext":
"""Get a connection with automatic release."""
return AiosqlitePoolConnectionContext(self)
[docs]
async def close(self) -> None:
"""Close the connection pool."""
if self.is_closed:
return
self._closed_event.set()
while not self._queue.empty():
self._queue.get_nowait()
async with self._lock:
connections = list(self._connection_registry.values())
self._connection_registry.clear()
if connections:
close_tasks = [asyncio.wait_for(conn.close(), timeout=self._operation_timeout) for conn in connections]
results = await asyncio.gather(*close_tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
if isinstance(result, asyncio.TimeoutError):
await self._force_stop_connection(connections[i], reason="pool_close_timeout")
log_with_context(
logger,
logging.WARNING,
"pool.close.connection.error",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
connection_id=connections[i].id,
error=str(result),
)