"""AsyncPG session store for Litestar integration."""
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from sqlspec.extensions.litestar.store import BaseSQLSpecStore
from sqlspec.utils.logging import get_logger
if TYPE_CHECKING:
from sqlspec.adapters.asyncpg.config import AsyncpgConfig
logger = get_logger("adapters.asyncpg.litestar.store")
__all__ = ("AsyncpgStore",)
[docs]
class AsyncpgStore(BaseSQLSpecStore["AsyncpgConfig"]):
"""PostgreSQL session store using AsyncPG driver.
Implements server-side session storage for Litestar using PostgreSQL
via the AsyncPG driver. Provides efficient session management with:
- Native async PostgreSQL operations
- UPSERT support using ON CONFLICT
- Automatic expiration handling
- Efficient cleanup of expired sessions
Args:
config: AsyncpgConfig instance with extension_config["litestar"] settings.
Example:
from sqlspec.adapters.asyncpg import AsyncpgConfig
from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore
config = AsyncpgConfig(
pool_config={"dsn": "postgresql://..."},
extension_config={"litestar": {"session_table": "my_sessions"}}
)
store = AsyncpgStore(config)
await store.create_table()
"""
__slots__ = ()
[docs]
def __init__(self, config: "AsyncpgConfig") -> None:
"""Initialize AsyncPG session store.
Args:
config: AsyncpgConfig instance.
Notes:
Table name is read from config.extension_config["litestar"]["session_table"].
"""
super().__init__(config)
def _get_create_table_sql(self) -> str:
"""Get PostgreSQL CREATE TABLE SQL with optimized schema.
Returns:
SQL statement to create the sessions table with proper indexes.
Notes:
- Uses TIMESTAMPTZ for timezone-aware expiration timestamps
- Partial index WHERE expires_at IS NOT NULL reduces index size/maintenance
- FILLFACTOR 80 leaves space for HOT updates, reducing table bloat
- Audit columns (created_at, updated_at) help with debugging
- Table name is internally controlled, not user input (S608 suppressed)
"""
return f"""
CREATE TABLE IF NOT EXISTS {self._table_name} (
session_id TEXT PRIMARY KEY,
data BYTEA NOT NULL,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
) WITH (fillfactor = 80);
CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at
ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL;
ALTER TABLE {self._table_name} SET (
autovacuum_vacuum_scale_factor = 0.05,
autovacuum_analyze_scale_factor = 0.02
);
"""
def _get_drop_table_sql(self) -> "list[str]":
"""Get PostgreSQL DROP TABLE SQL statements.
Returns:
List of SQL statements to drop indexes and table.
"""
return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"]
[docs]
async def create_table(self) -> None:
"""Create the session table if it doesn't exist."""
sql = self._get_create_table_sql()
async with self._config.provide_session() as driver:
await driver.execute_script(sql)
logger.debug("Created session table: %s", self._table_name)
[docs]
async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None":
"""Get a session value by key.
Args:
key: Session ID to retrieve.
renew_for: If given, renew the expiry time for this duration.
Returns:
Session data as bytes if found and not expired, None otherwise.
Notes:
Uses CURRENT_TIMESTAMP instead of NOW() for SQL standard compliance.
The query planner can use the partial index for expires_at > CURRENT_TIMESTAMP.
"""
sql = f"""
SELECT data, expires_at FROM {self._table_name}
WHERE session_id = $1
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)
"""
async with self._config.provide_connection() as conn:
row = await conn.fetchrow(sql, key)
if row is None:
return None
if renew_for is not None and row["expires_at"] is not None:
new_expires_at = self._calculate_expires_at(renew_for)
if new_expires_at is not None:
update_sql = f"""
UPDATE {self._table_name}
SET expires_at = $1, updated_at = CURRENT_TIMESTAMP
WHERE session_id = $2
"""
await conn.execute(update_sql, new_expires_at, key)
return bytes(row["data"])
[docs]
async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None:
"""Store a session value.
Args:
key: Session ID.
value: Session data.
expires_in: Time until expiration.
Notes:
Uses EXCLUDED to reference the proposed insert values in ON CONFLICT.
Updates updated_at timestamp on every write for audit trail.
"""
data = self._value_to_bytes(value)
expires_at = self._calculate_expires_at(expires_in)
sql = f"""
INSERT INTO {self._table_name} (session_id, data, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (session_id)
DO UPDATE SET
data = EXCLUDED.data,
expires_at = EXCLUDED.expires_at,
updated_at = CURRENT_TIMESTAMP
"""
async with self._config.provide_connection() as conn:
await conn.execute(sql, key, data, expires_at)
[docs]
async def delete(self, key: str) -> None:
"""Delete a session by key.
Args:
key: Session ID to delete.
"""
sql = f"DELETE FROM {self._table_name} WHERE session_id = $1"
async with self._config.provide_connection() as conn:
await conn.execute(sql, key)
[docs]
async def delete_all(self) -> None:
"""Delete all sessions from the store."""
sql = f"DELETE FROM {self._table_name}"
async with self._config.provide_connection() as conn:
await conn.execute(sql)
logger.debug("Deleted all sessions from table: %s", self._table_name)
[docs]
async def exists(self, key: str) -> bool:
"""Check if a session key exists and is not expired.
Args:
key: Session ID to check.
Returns:
True if the session exists and is not expired.
Notes:
Uses CURRENT_TIMESTAMP for consistency with get() method.
"""
sql = f"""
SELECT 1 FROM {self._table_name}
WHERE session_id = $1
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP)
"""
async with self._config.provide_connection() as conn:
result = await conn.fetchval(sql, key)
return result is not None
[docs]
async def expires_in(self, key: str) -> "int | None":
"""Get the time in seconds until the session expires.
Args:
key: Session ID to check.
Returns:
Seconds until expiration, or None if no expiry or key doesn't exist.
"""
sql = f"""
SELECT expires_at FROM {self._table_name}
WHERE session_id = $1
"""
async with self._config.provide_connection() as conn:
expires_at = await conn.fetchval(sql, key)
if expires_at is None:
return None
now = datetime.now(timezone.utc)
if expires_at <= now:
return 0
delta = expires_at - now
return int(delta.total_seconds())
[docs]
async def delete_expired(self) -> int:
"""Delete all expired sessions.
Returns:
Number of sessions deleted.
Notes:
Uses CURRENT_TIMESTAMP for consistency.
For very large tables (10M+ rows), consider batching deletes
to avoid holding locks too long.
"""
sql = f"DELETE FROM {self._table_name} WHERE expires_at <= CURRENT_TIMESTAMP"
async with self._config.provide_connection() as conn:
result = await conn.execute(sql)
count = int(result.split()[-1])
if count > 0:
logger.debug("Cleaned up %d expired sessions", count)
return count