Source code for sqlspec.extensions.litestar.providers

# ruff: noqa: B008
"""Application dependency providers generators.

This module contains functions to create dependency providers for services and filters.
"""

import datetime
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypedDict, cast
from uuid import UUID

from litestar.di import Provide
from litestar.exceptions import ValidationException
from litestar.params import Dependency, Parameter
from typing_extensions import NotRequired

from sqlspec.core import (
    BeforeAfterFilter,
    FilterTypes,
    InCollectionFilter,
    LimitOffsetFilter,
    NotInCollectionFilter,
    NotNullFilter,
    NullFilter,
    OrderByFilter,
    SearchFilter,
)
from sqlspec.utils.singleton import SingletonMeta
from sqlspec.utils.text import camelize

if TYPE_CHECKING:
    from sqlglot import exp

__all__ = (
    "DEPENDENCY_DEFAULTS",
    "BooleanOrNone",
    "DTorNone",
    "DependencyDefaults",
    "FieldNameType",
    "FilterConfig",
    "HashableType",
    "HashableValue",
    "IntOrNone",
    "SortField",
    "SortOrder",
    "SortOrderOrNone",
    "StringOrNone",
    "UuidOrNone",
    "create_filter_dependencies",
    "dep_cache",
)

DTorNone = datetime.datetime | None
StringOrNone = str | None
UuidOrNone = UUID | None
IntOrNone = int | None
BooleanOrNone = bool | None
SortOrder = Literal["asc", "desc"]
SortOrderOrNone = SortOrder | None
SortField = str | set[str] | list[str]
HashableValue = str | int | float | bool | None
HashableType = HashableValue | tuple[Any, ...] | tuple[tuple[str, Any], ...] | tuple[HashableValue, ...]


[docs] class DependencyDefaults: FILTERS_DEPENDENCY_KEY: str = "filters" CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter" ID_FILTER_DEPENDENCY_KEY: str = "id_filter" LIMIT_OFFSET_FILTER_DEPENDENCY_KEY: str = "limit_offset_filter" UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter" ORDER_BY_FILTER_DEPENDENCY_KEY: str = "order_by_filter" SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter" DEFAULT_PAGINATION_SIZE: int = 20
DEPENDENCY_DEFAULTS = DependencyDefaults() class FieldNameType(NamedTuple): """Type for field name and associated type information for filter configuration.""" name: str type_hint: type[Any] = str
[docs] class FilterConfig(TypedDict): """Configuration for generating dynamic filters.""" id_filter: NotRequired[type[UUID | int | str]] id_field: NotRequired[str] sort_field: NotRequired[SortField] sort_order: NotRequired[SortOrder] pagination_type: NotRequired[Literal["limit_offset"]] pagination_size: NotRequired[int] search: NotRequired[str | set[str] | list[str]] search_ignore_case: NotRequired[bool] created_at: NotRequired[bool] updated_at: NotRequired[bool] not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]] in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]] null_fields: NotRequired[str | set[str] | list[str]] """Fields that support IS NULL filtering.""" not_null_fields: NotRequired[str | set[str] | list[str]] """Fields that support IS NOT NULL filtering."""
class DependencyCache(metaclass=SingletonMeta): """Dependency cache for memoizing dynamically generated dependencies.""" def __init__(self) -> None: self.dependencies: dict[int | str, dict[str, Provide]] = {} def add_dependencies(self, key: int | str, dependencies: dict[str, Provide]) -> None: self.dependencies[key] = dependencies def get_dependencies(self, key: int | str) -> dict[str, Provide] | None: return self.dependencies.get(key) dep_cache = DependencyCache() def create_filter_dependencies( config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS ) -> dict[str, Provide]: """Create a dependency provider for the combined filter function. Args: config: FilterConfig instance with desired settings. dep_defaults: Dependency defaults to use for the filter dependencies Returns: A dependency provider function for the combined filter function. """ if (deps := dep_cache.get_dependencies(cache_key := hash(_make_hashable(config)))) is not None: return deps deps = _create_statement_filters(config, dep_defaults) dep_cache.add_dependencies(cache_key, deps) return deps def _make_hashable(value: Any) -> HashableType: """Convert a value into a hashable type for caching purposes. Args: value: Any value that needs to be made hashable. Returns: A hashable version of the value. """ if isinstance(value, dict): items = [] for k in sorted(value.keys()): # pyright: ignore v = value[k] items.append((str(k), _make_hashable(v))) return tuple(items) if isinstance(value, (list, set)): hashable_items = [_make_hashable(item) for item in value] filtered_items = [item for item in hashable_items if item is not None] return tuple(sorted(filtered_items, key=str)) if isinstance(value, (str, int, float, bool, type(None))): return value return str(value) def _resolve_sort_fields(sort_field: SortField) -> tuple[str, set[str]]: if isinstance(sort_field, str): return sort_field, {sort_field} fields = tuple(sorted(sort_field)) if isinstance(sort_field, set) else tuple(sort_field) return fields[0], set(fields) def _create_statement_filters( # noqa: C901 config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS ) -> dict[str, Provide]: """Create filter dependencies based on configuration. Args: config: Configuration dictionary specifying which filters to enable dep_defaults: Dependency defaults to use for the filter dependencies Returns: Dictionary of filter provider functions """ filters: dict[str, Provide] = {} if config.get("id_filter", False): def provide_id_filter( # pyright: ignore[reportUnknownParameterType] ids: list[str] | None = Parameter(query="ids", default=None, required=False), ) -> InCollectionFilter: # pyright: ignore[reportMissingTypeArgument] return InCollectionFilter(field_name=config.get("id_field", "id"), values=ids) filters[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = Provide(provide_id_filter, sync_to_thread=False) # pyright: ignore[reportUnknownArgumentType] if config.get("created_at", False): def provide_created_filter( before: DTorNone = Parameter(query="createdBefore", default=None, required=False), after: DTorNone = Parameter(query="createdAfter", default=None, required=False), ) -> BeforeAfterFilter: return BeforeAfterFilter("created_at", before, after) filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = Provide(provide_created_filter, sync_to_thread=False) if config.get("updated_at", False): def provide_updated_filter( before: DTorNone = Parameter(query="updatedBefore", default=None, required=False), after: DTorNone = Parameter(query="updatedAfter", default=None, required=False), ) -> BeforeAfterFilter: return BeforeAfterFilter("updated_at", before, after) filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = Provide(provide_updated_filter, sync_to_thread=False) if config.get("pagination_type") == "limit_offset": def provide_limit_offset_pagination( current_page: int = Parameter(ge=1, query="currentPage", default=1, required=False), page_size: int = Parameter( query="pageSize", ge=1, default=config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE), required=False, ), ) -> LimitOffsetFilter: return LimitOffsetFilter(page_size, page_size * (current_page - 1)) filters[dep_defaults.LIMIT_OFFSET_FILTER_DEPENDENCY_KEY] = Provide( provide_limit_offset_pagination, sync_to_thread=False ) if search_fields := config.get("search"): def provide_search_filter( search_string: StringOrNone = Parameter( title="Field to search", query="searchString", default=None, required=False ), ignore_case: BooleanOrNone = Parameter( title="Search should be case sensitive", query="searchIgnoreCase", default=config.get("search_ignore_case", False), required=False, ), ) -> SearchFilter: field_names: set[str | exp.Expression] = ( set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields) ) return SearchFilter(field_name=field_names, value=search_string, ignore_case=ignore_case or False) filters[dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY] = Provide(provide_search_filter, sync_to_thread=False) if sort_field := config.get("sort_field"): default_field, allowed_fields = _resolve_sort_fields(sort_field) allowed_field_names = ", ".join(sorted(allowed_fields)) sort_order_default = config.get("sort_order", "desc") def provide_order_by( field_name: StringOrNone = Parameter( title="Order by field", query="orderBy", default=default_field, required=False ), sort_order: SortOrderOrNone = Parameter( title="Field to search", query="sortOrder", default=sort_order_default, required=False ), ) -> OrderByFilter: resolved_field = field_name or default_field if resolved_field not in allowed_fields: msg = f"Invalid orderBy field '{resolved_field}'. Allowed fields: {allowed_field_names}" raise ValidationException(detail=msg) return OrderByFilter(field_name=resolved_field, sort_order=sort_order or sort_order_default) filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False) if not_in_fields := config.get("not_in_fields"): not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields for field_def in not_in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def def create_not_in_filter_provider( # pyright: ignore field_name: FieldNameType, ) -> Callable[..., NotInCollectionFilter[field_def.type_hint] | None]: # type: ignore def provide_not_in_filter( # pyright: ignore values: list[field_name.type_hint] | None = Parameter( # type: ignore query=camelize(f"{field_name.name}_not_in"), default=None, required=False ), ) -> NotInCollectionFilter[field_name.type_hint] | None: # type: ignore return ( NotInCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore if values else None ) return provide_not_in_filter # pyright: ignore provider = create_not_in_filter_provider(field_def) # pyright: ignore filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore if in_fields := config.get("in_fields"): in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields for field_def in in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def def create_in_filter_provider( # pyright: ignore field_name: FieldNameType, ) -> Callable[..., InCollectionFilter[field_def.type_hint] | None]: # type: ignore # pyright: ignore def provide_in_filter( # pyright: ignore values: list[field_name.type_hint] | None = Parameter( # type: ignore # pyright: ignore query=camelize(f"{field_name.name}_in"), default=None, required=False ), ) -> InCollectionFilter[field_name.type_hint] | None: # type: ignore # pyright: ignore return ( InCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore # pyright: ignore if values else None ) return provide_in_filter # pyright: ignore provider = create_in_filter_provider(field_def) # type: ignore filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore if null_fields := config.get("null_fields"): null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields) for field_name in null_fields: def create_null_filter_provider(fname: str) -> Callable[..., NullFilter | None]: def provide_null_filter( is_null: bool | None = Parameter(query=camelize(f"{fname}_is_null"), default=None, required=False), ) -> NullFilter | None: return NullFilter(field_name=fname) if is_null else None return provide_null_filter null_provider = create_null_filter_provider(field_name) filters[f"{field_name}_null_filter"] = Provide(null_provider, sync_to_thread=False) if not_null_fields := config.get("not_null_fields"): not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields) for field_name in not_null_fields: def create_not_null_filter_provider(fname: str) -> Callable[..., NotNullFilter | None]: def provide_not_null_filter( is_not_null: bool | None = Parameter( query=camelize(f"{fname}_is_not_null"), default=None, required=False ), ) -> NotNullFilter | None: return NotNullFilter(field_name=fname) if is_not_null else None return provide_not_null_filter not_null_provider = create_not_null_filter_provider(field_name) filters[f"{field_name}_not_null_filter"] = Provide(not_null_provider, sync_to_thread=False) if filters: filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide( _create_filter_aggregate_function(config), sync_to_thread=False ) return filters def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: # noqa: C901 """Create filter aggregation function based on configuration. Args: config: The filter configuration. Returns: Function that returns list of configured filters. """ parameters: dict[str, inspect.Parameter] = {} annotations: dict[str, Any] = {} if cls := config.get("id_filter"): parameters["id_filter"] = inspect.Parameter( name="id_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=InCollectionFilter[cls], # type: ignore[valid-type] ) annotations["id_filter"] = InCollectionFilter[cls] # type: ignore[valid-type] if config.get("created_at"): parameters["created_filter"] = inspect.Parameter( name="created_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=BeforeAfterFilter, ) annotations["created_filter"] = BeforeAfterFilter if config.get("updated_at"): parameters["updated_filter"] = inspect.Parameter( name="updated_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=BeforeAfterFilter, ) annotations["updated_filter"] = BeforeAfterFilter if config.get("search"): parameters["search_filter"] = inspect.Parameter( name="search_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=SearchFilter, ) annotations["search_filter"] = SearchFilter if config.get("pagination_type") == "limit_offset": parameters["limit_offset_filter"] = inspect.Parameter( name="limit_offset_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=LimitOffsetFilter, ) annotations["limit_offset_filter"] = LimitOffsetFilter if config.get("sort_field"): parameters["order_by_filter"] = inspect.Parameter( name="order_by_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=OrderByFilter, ) annotations["order_by_filter"] = OrderByFilter if not_in_fields := config.get("not_in_fields"): for field_def in not_in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def parameters[f"{field_def.name}_not_in_filter"] = inspect.Parameter( name=f"{field_def.name}_not_in_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=NotInCollectionFilter[field_def.type_hint], # type: ignore ) annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore if in_fields := config.get("in_fields"): for field_def in in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def parameters[f"{field_def.name}_in_filter"] = inspect.Parameter( name=f"{field_def.name}_in_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=InCollectionFilter[field_def.type_hint], # type: ignore ) annotations[f"{field_def.name}_in_filter"] = InCollectionFilter[field_def.type_hint] # type: ignore if null_fields := config.get("null_fields"): null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields) for field_name in null_fields: parameters[f"{field_name}_null_filter"] = inspect.Parameter( name=f"{field_name}_null_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=NullFilter | None, ) annotations[f"{field_name}_null_filter"] = NullFilter | None if not_null_fields := config.get("not_null_fields"): not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields) for field_name in not_null_fields: parameters[f"{field_name}_not_null_filter"] = inspect.Parameter( name=f"{field_name}_not_null_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=Dependency(skip_validation=True), annotation=NotNullFilter | None, ) annotations[f"{field_name}_not_null_filter"] = NotNullFilter | None def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]: """Aggregate filter dependencies based on configuration. Args: **kwargs: Filter parameters dynamically provided based on configuration. Returns: List of configured filters. """ filters: list[FilterTypes] = [] if id_filter := kwargs.get("id_filter"): filters.append(id_filter) if created_filter := kwargs.get("created_filter"): filters.append(created_filter) if limit_offset := kwargs.get("limit_offset_filter"): filters.append(limit_offset) if updated_filter := kwargs.get("updated_filter"): filters.append(updated_filter) if ( (search_filter := cast("SearchFilter | None", kwargs.get("search_filter"))) and search_filter is not None # pyright: ignore[reportUnnecessaryComparison] and search_filter.field_name is not None # pyright: ignore[reportUnnecessaryComparison] and search_filter.value is not None # pyright: ignore[reportUnnecessaryComparison] ): filters.append(search_filter) if ( (order_by := cast("OrderByFilter | None", kwargs.get("order_by_filter"))) and order_by is not None # pyright: ignore[reportUnnecessaryComparison] and order_by.field_name is not None # pyright: ignore[reportUnnecessaryComparison] ): filters.append(order_by) if not_in_fields := config.get("not_in_fields"): not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields for field_def in not_in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def filter_ = kwargs.get(f"{field_def.name}_not_in_filter") if filter_ is not None: filters.append(filter_) if in_fields := config.get("in_fields"): in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields for field_def in in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def filter_ = kwargs.get(f"{field_def.name}_in_filter") if filter_ is not None: filters.append(filter_) if null_fields := config.get("null_fields"): null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields) for field_name in null_fields: filter_ = kwargs.get(f"{field_name}_null_filter") if filter_ is not None: filters.append(filter_) if not_null_fields := config.get("not_null_fields"): not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields) for field_name in not_null_fields: filter_ = kwargs.get(f"{field_name}_not_null_filter") if filter_ is not None: filters.append(filter_) return filters provide_filters.__signature__ = inspect.Signature( # type: ignore parameters=list(parameters.values()), return_annotation=list[FilterTypes] ) provide_filters.__annotations__ = annotations provide_filters.__annotations__["return"] = list[FilterTypes] return provide_filters