"""Base session store classes for Litestar integration."""
import re
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast
from sqlspec.utils.logging import get_logger
if TYPE_CHECKING:
from types import TracebackType
ConfigT = TypeVar("ConfigT")
logger = get_logger("extensions.litestar.store")
__all__ = ("BaseSQLSpecStore",)
VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
MAX_TABLE_NAME_LENGTH: Final = 63
[docs]
class BaseSQLSpecStore(ABC, Generic[ConfigT]):
"""Base class for SQLSpec-backed Litestar session stores.
Implements the litestar.stores.base.Store protocol for server-side session
storage using SQLSpec database adapters.
This abstract base class provides common functionality for all database-specific
store implementations including:
- Connection management via SQLSpec configs
- Session expiration calculation
- Table creation utilities
Subclasses must implement dialect-specific SQL queries.
Args:
config: SQLSpec database configuration with extension_config["litestar"] settings.
Example:
from sqlspec.adapters.asyncpg import AsyncpgConfig
from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore
config = AsyncpgConfig(
pool_config={"dsn": "postgresql://..."},
extension_config={"litestar": {"session_table": "my_sessions"}}
)
store = AsyncpgStore(config)
await store.create_table()
Notes:
Configuration is read from config.extension_config["litestar"]:
- session_table: Table name (default: "litestar_session")
"""
__slots__ = ("_config", "_table_name")
[docs]
def __init__(self, config: ConfigT) -> None:
"""Initialize the session store.
Args:
config: SQLSpec database configuration.
Notes:
Reads table_name from config.extension_config["litestar"]["session_table"].
Defaults to "litestar_session" if not specified.
"""
self._config = config
self._table_name = self._get_table_name_from_config()
self._validate_table_name(self._table_name)
def _get_table_name_from_config(self) -> str:
"""Extract table name from config.extension_config.
Returns:
Table name for the session store.
"""
if hasattr(self._config, "extension_config"):
extension_config = cast("dict[str, dict[str, Any]]", self._config.extension_config) # pyright: ignore
litestar_config: dict[str, Any] = extension_config.get("litestar", {})
return str(litestar_config.get("session_table", "litestar_session"))
return "litestar_session"
@property
def config(self) -> ConfigT:
"""Return the database configuration."""
return self._config
@property
def table_name(self) -> str:
"""Return the session table name."""
return self._table_name
[docs]
@abstractmethod
async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None":
"""Get a session value by key.
Args:
key: Session ID to retrieve.
renew_for: If given and the value had an initial expiry time set, renew the
expiry time for ``renew_for`` seconds. If the value has not been set
with an expiry time this is a no-op.
Returns:
Session data as bytes if found and not expired, None otherwise.
"""
raise NotImplementedError
[docs]
@abstractmethod
async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None:
"""Store a session value.
Args:
key: Session ID.
value: Session data (will be converted to bytes if string).
expires_in: Time in seconds or timedelta before expiration.
"""
raise NotImplementedError
[docs]
@abstractmethod
async def delete(self, key: str) -> None:
"""Delete a session by key.
Args:
key: Session ID to delete.
"""
raise NotImplementedError
[docs]
@abstractmethod
async def delete_all(self) -> None:
"""Delete all sessions from the store."""
raise NotImplementedError
[docs]
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if a session key exists and is not expired.
Args:
key: Session ID to check.
Returns:
True if the session exists and is not expired.
"""
raise NotImplementedError
[docs]
@abstractmethod
async def expires_in(self, key: str) -> "int | None":
"""Get the time in seconds until the session expires.
Args:
key: Session ID to check.
Returns:
Seconds until expiration, or None if no expiry or key doesn't exist.
"""
raise NotImplementedError
[docs]
@abstractmethod
async def delete_expired(self) -> int:
"""Delete all expired sessions.
Returns:
Number of sessions deleted.
"""
raise NotImplementedError
[docs]
@abstractmethod
async def create_table(self) -> None:
"""Create the session table if it doesn't exist."""
raise NotImplementedError
@abstractmethod
def _get_create_table_sql(self) -> str:
"""Get the CREATE TABLE SQL for this database dialect.
Returns:
SQL statement to create the sessions table.
"""
raise NotImplementedError
@abstractmethod
def _get_drop_table_sql(self) -> "list[str]":
"""Get the DROP TABLE SQL statements for this database dialect.
Returns:
List of SQL statements to drop the table and all indexes.
Order matters: drop indexes before table.
Notes:
Should use IF EXISTS or dialect-specific error handling
to allow idempotent migrations.
"""
raise NotImplementedError
[docs]
async def __aenter__(self) -> "BaseSQLSpecStore":
"""Enter context manager."""
return self
[docs]
async def __aexit__(
self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None"
) -> None:
"""Exit context manager."""
return
def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None":
"""Calculate expiration timestamp from expires_in.
Args:
expires_in: Seconds or timedelta until expiration.
Returns:
UTC datetime of expiration, or None if no expiration.
"""
if expires_in is None:
return None
expires_in_seconds = int(expires_in.total_seconds()) if isinstance(expires_in, timedelta) else expires_in
if expires_in_seconds <= 0:
return None
return datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds)
def _value_to_bytes(self, value: "str | bytes") -> bytes:
"""Convert value to bytes if needed.
Args:
value: String or bytes value.
Returns:
Value as bytes.
"""
if isinstance(value, str):
return value.encode("utf-8")
return value
@staticmethod
def _validate_table_name(table_name: str) -> None:
"""Validate table name for SQL safety.
Args:
table_name: Table name to validate.
Raises:
ValueError: If table name is invalid.
Notes:
- Must start with letter or underscore
- Can only contain letters, numbers, and underscores
- Maximum length is 63 characters (PostgreSQL limit)
- Prevents SQL injection in table names
"""
if not table_name:
msg = "Table name cannot be empty"
raise ValueError(msg)
if len(table_name) > MAX_TABLE_NAME_LENGTH:
msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})"
raise ValueError(msg)
if not VALID_TABLE_NAME_PATTERN.match(table_name):
msg = f"Invalid table name: {table_name!r}. Must start with letter/underscore and contain only alphanumeric characters and underscores"
raise ValueError(msg)