from typing import TYPE_CHECKING, Any
from starlette.middleware.base import BaseHTTPMiddleware
from sqlspec.core import CorrelationExtractor
from sqlspec.extensions.starlette._utils import get_state_value, pop_state_value, set_state_value
from sqlspec.utils.correlation import CorrelationContext
if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
from sqlspec.extensions.starlette._state import SQLSpecConfigState
__all__ = ("CorrelationMiddleware", "SQLSpecAutocommitMiddleware", "SQLSpecManualMiddleware")
HTTP_200_OK = 200
HTTP_300_MULTIPLE_CHOICES = 300
HTTP_400_BAD_REQUEST = 400
[docs]
class SQLSpecManualMiddleware(BaseHTTPMiddleware):
"""Middleware for manual transaction mode.
Acquires connection from pool, stores in request.state, releases after request.
No automatic commit or rollback - user code must handle transactions.
"""
[docs]
def __init__(self, app: Any, config_state: "SQLSpecConfigState") -> None:
"""Initialize middleware.
Args:
app: Starlette application instance.
config_state: Configuration state for this database.
"""
super().__init__(app)
self.config_state = config_state
[docs]
async def dispatch(self, request: "Request", call_next: Any) -> Any:
"""Process request with manual transaction mode.
Args:
request: Incoming HTTP request.
call_next: Next middleware or route handler.
Returns:
HTTP response.
"""
config = self.config_state.config
connection_key = self.config_state.connection_key
if config.supports_connection_pooling:
pool = get_state_value(request.app.state, self.config_state.pool_key)
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
set_state_value(request.state, connection_key, connection)
try:
return await call_next(request)
finally:
pop_state_value(request.state, connection_key)
else:
connection = await config.create_connection()
set_state_value(request.state, connection_key, connection)
try:
return await call_next(request)
finally:
await connection.close()
[docs]
class SQLSpecAutocommitMiddleware(BaseHTTPMiddleware):
"""Middleware for autocommit transaction mode.
Acquires connection, commits on success status codes, rollbacks on error status codes.
"""
[docs]
def __init__(self, app: Any, config_state: "SQLSpecConfigState", include_redirect: bool = False) -> None:
"""Initialize middleware.
Args:
app: Starlette application instance.
config_state: Configuration state for this database.
include_redirect: If True, commit on 3xx status codes as well.
"""
super().__init__(app)
self.config_state = config_state
self.include_redirect = include_redirect
[docs]
async def dispatch(self, request: "Request", call_next: Any) -> Any:
"""Process request with autocommit transaction mode.
Args:
request: Incoming HTTP request.
call_next: Next middleware or route handler.
Returns:
HTTP response.
"""
config = self.config_state.config
connection_key = self.config_state.connection_key
if config.supports_connection_pooling:
pool = get_state_value(request.app.state, self.config_state.pool_key)
async with config.provide_connection(pool) as connection: # type: ignore[union-attr]
set_state_value(request.state, connection_key, connection)
try:
response = await call_next(request)
if self._should_commit(response.status_code):
await connection.commit()
else:
await connection.rollback()
except Exception:
await connection.rollback()
raise
else:
return response
finally:
pop_state_value(request.state, connection_key)
else:
connection = await config.create_connection()
set_state_value(request.state, connection_key, connection)
try:
response = await call_next(request)
if self._should_commit(response.status_code):
await connection.commit()
else:
await connection.rollback()
except Exception:
await connection.rollback()
raise
else:
return response
finally:
await connection.close()
def _should_commit(self, status_code: int) -> bool:
"""Determine if response status code should trigger commit.
Args:
status_code: HTTP status code.
Returns:
True if should commit, False if should rollback.
"""
extra_commit = self.config_state.extra_commit_statuses or set()
extra_rollback = self.config_state.extra_rollback_statuses or set()
if status_code in extra_commit:
return True
if status_code in extra_rollback:
return False
if HTTP_200_OK <= status_code < HTTP_300_MULTIPLE_CHOICES:
return True
return bool(self.include_redirect and HTTP_300_MULTIPLE_CHOICES <= status_code < HTTP_400_BAD_REQUEST)
[docs]
class CorrelationMiddleware(BaseHTTPMiddleware):
"""Middleware for correlation ID extraction and propagation.
Extracts correlation IDs from request headers (or generates new ones)
and propagates them through the request lifecycle via CorrelationContext.
The middleware:
1. Extracts correlation ID from configurable headers
2. Sets it in the CorrelationContext for async/sync access
3. Stores it in request.state.correlation_id
4. Adds X-Correlation-ID header to the response
5. Cleans up the context on request completion
Example:
```python
from starlette.applications import Starlette
from sqlspec.extensions.starlette.middleware import (
CorrelationMiddleware,
)
app = Starlette()
app.add_middleware(
CorrelationMiddleware,
primary_header="x-request-id",
auto_trace_headers=True,
)
```
"""
[docs]
def __init__(
self,
app: Any,
*,
primary_header: str = "x-request-id",
additional_headers: tuple[str, ...] | None = None,
auto_trace_headers: bool = True,
max_length: int = 128,
) -> None:
"""Initialize correlation middleware.
Args:
app: Starlette application instance.
primary_header: The primary header to check first. Defaults to "x-request-id".
additional_headers: Additional headers to check after the primary header.
auto_trace_headers: If True, include standard trace context headers as fallbacks.
max_length: Maximum length for correlation IDs. Defaults to 128.
"""
super().__init__(app)
self._extractor = CorrelationExtractor(
primary_header=primary_header,
additional_headers=additional_headers,
auto_trace_headers=auto_trace_headers,
max_length=max_length,
)
[docs]
async def dispatch(self, request: "Request", call_next: Any) -> "Response":
"""Extract correlation ID and propagate through request lifecycle.
Args:
request: Incoming HTTP request.
call_next: Next middleware or route handler.
Returns:
HTTP response with X-Correlation-ID header.
"""
correlation_id = self._extractor.extract(lambda h: request.headers.get(h))
previous_id = CorrelationContext.get()
CorrelationContext.set(correlation_id)
set_state_value(request.state, "correlation_id", correlation_id)
try:
response: Response = await call_next(request)
response.headers["X-Correlation-ID"] = correlation_id
return response
finally:
CorrelationContext.set(previous_id)
pop_state_value(request.state, "correlation_id")