from typing import TYPE_CHECKING, Any
from starlette.middleware.base import BaseHTTPMiddleware
from sqlspec.utils.logging import get_logger
if TYPE_CHECKING:
from starlette.requests import Request
from sqlspec.extensions.starlette._state import SQLSpecConfigState
__all__ = ("SQLSpecAutocommitMiddleware", "SQLSpecManualMiddleware")
logger = get_logger("extensions.starlette.middleware")
HTTP_200_OK = 200
HTTP_300_MULTIPLE_CHOICES = 300
HTTP_400_BAD_REQUEST = 400
[docs]
class SQLSpecManualMiddleware(BaseHTTPMiddleware):
"""Middleware for manual transaction mode.
Acquires connection from pool, stores in request.state, releases after request.
No automatic commit or rollback - user code must handle transactions.
"""
[docs]
def __init__(self, app: Any, config_state: "SQLSpecConfigState") -> None:
"""Initialize middleware.
Args:
app: Starlette application instance.
config_state: Configuration state for this database.
"""
super().__init__(app)
self.config_state = config_state
[docs]
async def dispatch(self, request: "Request", call_next: Any) -> Any:
"""Process request with manual transaction mode.
Args:
request: Incoming HTTP request.
call_next: Next middleware or route handler.
Returns:
HTTP response.
"""
config = self.config_state.config
connection_key = self.config_state.connection_key
if config.supports_connection_pooling:
pool = getattr(request.app.state, self.config_state.pool_key)
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
setattr(request.state, connection_key, connection)
try:
return await call_next(request)
finally:
delattr(request.state, connection_key)
else:
connection = await config.create_connection()
setattr(request.state, connection_key, connection)
try:
return await call_next(request)
finally:
await connection.close()
[docs]
class SQLSpecAutocommitMiddleware(BaseHTTPMiddleware):
"""Middleware for autocommit transaction mode.
Acquires connection, commits on success status codes, rollbacks on error status codes.
"""
[docs]
def __init__(self, app: Any, config_state: "SQLSpecConfigState", include_redirect: bool = False) -> None:
"""Initialize middleware.
Args:
app: Starlette application instance.
config_state: Configuration state for this database.
include_redirect: If True, commit on 3xx status codes as well.
"""
super().__init__(app)
self.config_state = config_state
self.include_redirect = include_redirect
[docs]
async def dispatch(self, request: "Request", call_next: Any) -> Any:
"""Process request with autocommit transaction mode.
Args:
request: Incoming HTTP request.
call_next: Next middleware or route handler.
Returns:
HTTP response.
"""
config = self.config_state.config
connection_key = self.config_state.connection_key
if config.supports_connection_pooling:
pool = getattr(request.app.state, self.config_state.pool_key)
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
setattr(request.state, connection_key, connection)
try:
response = await call_next(request)
if self._should_commit(response.status_code):
await connection.commit()
else:
await connection.rollback()
except Exception:
await connection.rollback()
raise
else:
return response
finally:
delattr(request.state, connection_key)
else:
connection = await config.create_connection()
setattr(request.state, connection_key, connection)
try:
response = await call_next(request)
if self._should_commit(response.status_code):
await connection.commit()
else:
await connection.rollback()
except Exception:
await connection.rollback()
raise
else:
return response
finally:
await connection.close()
def _should_commit(self, status_code: int) -> bool:
"""Determine if response status code should trigger commit.
Args:
status_code: HTTP status code.
Returns:
True if should commit, False if should rollback.
"""
extra_commit = self.config_state.extra_commit_statuses or set()
extra_rollback = self.config_state.extra_rollback_statuses or set()
if status_code in extra_commit:
return True
if status_code in extra_rollback:
return False
if HTTP_200_OK <= status_code < HTTP_300_MULTIPLE_CHOICES:
return True
return bool(self.include_redirect and HTTP_300_MULTIPLE_CHOICES <= status_code < HTTP_400_BAD_REQUEST)