Source code for sqlspec.extensions.litestar.providers

"""Application dependency providers generators.

This module contains functions to create dependency providers for services and filters.
"""

import datetime
import inspect
import typing
from collections.abc import Callable, Mapping
from enum import Enum
from functools import partial
from inspect import isclass
from types import GenericAlias
from typing import Annotated, Any, Literal, NamedTuple, TypedDict, cast
from uuid import UUID

from litestar.di import NamedDependency, Provide
from litestar.exceptions import ValidationException
from litestar.params import QueryParameter, SkipValidation
from litestar.utils.signature import ParsedSignature
from typing_extensions import NotRequired

from sqlspec.core import (
    BeforeAfterFilter,
    BooleanFilter,
    ChoicesFilter,
    FilterTypes,
    InCollectionFilter,
    LimitOffsetFilter,
    NotInCollectionFilter,
    NotNullFilter,
    NullFilter,
    OrderByFilter,
    SearchFilter,
)
from sqlspec.utils.text import camelize

__all__ = (
    "DEPENDENCY_DEFAULTS",
    "BooleanOrNone",
    "ChoiceField",
    "DTorNone",
    "DependencyDefaults",
    "FieldNameType",
    "FilterConfig",
    "HashableType",
    "HashableValue",
    "IntOrNone",
    "SortField",
    "SortOrder",
    "SortOrderOrNone",
    "StringOrNone",
    "UuidOrNone",
    "create_filter_dependencies",
    "dep_cache",
    "normalize_choice_field_types",
)

DTorNone = datetime.datetime | None
StringOrNone = str | None
UuidOrNone = UUID | None
IntOrNone = int | None
BooleanOrNone = bool | None
SortOrder = Literal["asc", "desc"]
SortOrderOrNone = SortOrder | None
SortField = str | set[str] | list[str]
HashableValue = str | int | float | bool | None
HashableType = HashableValue | tuple[Any, ...] | tuple[tuple[str, Any], ...] | tuple[HashableValue, ...]


[docs] class DependencyDefaults: FILTERS_DEPENDENCY_KEY: str = "filters" CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter" ID_FILTER_DEPENDENCY_KEY: str = "id_filter" LIMIT_OFFSET_FILTER_DEPENDENCY_KEY: str = "limit_offset_filter" UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter" ORDER_BY_FILTER_DEPENDENCY_KEY: str = "order_by_filter" SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter" DEFAULT_PAGINATION_SIZE: int = 20
DEPENDENCY_DEFAULTS = DependencyDefaults() class FieldNameType(NamedTuple): """Type for field name and associated type information for filter configuration.""" name: str type_hint: type[Any] = str class ChoiceField: """Type for choice field name and allowed choices for filter configuration.""" __slots__ = ("choices", "name") def __init__(self, name: str, choices: list[Any] | tuple[Any, ...] | type[Enum]) -> None: self.name = name self.choices = choices def normalize_choice_field_types(choices: list[Any] | tuple[Any, ...] | type[Enum]) -> Any: """Normalize choices into a generic type hint (Literal or Enum).""" if isclass(choices) and issubclass(choices, Enum): return choices return cast("Any", typing.Literal).__getitem__(tuple(choices)) class _SortFieldResolution(NamedTuple): default_field: str default_query_value: str allowed_fields: frozenset[str] inbound_aliases: dict[str, str] field_display_names: dict[str, str] allowed_display_names: tuple[str, ...] def normalize(self, value: str | None) -> str | None: if value is None: return self.default_field return self.inbound_aliases.get(value) # Keep FilterConfig field unions and provider signatures in sync with sqlspec.extensions.fastapi.providers.
[docs] class FilterConfig(TypedDict): """Configuration for generated Litestar filter dependencies. All keys are optional. A filter dependency is created only for each enabled key. Field names are SQL-facing allowlist values; generated query parameter names and order-by aliases remain API-facing. """ id_filter: NotRequired[type[UUID | int | str]] """Type of ID filter to enable. When set, creates an ``ids`` collection filter.""" id_field: NotRequired[str] """SQL-facing field name for ID filtering. Defaults to ``"id"``.""" sort_field: NotRequired[SortField] """Allowed SQL-facing field or fields for ``orderBy`` sorting.""" sort_field_aliases: NotRequired[dict[str, str]] """Additional API-facing ``orderBy`` aliases mapped to configured ``sort_field`` values.""" sort_field_camelize: NotRequired[bool] """Whether to accept camel-case aliases for configured sort fields. Defaults to ``True``.""" sort_order: NotRequired[SortOrder] """Default sort order. Defaults to ``"desc"``.""" pagination_type: NotRequired[Literal["limit_offset"]] """Pagination strategy to enable. Currently supports ``"limit_offset"``.""" pagination_size: NotRequired[int] """Default page size for limit/offset pagination.""" search: NotRequired[str | set[str] | list[str]] """SQL-facing field or fields to search. Strings may be comma-separated.""" search_ignore_case: NotRequired[bool] """Whether search filtering is case-insensitive. Defaults to ``False``.""" created_at: NotRequired[bool] """Whether to enable ``created_at`` before/after range filtering.""" updated_at: NotRequired[bool] """Whether to enable ``updated_at`` before/after range filtering.""" not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]] """Field or fields that support ``NOT IN`` collection filtering.""" in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]] """Field or fields that support ``IN`` collection filtering.""" null_fields: NotRequired[str | set[str] | list[str]] """Field or fields that support ``IS NULL`` filtering.""" not_null_fields: NotRequired[str | set[str] | list[str]] """Field or fields that support ``IS NOT NULL`` filtering.""" boolean_fields: NotRequired[str | set[str] | list[str]] """Field or fields that support boolean filtering.""" choice_fields: NotRequired[ChoiceField | set[ChoiceField] | list[str | ChoiceField]] """Field or fields that support choices filtering."""
class DependencyCache: """Dependency cache for memoizing dynamically generated dependencies.""" def __init__(self) -> None: self.dependencies: dict[int | str, dict[str, Provide]] = {} def add_dependencies(self, key: int | str, dependencies: dict[str, Provide]) -> None: self.dependencies[key] = dependencies def get_dependencies(self, key: int | str) -> dict[str, Provide] | None: return self.dependencies.get(key) dep_cache = DependencyCache() def create_filter_dependencies( config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS ) -> dict[str, Provide]: """Create a dependency provider for the combined filter function. Args: config: FilterConfig instance with desired settings. dep_defaults: Dependency defaults to use for the filter dependencies Returns: A dependency provider function for the combined filter function. """ if (deps := dep_cache.get_dependencies(cache_key := hash(_make_hashable(config)))) is not None: return deps deps = _create_statement_filters(config, dep_defaults) dep_cache.add_dependencies(cache_key, deps) return deps def _create_statement_filters( config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS ) -> dict[str, Provide]: """Create filter dependencies based on configuration. Args: config: Configuration dictionary specifying which filters to enable dep_defaults: Dependency defaults to use for the filter dependencies Returns: Dictionary of filter provider functions """ filters: dict[str, Provide] = {} if id_type := config.get("id_filter", False): filters[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = _create_provide( _bind_provider( _IdFilterProvider(config.get("id_field", "id"), id_type if isinstance(id_type, type) else object), _provide_id_filter, ) ) if config.get("created_at", False): filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = _create_provide( _build_before_after_provider("created_at", "createdBefore", "createdAfter") ) if config.get("updated_at", False): filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = _create_provide( _build_before_after_provider("updated_at", "updatedBefore", "updatedAfter") ) if config.get("pagination_type") == "limit_offset": filters[dep_defaults.LIMIT_OFFSET_FILTER_DEPENDENCY_KEY] = _create_provide( _bind_provider( _LimitOffsetFilterProvider(config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE)), _provide_limit_offset_filter, ) ) if search_fields := config.get("search"): filters[dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY] = _create_provide( _bind_provider( _SearchFilterProvider(search_fields, config.get("search_ignore_case", False)), _provide_search_filter ) ) if sort_field := config.get("sort_field"): filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = _create_provide( _bind_provider(_OrderByProvider(sort_field, config), _provide_order_by_filter) ) if not_in_fields := config.get("not_in_fields"): not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields for field_def in not_in_fields: resolved = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def filters[f"{resolved.name}_not_in_filter"] = _create_provide( _build_in_collection_provider(resolved, negated=True) ) if in_fields := config.get("in_fields"): in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields for field_def in in_fields: resolved = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def filters[f"{resolved.name}_in_filter"] = _create_provide( _build_in_collection_provider(resolved, negated=False) ) if null_fields := config.get("null_fields"): null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields) for field_name in null_fields: filters[f"{field_name}_null_filter"] = _create_provide(_build_null_provider(field_name, negated=False)) if not_null_fields := config.get("not_null_fields"): not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields) for field_name in not_null_fields: filters[f"{field_name}_not_null_filter"] = _create_provide(_build_null_provider(field_name, negated=True)) if boolean_fields := config.get("boolean_fields"): boolean_fields = {boolean_fields} if isinstance(boolean_fields, str) else set(boolean_fields) for field_name in boolean_fields: filters[f"{field_name}_boolean_filter"] = _create_provide( _bind_provider(_BooleanFilterProvider(field_name), _provide_boolean_filter) ) if choice_fields := config.get("choice_fields"): choice_fields = {choice_fields} if isinstance(choice_fields, ChoiceField) else choice_fields for choice_def in choice_fields: resolved_choice = ChoiceField(name=choice_def, choices=[]) if isinstance(choice_def, str) else choice_def filters[f"{resolved_choice.name}_choices_filter"] = _create_provide( _bind_provider( _ChoicesFilterProvider(resolved_choice.name, resolved_choice.choices), _provide_choices_filter ) ) if filters: filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = _create_provide(_create_filter_aggregate_function(config)) return filters def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: """Create filter aggregation function based on configuration. Args: config: The filter configuration. Returns: Function that returns list of configured filters. """ parameters: dict[str, inspect.Parameter] = {} annotations: dict[str, Any] = {} annotation: Any if cls := config.get("id_filter"): annotation = NamedDependency[SkipValidation[InCollectionFilter[cls]]] # type: ignore[valid-type] parameters["id_filter"] = inspect.Parameter( name="id_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations["id_filter"] = annotation if config.get("created_at"): annotation = NamedDependency[SkipValidation[BeforeAfterFilter]] parameters["created_filter"] = inspect.Parameter( name="created_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations["created_filter"] = annotation if config.get("updated_at"): annotation = NamedDependency[SkipValidation[BeforeAfterFilter]] parameters["updated_filter"] = inspect.Parameter( name="updated_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations["updated_filter"] = annotation if config.get("search"): annotation = NamedDependency[SkipValidation[SearchFilter]] parameters["search_filter"] = inspect.Parameter( name="search_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations["search_filter"] = annotation if config.get("pagination_type") == "limit_offset": annotation = NamedDependency[SkipValidation[LimitOffsetFilter]] parameters["limit_offset_filter"] = inspect.Parameter( name="limit_offset_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations["limit_offset_filter"] = annotation if config.get("sort_field"): annotation = NamedDependency[SkipValidation[OrderByFilter]] parameters["order_by_filter"] = inspect.Parameter( name="order_by_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations["order_by_filter"] = annotation if not_in_fields := config.get("not_in_fields"): for field_def in not_in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def annotation = NamedDependency[SkipValidation[NotInCollectionFilter[Any]]] parameters[f"{field_def.name}_not_in_filter"] = inspect.Parameter( name=f"{field_def.name}_not_in_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation, ) annotations[f"{field_def.name}_not_in_filter"] = annotation if in_fields := config.get("in_fields"): for field_def in in_fields: field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def annotation = NamedDependency[SkipValidation[InCollectionFilter[Any]]] parameters[f"{field_def.name}_in_filter"] = inspect.Parameter( name=f"{field_def.name}_in_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations[f"{field_def.name}_in_filter"] = annotation if null_fields := config.get("null_fields"): null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields) for field_name in null_fields: annotation = NamedDependency[SkipValidation[NullFilter]] | None parameters[f"{field_name}_null_filter"] = inspect.Parameter( name=f"{field_name}_null_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations[f"{field_name}_null_filter"] = annotation if not_null_fields := config.get("not_null_fields"): not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields) for field_name in not_null_fields: annotation = NamedDependency[SkipValidation[NotNullFilter]] | None parameters[f"{field_name}_not_null_filter"] = inspect.Parameter( name=f"{field_name}_not_null_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation, ) annotations[f"{field_name}_not_null_filter"] = annotation if boolean_fields := config.get("boolean_fields"): boolean_fields = {boolean_fields} if isinstance(boolean_fields, str) else set(boolean_fields) for field_name in boolean_fields: annotation = NamedDependency[SkipValidation[BooleanFilter]] | None parameters[f"{field_name}_boolean_filter"] = inspect.Parameter( name=f"{field_name}_boolean_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation ) annotations[f"{field_name}_boolean_filter"] = annotation if choice_fields := config.get("choice_fields"): choice_fields = {choice_fields} if isinstance(choice_fields, ChoiceField) else choice_fields for choice_def in choice_fields: resolved_choice = ChoiceField(name=choice_def, choices=[]) if isinstance(choice_def, str) else choice_def annotation = NamedDependency[SkipValidation[ChoicesFilter[Any]]] | None parameters[f"{resolved_choice.name}_choices_filter"] = inspect.Parameter( name=f"{resolved_choice.name}_choices_filter", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation, ) annotations[f"{resolved_choice.name}_choices_filter"] = annotation return _make_aggregate_filter_provider(list(parameters.values()), annotations) def _make_hashable(value: Any) -> HashableType: """Convert a value into a hashable type for caching purposes. Args: value: Any value that needs to be made hashable. Returns: A hashable version of the value. """ if isinstance(value, dict): items = [] for k in sorted(value.keys()): # pyright: ignore v = value[k] items.append((str(k), _make_hashable(v))) return tuple(items) if isinstance(value, (list, set)): hashable_items = [_make_hashable(item) for item in value] filtered_items = [item for item in hashable_items if item is not None] return tuple(sorted(filtered_items, key=str)) if isinstance(value, (str, int, float, bool, type(None))): return value return str(value) def _set_provider_metadata( provider: Any, signature: inspect.Signature, annotations: dict[str, Any] ) -> Callable[..., Any]: provider.__signature__ = signature provider.__annotations__ = annotations return cast("Callable[..., Any]", provider) def _bind_provider(context: Any, provider: Callable[..., Any]) -> Callable[..., Any]: return _set_provider_metadata(partial(provider, context), context.signature, context.annotations) def _create_provide(provider: Callable[..., Any]) -> Provide: dependency = Provide(provider, sync_to_thread=False) dependency.parsed_fn_signature = ParsedSignature.from_signature( inspect.signature(provider), getattr(provider, "__annotations__", {}) ) return dependency def _aggregate_filter_provider(**kwargs: Any) -> list[FilterTypes]: filters: list[FilterTypes] = [] for filter_value in kwargs.values(): if filter_value is None: continue if isinstance(filter_value, list): filters.extend(filter_value) elif (isinstance(filter_value, SearchFilter) and filter_value.value is None) or ( isinstance(filter_value, OrderByFilter) and filter_value.field_name is None ): continue else: filters.append(filter_value) return filters def _make_aggregate_filter_provider( parameters: list[inspect.Parameter], annotations: dict[str, Any] ) -> Callable[..., list[FilterTypes]]: aggregate_annotations = dict(annotations) aggregate_annotations["return"] = list[FilterTypes] return _set_provider_metadata( partial(_aggregate_filter_provider), inspect.Signature(parameters=parameters, return_annotation=list[FilterTypes]), aggregate_annotations, ) def _collection_value_annotation(collection_type: type[Any], value_type: type[Any]) -> Any: return GenericAlias(collection_type, (value_type,)) | None def _query_parameter_annotation(value_annotation: Any, query: Any) -> Any: return Annotated[value_annotation, query] class _CollectionFilterProvider: """Per-field `IN` / `NOT IN` provider with a unique parameter name (issue #435).""" def __init__(self, field: FieldNameType, *, negated: bool) -> None: self.type_hint = field.type_hint self.field_name = field.name self.param_name = f"{field.name}_values" self.filter_cls: Any = NotInCollectionFilter if negated else InCollectionFilter self.return_annotation = self.filter_cls[field.type_hint] | None annotation = _query_parameter_annotation( _collection_value_annotation(list, field.type_hint), QueryParameter(name=camelize(f"{field.name}_{'not_in' if negated else 'in'}"), required=False), ) self.signature = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=annotation ) ], return_annotation=self.return_annotation, ) self.annotations = {self.param_name: annotation, "return": self.return_annotation} def __call__(self, **kwargs: Any) -> Any: values = kwargs.get(self.param_name) return self.filter_cls[self.type_hint](field_name=self.field_name, values=values) if values else None class _NullFilterProvider: def __init__(self, field_name: str, *, negated: bool) -> None: suffix = "is_not_null" if negated else "is_null" self.field_name = field_name self.param_name = f"{field_name}_{suffix}" self.filter_cls: type[Any] = NotNullFilter if negated else NullFilter self.return_annotation = self.filter_cls | None annotation = _query_parameter_annotation( bool | None, QueryParameter(name=camelize(self.param_name), required=False) ) self.signature = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=annotation ) ], return_annotation=self.return_annotation, ) self.annotations = {self.param_name: annotation, "return": self.return_annotation} def __call__(self, **kwargs: Any) -> Any: return self.filter_cls(field_name=self.field_name) if kwargs.get(self.param_name) else None class _BeforeAfterFilterProvider: """Before/after provider with unique parameter names for sibling dependencies.""" def __init__(self, field_name: str, before_alias: str, after_alias: str) -> None: self.field_name = field_name self.before_param = f"{field_name}_before" self.after_param = f"{field_name}_after" before_annotation = _query_parameter_annotation(DTorNone, QueryParameter(name=before_alias, required=False)) after_annotation = _query_parameter_annotation(DTorNone, QueryParameter(name=after_alias, required=False)) self.signature = inspect.Signature( parameters=[ inspect.Parameter( self.before_param, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=before_annotation ), inspect.Parameter( self.after_param, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=after_annotation ), ], return_annotation=BeforeAfterFilter, ) self.annotations = { self.before_param: before_annotation, self.after_param: after_annotation, "return": BeforeAfterFilter, } def __call__(self, **kwargs: Any) -> BeforeAfterFilter: return BeforeAfterFilter(self.field_name, kwargs.get(self.before_param), kwargs.get(self.after_param)) class _IdFilterProvider: def __init__(self, field_name: str, id_type: type[Any]) -> None: self.field_name = field_name self.return_annotation = InCollectionFilter[id_type] # type: ignore[valid-type] annotation = _query_parameter_annotation( _collection_value_annotation(list, id_type), QueryParameter(name="ids", required=False) ) self.signature = inspect.Signature( parameters=[ inspect.Parameter("ids", kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=annotation) ], return_annotation=self.return_annotation, ) self.annotations = {"ids": annotation, "return": self.return_annotation} def __call__(self, ids: list[Any] | None = None) -> InCollectionFilter[Any]: return InCollectionFilter(field_name=self.field_name, values=ids) class _LimitOffsetFilterProvider: def __init__(self, default_page_size: int) -> None: self.default_page_size = default_page_size self.return_annotation = LimitOffsetFilter current_annotation = _query_parameter_annotation(int, QueryParameter(name="currentPage", required=False, ge=1)) size_annotation = _query_parameter_annotation(int, QueryParameter(name="pageSize", required=False, ge=1)) self.signature = inspect.Signature( parameters=[ inspect.Parameter( "current_page", kind=inspect.Parameter.KEYWORD_ONLY, default=1, annotation=current_annotation ), inspect.Parameter( "page_size", kind=inspect.Parameter.KEYWORD_ONLY, default=default_page_size, annotation=size_annotation, ), ], return_annotation=self.return_annotation, ) self.annotations = { "current_page": current_annotation, "page_size": size_annotation, "return": self.return_annotation, } def __call__(self, current_page: int = 1, page_size: int | None = None) -> LimitOffsetFilter: resolved_page_size = page_size if page_size is not None else self.default_page_size return LimitOffsetFilter(resolved_page_size, resolved_page_size * (current_page - 1)) class _SearchFilterProvider: def __init__(self, search_fields: str | set[str] | list[str], ignore_case_default: bool) -> None: self.search_fields = search_fields self.ignore_case_default = ignore_case_default self.return_annotation = SearchFilter search_annotation = _query_parameter_annotation( StringOrNone, QueryParameter(name="searchString", required=False, title="Field to search") ) ignore_annotation = _query_parameter_annotation( BooleanOrNone, QueryParameter(name="searchIgnoreCase", required=False, title="Search should be case sensitive"), ) self.signature = inspect.Signature( parameters=[ inspect.Parameter( "search_string", kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=search_annotation ), inspect.Parameter( "ignore_case", kind=inspect.Parameter.KEYWORD_ONLY, default=ignore_case_default, annotation=ignore_annotation, ), ], return_annotation=self.return_annotation, ) self.annotations = { "search_string": search_annotation, "ignore_case": ignore_annotation, "return": SearchFilter, } def __call__(self, search_string: StringOrNone = None, ignore_case: BooleanOrNone = None) -> SearchFilter: field_names: set[Any] = ( set(self.search_fields.split(",")) if isinstance(self.search_fields, str) else set(self.search_fields) ) return SearchFilter( field_name=field_names, value=search_string, ignore_case=self.ignore_case_default if ignore_case is None else ignore_case, ) class _OrderByProvider: def __init__(self, sort_field: SortField, config: FilterConfig) -> None: self.sort_resolution = _resolve_sort_field_aliases( sort_field, sort_field_aliases=config.get("sort_field_aliases"), sort_field_camelize=config.get("sort_field_camelize", True), ) self.allowed_field_names = ", ".join(self.sort_resolution.allowed_display_names) self.sort_order_default: SortOrder = config.get("sort_order", "desc") self.return_annotation = OrderByFilter field_annotation = _query_parameter_annotation( StringOrNone, QueryParameter(name="orderBy", required=False, title="Order by field") ) order_annotation = _query_parameter_annotation( SortOrderOrNone, QueryParameter(name="sortOrder", required=False, title="Field to search") ) self.signature = inspect.Signature( parameters=[ inspect.Parameter( "field_name", kind=inspect.Parameter.KEYWORD_ONLY, default=self.sort_resolution.default_query_value, annotation=field_annotation, ), inspect.Parameter( "sort_order", kind=inspect.Parameter.KEYWORD_ONLY, default=self.sort_order_default, annotation=order_annotation, ), ], return_annotation=self.return_annotation, ) self.annotations = {"field_name": field_annotation, "sort_order": order_annotation, "return": OrderByFilter} def __call__(self, field_name: StringOrNone = None, sort_order: SortOrderOrNone = None) -> OrderByFilter: resolved_field = ( self.sort_resolution.normalize(field_name) if field_name else self.sort_resolution.default_field ) if resolved_field is None: msg = f"Invalid orderBy field '{field_name}'. Allowed fields: {self.allowed_field_names}" raise ValidationException(detail=msg) return OrderByFilter(field_name=resolved_field, sort_order=sort_order or self.sort_order_default) def _resolve_sort_field_aliases( sort_field: SortField, sort_field_aliases: Mapping[str, str] | None = None, sort_field_camelize: bool = True ) -> _SortFieldResolution: fields = _coerce_sort_fields(sort_field) allowed_fields = frozenset(fields) inbound_aliases: dict[str, str] = {} field_display_names = {field: field for field in fields} for field in fields: _add_sort_field_alias(inbound_aliases, alias=field, field=field) if sort_field_camelize: for field in fields: alias = camelize(field) _add_sort_field_alias(inbound_aliases, alias=alias, field=field) field_display_names[field] = alias if sort_field_aliases: for alias, field in sort_field_aliases.items(): if field not in allowed_fields: msg = f"sort field alias '{alias}' targets unknown sort field '{field}'" raise ValueError(msg) _add_sort_field_alias(inbound_aliases, alias=alias, field=field) field_display_names[field] = alias allowed_display_names = tuple(field_display_names[field] for field in fields) return _SortFieldResolution( default_field=fields[0], default_query_value=field_display_names[fields[0]], allowed_fields=allowed_fields, inbound_aliases=inbound_aliases, field_display_names=field_display_names, allowed_display_names=allowed_display_names, ) def _coerce_sort_fields(sort_field: SortField) -> tuple[str, ...]: if isinstance(sort_field, str): return (sort_field,) fields = tuple(sorted(sort_field)) if isinstance(sort_field, set) else tuple(sort_field) if not fields: msg = "sort_field must include at least one field" raise ValueError(msg) return fields def _add_sort_field_alias(inbound_aliases: dict[str, str], *, alias: str, field: str) -> None: existing_field = inbound_aliases.get(alias) if existing_field is None or existing_field == field: inbound_aliases[alias] = field return msg = f"ambiguous sort field alias '{alias}' maps to both '{existing_field}' and '{field}'" raise ValueError(msg) def _provide_collection_filter(context: _CollectionFilterProvider, **kwargs: Any) -> Any: values = kwargs.get(context.param_name) return context.filter_cls[context.type_hint](field_name=context.field_name, values=values) if values else None def _provide_null_filter(context: _NullFilterProvider, **kwargs: Any) -> Any: return context.filter_cls(field_name=context.field_name) if kwargs.get(context.param_name) else None def _provide_before_after_filter(context: _BeforeAfterFilterProvider, **kwargs: Any) -> BeforeAfterFilter: return BeforeAfterFilter(context.field_name, kwargs.get(context.before_param), kwargs.get(context.after_param)) def _provide_id_filter(context: _IdFilterProvider, ids: list[Any] | None = None) -> InCollectionFilter[Any]: return InCollectionFilter(field_name=context.field_name, values=ids) def _provide_limit_offset_filter( context: _LimitOffsetFilterProvider, current_page: int = 1, page_size: int | None = None ) -> LimitOffsetFilter: resolved_page_size = page_size if page_size is not None else context.default_page_size return LimitOffsetFilter(resolved_page_size, resolved_page_size * (current_page - 1)) def _provide_search_filter( context: _SearchFilterProvider, search_string: StringOrNone = None, ignore_case: BooleanOrNone = None ) -> SearchFilter: field_names: set[Any] = ( set(context.search_fields.split(",")) if isinstance(context.search_fields, str) else set(context.search_fields) ) return SearchFilter( field_name=field_names, value=search_string, ignore_case=context.ignore_case_default if ignore_case is None else ignore_case, ) def _provide_order_by_filter( context: _OrderByProvider, field_name: StringOrNone = None, sort_order: SortOrderOrNone = None ) -> OrderByFilter: resolved_field = ( context.sort_resolution.normalize(field_name) if field_name else context.sort_resolution.default_field ) if resolved_field is None: msg = f"Invalid orderBy field '{field_name}'. Allowed fields: {context.allowed_field_names}" raise ValidationException(detail=msg) return OrderByFilter(field_name=resolved_field, sort_order=sort_order or context.sort_order_default) def _build_in_collection_provider(field: FieldNameType, *, negated: bool) -> Callable[..., Any]: return _bind_provider(_CollectionFilterProvider(field, negated=negated), _provide_collection_filter) def _build_null_provider(field_name: str, *, negated: bool) -> Callable[..., Any]: return _bind_provider(_NullFilterProvider(field_name, negated=negated), _provide_null_filter) def _build_before_after_provider(field_name: str, before_alias: str, after_alias: str) -> Callable[..., Any]: return _bind_provider( _BeforeAfterFilterProvider(field_name, before_alias, after_alias), _provide_before_after_filter ) class _BooleanFilterProvider: __slots__ = ("annotations", "field_name", "param_name", "return_annotation", "signature") def __init__(self, field_name: str) -> None: self.field_name = field_name self.param_name = f"{field_name}_boolean" self.return_annotation = BooleanFilter | None annotation = _query_parameter_annotation( bool | None, QueryParameter(name=camelize(self.param_name), required=False) ) self.signature = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=annotation ) ], return_annotation=self.return_annotation, ) self.annotations = {self.param_name: annotation, "return": self.return_annotation} class _ChoicesFilterProvider: __slots__ = ("annotations", "choices", "field_name", "param_name", "return_annotation", "signature") def __init__(self, field_name: str, choices: list[Any] | tuple[Any, ...] | type[Enum]) -> None: self.field_name = field_name self.choices = choices self.param_name = f"{field_name}_choices" choices_type = normalize_choice_field_types(choices) self.return_annotation = ChoicesFilter[Any] | None annotation = _query_parameter_annotation( _collection_value_annotation(list, choices_type), QueryParameter(name=camelize(self.param_name), required=False), ) self.signature = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=annotation ) ], return_annotation=self.return_annotation, ) self.annotations = {self.param_name: annotation, "return": self.return_annotation} def _provide_boolean_filter(context: _BooleanFilterProvider, **kwargs: Any) -> BooleanFilter | None: val = kwargs.get(context.param_name) if val is None: return None return BooleanFilter(field_name=context.field_name, value=val) def _provide_choices_filter(context: _ChoicesFilterProvider, **kwargs: Any) -> Any: values = kwargs.get(context.param_name) if not values: return None return ChoicesFilter[Any](field_name=context.field_name, values=values)