Source code for sqlspec.extensions.starlette.extension

from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any

from sqlspec.base import SQLSpec
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.starlette._state import SQLSpecConfigState
from sqlspec.extensions.starlette._utils import get_or_create_session
from sqlspec.extensions.starlette.middleware import SQLSpecAutocommitMiddleware, SQLSpecManualMiddleware
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator

    from starlette.applications import Starlette
    from starlette.requests import Request

__all__ = ("SQLSpecPlugin",)

logger = get_logger("extensions.starlette")

DEFAULT_COMMIT_MODE = "manual"
DEFAULT_CONNECTION_KEY = "db_connection"
DEFAULT_POOL_KEY = "db_pool"
DEFAULT_SESSION_KEY = "db_session"


[docs] class SQLSpecPlugin: """SQLSpec integration for Starlette applications. Provides middleware-based session management, automatic transaction handling, and connection pooling lifecycle management. Example: from starlette.applications import Starlette from sqlspec import SQLSpec from sqlspec.adapters.asyncpg import AsyncpgConfig from sqlspec.extensions.starlette import SQLSpecPlugin sqlspec = SQLSpec() sqlspec.add_config(AsyncpgConfig( bind_key="default", pool_config={"dsn": "postgresql://localhost/mydb"}, extension_config={ "starlette": { "commit_mode": "autocommit", "session_key": "db" } } )) app = Starlette() db_ext = SQLSpecPlugin(sqlspec, app) @app.route("/users") async def list_users(request): db = db_ext.get_session(request) result = await db.execute("SELECT * FROM users") return JSONResponse({"users": result.all()}) """ __slots__ = ("_config_states", "_sqlspec")
[docs] def __init__(self, sqlspec: SQLSpec, app: "Starlette | None" = None) -> None: """Initialize SQLSpec Starlette extension. Args: sqlspec: Pre-configured SQLSpec instance with registered configs. app: Optional Starlette application to initialize immediately. """ self._sqlspec = sqlspec self._config_states: list[SQLSpecConfigState] = [] for cfg in self._sqlspec.configs.values(): settings = self._extract_starlette_settings(cfg) state = self._create_config_state(cfg, settings) self._config_states.append(state) if app is not None: self.init_app(app)
def _extract_starlette_settings(self, config: Any) -> "dict[str, Any]": """Extract Starlette settings from config.extension_config. Args: config: Database configuration instance. Returns: Dictionary of Starlette-specific settings. """ starlette_config = config.extension_config.get("starlette", {}) connection_key = starlette_config.get("connection_key", DEFAULT_CONNECTION_KEY) pool_key = starlette_config.get("pool_key", DEFAULT_POOL_KEY) session_key = starlette_config.get("session_key", DEFAULT_SESSION_KEY) commit_mode = starlette_config.get("commit_mode", DEFAULT_COMMIT_MODE) if not config.supports_connection_pooling and pool_key == DEFAULT_POOL_KEY: pool_key = f"_{DEFAULT_POOL_KEY}_{id(config)}" return { "connection_key": connection_key, "pool_key": pool_key, "session_key": session_key, "commit_mode": commit_mode, "extra_commit_statuses": starlette_config.get("extra_commit_statuses"), "extra_rollback_statuses": starlette_config.get("extra_rollback_statuses"), "disable_di": starlette_config.get("disable_di", False), } def _create_config_state(self, config: Any, settings: "dict[str, Any]") -> SQLSpecConfigState: """Create configuration state object. Args: config: Database configuration instance. settings: Extracted Starlette settings. Returns: Configuration state instance. """ return SQLSpecConfigState( config=config, connection_key=settings["connection_key"], pool_key=settings["pool_key"], session_key=settings["session_key"], commit_mode=settings["commit_mode"], extra_commit_statuses=settings["extra_commit_statuses"], extra_rollback_statuses=settings["extra_rollback_statuses"], disable_di=settings["disable_di"], )
[docs] def init_app(self, app: "Starlette") -> None: """Initialize Starlette application with SQLSpec. Validates configuration, wraps lifespan, and adds middleware. Args: app: Starlette application instance. """ self._validate_unique_keys() original_lifespan = app.router.lifespan_context @asynccontextmanager async def combined_lifespan(app: "Starlette") -> "AsyncGenerator[None, None]": async with self.lifespan(app), original_lifespan(app): yield app.router.lifespan_context = combined_lifespan for config_state in self._config_states: if not config_state.disable_di: self._add_middleware(app, config_state)
def _validate_unique_keys(self) -> None: """Validate that all state keys are unique across configs. Raises: ImproperConfigurationError: If duplicate keys found. """ all_keys: set[str] = set() for state in self._config_states: keys = {state.connection_key, state.pool_key, state.session_key} duplicates = all_keys & keys if duplicates: msg = f"Duplicate state keys found: {duplicates}" raise ImproperConfigurationError(msg) all_keys.update(keys) def _add_middleware(self, app: "Starlette", config_state: SQLSpecConfigState) -> None: """Add transaction middleware for configuration. Args: app: Starlette application instance. config_state: Configuration state. """ if config_state.commit_mode == "manual": app.add_middleware(SQLSpecManualMiddleware, config_state=config_state) elif config_state.commit_mode == "autocommit": app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=False) elif config_state.commit_mode == "autocommit_include_redirect": app.add_middleware(SQLSpecAutocommitMiddleware, config_state=config_state, include_redirect=True)
[docs] @asynccontextmanager async def lifespan(self, app: "Starlette") -> "AsyncGenerator[None, None]": """Manage connection pool lifecycle. Args: app: Starlette application instance. Yields: None """ for config_state in self._config_states: if config_state.config.supports_connection_pooling: pool = await config_state.config.create_pool() setattr(app.state, config_state.pool_key, pool) try: yield finally: for config_state in self._config_states: if config_state.config.supports_connection_pooling: close_result = config_state.config.close_pool() if close_result is not None: await close_result
[docs] def get_session(self, request: "Request", key: "str | None" = None) -> Any: """Get or create database session for request. Sessions are cached per request to ensure consistency. Args: request: Starlette request instance. key: Optional session key to retrieve specific database session. Returns: Database session (driver instance). """ config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key) return get_or_create_session(request, config_state)
[docs] def get_connection(self, request: "Request", key: "str | None" = None) -> Any: """Get database connection from request state. Args: request: Starlette request instance. key: Optional session key to retrieve specific database connection. Returns: Database connection object. """ config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key) return getattr(request.state, config_state.connection_key)
def _get_config_state_by_key(self, key: str) -> SQLSpecConfigState: """Get configuration state by session key. Args: key: Session key to search for. Returns: Configuration state matching the key. Raises: ValueError: If no configuration found with the specified key. """ for state in self._config_states: if state.session_key == key: return state msg = f"No configuration found with session_key: {key}" raise ValueError(msg)