"""Flask extension for SQLSpec database integration."""
import atexit
import logging
from typing import TYPE_CHECKING, Any, Literal
from sqlspec.base import SQLSpec
from sqlspec.config import AsyncDatabaseConfig, NoPoolAsyncConfig
from sqlspec.core import CorrelationExtractor
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.flask._state import FlaskConfigState
from sqlspec.extensions.flask._utils import (
get_context_value,
get_or_create_session,
has_context_value,
pop_context_value,
set_context_value,
)
from sqlspec.utils.correlation import CorrelationContext
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.portal import PortalProvider
if TYPE_CHECKING:
from flask import Flask, Response
__all__ = ("SQLSpecPlugin",)
logger = get_logger("sqlspec.extensions.flask")
DEFAULT_COMMIT_MODE: Literal["manual"] = "manual"
DEFAULT_SESSION_KEY = "db_session"
[docs]
class SQLSpecPlugin:
"""Flask extension for SQLSpec database integration.
Provides request-scoped session management, automatic transaction handling,
and async adapter support via portal pattern.
Example:
from flask import Flask
from sqlspec import SQLSpec
from sqlspec.adapters.sqlite import SqliteConfig
from sqlspec.extensions.flask import SQLSpecPlugin
sqlspec = SQLSpec()
config = SqliteConfig(
connection_config={"database": "app.db"},
extension_config={
"flask": {
"commit_mode": "autocommit",
"session_key": "db"
}
}
)
sqlspec.add_config(config)
app = Flask(__name__)
plugin = SQLSpecPlugin(sqlspec, app)
@app.route("/users")
def list_users():
db = plugin.get_session()
result = db.execute("SELECT * FROM users")
return {"users": result.all()}
"""
[docs]
def __init__(self, sqlspec: SQLSpec, app: "Flask | None" = None) -> None:
"""Initialize Flask extension with SQLSpec instance.
Args:
sqlspec: SQLSpec instance with registered configs.
app: Optional Flask application to initialize immediately.
"""
self._sqlspec = sqlspec
self._config_states: list[FlaskConfigState] = []
self._portal: PortalProvider | None = None
self._has_async_configs = False
self._cleanup_registered = False
self._shutdown_complete = False
self._enable_correlation = False
self._extractor: CorrelationExtractor | None = None
for cfg in self._sqlspec.configs.values():
state = self._create_config_state(cfg)
self._config_states.append(state)
if state.is_async:
self._has_async_configs = True
if state.enable_correlation_middleware and not self._enable_correlation:
self._enable_correlation = True
self._extractor = CorrelationExtractor(
primary_header=state.correlation_header,
additional_headers=state.correlation_headers,
auto_trace_headers=state.auto_trace_headers,
)
if app is not None:
self.init_app(app)
def _create_config_state(self, config: Any) -> FlaskConfigState:
"""Create configuration state from database config.
Args:
config: Database configuration instance.
Returns:
FlaskConfigState instance.
"""
flask_config = config.extension_config.get("flask", {})
session_key = flask_config.get("session_key", DEFAULT_SESSION_KEY)
connection_key = flask_config.get("connection_key", f"sqlspec_connection_{session_key}")
commit_mode = flask_config.get("commit_mode", DEFAULT_COMMIT_MODE)
extra_commit_statuses = flask_config.get("extra_commit_statuses")
extra_rollback_statuses = flask_config.get("extra_rollback_statuses")
disable_di = flask_config.get("disable_di", False)
enable_correlation = flask_config.get("enable_correlation_middleware", False)
correlation_header = flask_config.get("correlation_header", "x-request-id")
correlation_headers = flask_config.get("correlation_headers")
if correlation_headers is not None:
correlation_headers = tuple(correlation_headers)
auto_trace_headers = flask_config.get("auto_trace_headers", True)
is_async = isinstance(config, (AsyncDatabaseConfig, NoPoolAsyncConfig))
return FlaskConfigState(
config=config,
connection_key=connection_key,
session_key=session_key,
commit_mode=commit_mode,
extra_commit_statuses=extra_commit_statuses,
extra_rollback_statuses=extra_rollback_statuses,
is_async=is_async,
disable_di=disable_di,
enable_correlation_middleware=enable_correlation,
correlation_header=correlation_header,
correlation_headers=correlation_headers,
auto_trace_headers=auto_trace_headers,
)
[docs]
def init_app(self, app: "Flask") -> None:
"""Initialize Flask application with SQLSpec.
Validates configuration, creates portal if needed, creates pools,
and registers hooks.
Args:
app: Flask application to initialize.
Raises:
ImproperConfigurationError: If extension already registered or keys not unique.
"""
if "sqlspec" in app.extensions:
msg = "SQLSpec extension already registered on this Flask application"
raise ImproperConfigurationError(msg)
self._validate_unique_keys()
if self._has_async_configs:
self._portal = PortalProvider()
self._portal.start()
log_with_context(logger, logging.DEBUG, "extension.init", framework="flask", stage="portal_started")
pools: dict[str, Any] = {}
for config_state in self._config_states:
if config_state.config.supports_connection_pooling:
if config_state.is_async:
pool = self._portal.portal.call(config_state.config.create_pool) # type: ignore[union-attr,arg-type]
else:
pool = config_state.config.create_pool()
pools[config_state.session_key] = pool
log_with_context(
logger, logging.DEBUG, "session.create", framework="flask", session_key=config_state.session_key
)
app.extensions["sqlspec"] = {"plugin": self, "pools": pools}
if any(not state.disable_di for state in self._config_states):
app.before_request(self._before_request_handler)
app.after_request(self._after_request_handler)
app.teardown_appcontext(self._teardown_appcontext_handler)
self._register_shutdown_hook()
log_with_context(
logger,
logging.DEBUG,
"extension.init",
framework="flask",
stage="configured",
config_count=len(self._config_states),
async_enabled=self._has_async_configs,
)
def _validate_unique_keys(self) -> None:
"""Validate that all state keys are unique across configs.
Raises:
ImproperConfigurationError: If duplicate keys found.
"""
all_keys: set[str] = set()
for state in self._config_states:
keys = {state.connection_key, state.session_key}
duplicates = all_keys & keys
if duplicates:
msg = f"Duplicate state keys found: {duplicates}. Use unique session_key values."
raise ImproperConfigurationError(msg)
all_keys.update(keys)
def _register_shutdown_hook(self) -> None:
"""Register shutdown hook for pool and portal cleanup."""
if self._cleanup_registered:
return
atexit.register(self.shutdown)
self._cleanup_registered = True
def _before_request_handler(self) -> None:
"""Acquire connection before request.
Stores connection in Flask g object for each configured database.
Also stores context managers for proper cleanup.
Extracts correlation ID if correlation middleware is enabled.
"""
from flask import current_app, g, request
if self._enable_correlation and self._extractor is not None:
correlation_id = self._extractor.extract(lambda h: request.headers.get(h))
set_context_value(g, "correlation_id", correlation_id)
CorrelationContext.set(correlation_id)
for config_state in self._config_states:
if config_state.disable_di:
continue
if config_state.config.supports_connection_pooling:
pool = current_app.extensions["sqlspec"]["pools"][config_state.session_key]
conn_ctx = config_state.config.provide_connection(pool)
if config_state.is_async:
connection = self._portal.portal.call(conn_ctx.__aenter__) # type: ignore[union-attr]
else:
connection = conn_ctx.__enter__() # type: ignore[union-attr]
set_context_value(g, f"{config_state.connection_key}_ctx", conn_ctx)
elif config_state.is_async:
connection = self._portal.portal.call(config_state.config.create_connection) # type: ignore[union-attr,arg-type]
else:
connection = config_state.config.create_connection()
set_context_value(g, config_state.connection_key, connection)
def _after_request_handler(self, response: "Response") -> "Response":
"""Handle transaction after request based on response status.
Args:
response: Flask response object.
Returns:
Response object with correlation ID header if enabled.
"""
from flask import g
if self._enable_correlation:
correlation_id = get_context_value(g, "correlation_id", None)
if correlation_id:
response.headers["X-Correlation-ID"] = correlation_id
for config_state in self._config_states:
if config_state.disable_di:
continue
if config_state.commit_mode == "manual":
continue
cache_key = f"sqlspec_session_cache_{config_state.session_key}"
session = get_context_value(g, cache_key, None)
if session is None:
continue
if config_state.should_commit(response.status_code):
self._execute_commit(session, config_state)
elif config_state.should_rollback(response.status_code):
self._execute_rollback(session, config_state)
return response
def _teardown_appcontext_handler(self, _exc: "BaseException | None" = None) -> None:
"""Clean up connections when request context ends.
Closes all connections, cleans up g object, and clears correlation context.
Args:
_exc: Exception that occurred (if any).
"""
from flask import g
if self._enable_correlation:
CorrelationContext.clear()
if has_context_value(g, "correlation_id"):
pop_context_value(g, "correlation_id")
for config_state in self._config_states:
if config_state.disable_di:
continue
connection = get_context_value(g, config_state.connection_key, None)
ctx_key = f"{config_state.connection_key}_ctx"
conn_ctx = get_context_value(g, ctx_key, None)
if connection is not None:
try:
if conn_ctx is not None:
if config_state.is_async:
self._portal.portal.call(conn_ctx.__aexit__, None, None, None) # type: ignore[union-attr]
else:
conn_ctx.__exit__(None, None, None)
elif config_state.is_async:
self._portal.portal.call(connection.close) # type: ignore[union-attr]
else:
connection.close()
except Exception as exc:
log_with_context(
logger,
logging.ERROR,
"session.close",
framework="flask",
session_key=config_state.session_key,
operation="connection",
status="failed",
error_type=type(exc).__name__,
)
if has_context_value(g, config_state.connection_key):
pop_context_value(g, config_state.connection_key)
if has_context_value(g, ctx_key):
pop_context_value(g, ctx_key)
cache_key = f"sqlspec_session_cache_{config_state.session_key}"
if has_context_value(g, cache_key):
pop_context_value(g, cache_key)
[docs]
def get_session(self, key: "str | None" = None) -> Any:
"""Get or create database session for current request.
Sessions are cached per request for consistency.
Args:
key: Session key for multi-database configs. Defaults to first config if None.
Returns:
Database session (driver instance).
"""
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
return get_or_create_session(config_state, self._portal.portal if self._portal else None)
[docs]
def get_connection(self, key: "str | None" = None) -> Any:
"""Get database connection for current request.
Args:
key: Session key for multi-database configs. Defaults to first config if None.
Returns:
Raw database connection.
"""
from flask import g
config_state = self._config_states[0] if key is None else self._get_config_state_by_key(key)
return get_context_value(g, config_state.connection_key)
def _get_config_state_by_key(self, key: str) -> FlaskConfigState:
"""Get config state by session key.
Args:
key: Session key to look up.
Returns:
FlaskConfigState for the key.
Raises:
ImproperConfigurationError: If key not found.
"""
for state in self._config_states:
if state.session_key == key:
return state
msg = f"No configuration found for key: {key}"
raise ImproperConfigurationError(msg)
[docs]
def shutdown(self) -> None:
"""Dispose connection pools and stop async portal."""
if self._shutdown_complete:
return
self._shutdown_complete = True
for config_state in self._config_states:
if config_state.config.supports_connection_pooling:
self._close_pool_state(config_state)
if self._portal is not None:
try:
self._portal.stop()
except Exception as exc:
log_with_context(
logger,
logging.ERROR,
"extension.init",
framework="flask",
stage="shutdown",
status="failed",
error_type=type(exc).__name__,
)
finally:
self._portal = None
def _close_pool_state(self, config_state: FlaskConfigState) -> None:
"""Close pool associated with configuration state."""
try:
if config_state.is_async:
if self._portal is None:
log_with_context(
logger,
logging.DEBUG,
"session.close",
framework="flask",
session_key=config_state.session_key,
operation="pool",
status="skipped",
reason="portal_not_initialized",
)
return
_ = self._portal.portal.call(config_state.config.close_pool) # type: ignore[arg-type]
else:
config_state.config.close_pool()
log_with_context(
logger,
logging.DEBUG,
"session.close",
framework="flask",
session_key=config_state.session_key,
operation="pool",
status="complete",
)
except Exception as exc:
log_with_context(
logger,
logging.ERROR,
"session.close",
framework="flask",
session_key=config_state.session_key,
operation="pool",
status="failed",
error_type=type(exc).__name__,
)
def _execute_commit(self, session: Any, config_state: FlaskConfigState) -> None:
"""Execute commit on session.
Args:
session: Database session.
config_state: Configuration state.
"""
try:
if config_state.is_async:
connection = self.get_connection(config_state.session_key)
self._portal.portal.call(connection.commit) # type: ignore[union-attr]
else:
connection = self.get_connection(config_state.session_key)
connection.commit()
except Exception as exc:
log_with_context(
logger,
logging.ERROR,
"session.close",
framework="flask",
session_key=config_state.session_key,
operation="commit",
status="failed",
error_type=type(exc).__name__,
)
def _execute_rollback(self, session: Any, config_state: FlaskConfigState) -> None:
"""Execute rollback on session.
Args:
session: Database session.
config_state: Configuration state.
"""
try:
if config_state.is_async:
connection = self.get_connection(config_state.session_key)
self._portal.portal.call(connection.rollback) # type: ignore[union-attr]
else:
connection = self.get_connection(config_state.session_key)
connection.rollback()
except Exception as exc:
log_with_context(
logger,
logging.DEBUG,
"session.close",
framework="flask",
session_key=config_state.session_key,
operation="rollback",
status="failed",
error_type=type(exc).__name__,
)