Source code for sqlspec.adapters.aiosqlite.litestar.store

"""AioSQLite session store for Litestar integration."""

from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING

from sqlspec.extensions.litestar.store import BaseSQLSpecStore
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
    from sqlspec.adapters.aiosqlite.config import AiosqliteConfig

logger = get_logger("adapters.aiosqlite.litestar.store")

SECONDS_PER_DAY = 86400.0
JULIAN_EPOCH = 2440587.5

__all__ = ("AiosqliteStore",)


[docs] class AiosqliteStore(BaseSQLSpecStore["AiosqliteConfig"]): """SQLite session store using AioSQLite driver. Implements server-side session storage for Litestar using SQLite via the AioSQLite driver. Provides efficient session management with: - Native async SQLite operations - INSERT OR REPLACE for UPSERT functionality - Automatic expiration handling - Efficient cleanup of expired sessions Args: config: AiosqliteConfig instance. Example: from sqlspec.adapters.aiosqlite import AiosqliteConfig from sqlspec.adapters.aiosqlite.litestar.store import AiosqliteStore config = AiosqliteConfig(database=":memory:") store = AiosqliteStore(config) await store.create_table() """ __slots__ = ()
[docs] def __init__(self, config: "AiosqliteConfig") -> None: """Initialize AioSQLite session store. Args: config: AiosqliteConfig instance. Notes: Table name is read from config.extension_config["litestar"]["session_table"]. """ super().__init__(config)
def _get_create_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL. Returns: SQL statement to create the sessions table with proper indexes. Notes: - Uses REAL type for expires_at (stores Julian Day number) - Julian Day enables direct comparison with julianday('now') - Partial index WHERE expires_at IS NOT NULL reduces index size - This approach ensures the index is actually used by query optimizer - Table name is internally controlled, not user input (S608 suppressed) """ return f""" CREATE TABLE IF NOT EXISTS {self._table_name} ( session_id TEXT PRIMARY KEY, data BLOB NOT NULL, expires_at REAL ); CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; """ def _get_drop_table_sql(self) -> "list[str]": """Get SQLite DROP TABLE SQL statements. Returns: List of SQL statements to drop indexes and table. """ return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] def _datetime_to_julian(self, dt: "datetime | None") -> "float | None": """Convert datetime to Julian Day number for SQLite storage. Args: dt: Datetime to convert (must be UTC-aware). Returns: Julian Day number as REAL, or None if dt is None. Notes: Julian Day number is days since November 24, 4714 BCE (proleptic Gregorian). This enables direct comparison with julianday('now') in SQL queries. """ if dt is None: return None epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) delta_days = (dt - epoch).total_seconds() / SECONDS_PER_DAY return JULIAN_EPOCH + delta_days def _julian_to_datetime(self, julian: "float | None") -> "datetime | None": """Convert Julian Day number back to datetime. Args: julian: Julian Day number. Returns: UTC-aware datetime, or None if julian is None. """ if julian is None: return None days_since_epoch = julian - JULIAN_EPOCH timestamp = days_since_epoch * SECONDS_PER_DAY return datetime.fromtimestamp(timestamp, tz=timezone.utc)
[docs] async def create_table(self) -> None: """Create the session table if it doesn't exist.""" sql = self._get_create_table_sql() async with self._config.provide_session() as driver: await driver.execute_script(sql) logger.debug("Created session table: %s", self._table_name)
[docs] async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": """Get a session value by key. Args: key: Session ID to retrieve. renew_for: If given, renew the expiry time for this duration. Returns: Session data as bytes if found and not expired, None otherwise. """ sql = f""" SELECT data, expires_at FROM {self._table_name} WHERE session_id = ? AND (expires_at IS NULL OR julianday(expires_at) > julianday('now')) """ async with self._config.provide_connection() as conn: async with conn.execute(sql, (key,)) as cursor: row = await cursor.fetchone() if row is None: return None data, expires_at_julian = row if renew_for is not None and expires_at_julian is not None: new_expires_at = self._calculate_expires_at(renew_for) new_expires_at_julian = self._datetime_to_julian(new_expires_at) if new_expires_at_julian is not None: update_sql = f""" UPDATE {self._table_name} SET expires_at = ? WHERE session_id = ? """ await conn.execute(update_sql, (new_expires_at_julian, key)) await conn.commit() return bytes(data)
[docs] async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: """Store a session value. Args: key: Session ID. value: Session data. expires_in: Time until expiration. Notes: Stores expires_at as Julian Day number (REAL) for optimal index usage. """ data = self._value_to_bytes(value) expires_at = self._calculate_expires_at(expires_in) expires_at_julian = self._datetime_to_julian(expires_at) sql = f""" INSERT OR REPLACE INTO {self._table_name} (session_id, data, expires_at) VALUES (?, ?, ?) """ async with self._config.provide_connection() as conn: await conn.execute(sql, (key, data, expires_at_julian)) await conn.commit()
[docs] async def delete(self, key: str) -> None: """Delete a session by key. Args: key: Session ID to delete. """ sql = f"DELETE FROM {self._table_name} WHERE session_id = ?" async with self._config.provide_connection() as conn: await conn.execute(sql, (key,)) await conn.commit()
[docs] async def delete_all(self) -> None: """Delete all sessions from the store.""" sql = f"DELETE FROM {self._table_name}" async with self._config.provide_connection() as conn: await conn.execute(sql) await conn.commit() logger.debug("Deleted all sessions from table: %s", self._table_name)
[docs] async def exists(self, key: str) -> bool: """Check if a session key exists and is not expired. Args: key: Session ID to check. Returns: True if the session exists and is not expired. """ sql = f""" SELECT 1 FROM {self._table_name} WHERE session_id = ? AND (expires_at IS NULL OR julianday(expires_at) > julianday('now')) """ async with self._config.provide_connection() as conn, conn.execute(sql, (key,)) as cursor: result = await cursor.fetchone() return result is not None
[docs] async def expires_in(self, key: str) -> "int | None": """Get the time in seconds until the session expires. Args: key: Session ID to check. Returns: Seconds until expiration, or None if no expiry or key doesn't exist. """ sql = f""" SELECT expires_at FROM {self._table_name} WHERE session_id = ? """ async with self._config.provide_connection() as conn: async with conn.execute(sql, (key,)) as cursor: row = await cursor.fetchone() if row is None or row[0] is None: return None expires_at_julian = row[0] expires_at = self._julian_to_datetime(expires_at_julian) if expires_at is None: return None now = datetime.now(timezone.utc) if expires_at <= now: return 0 delta = expires_at - now return int(delta.total_seconds())
[docs] async def delete_expired(self) -> int: """Delete all expired sessions. Returns: Number of sessions deleted. """ sql = f"DELETE FROM {self._table_name} WHERE julianday(expires_at) <= julianday('now')" async with self._config.provide_connection() as conn: cursor = await conn.execute(sql) await conn.commit() count = cursor.rowcount if count > 0: logger.debug("Cleaned up %d expired sessions", count) return count