Source code for sqlspec.extensions.fastapi.providers

"""Application dependency providers for FastAPI filter injection.

This module provides filter dependency injection for FastAPI routes, allowing
automatic parsing of query parameters into SQLSpec filter objects.
"""

import datetime
import inspect
import typing
from collections.abc import Callable, Mapping
from enum import Enum
from functools import partial
from inspect import isclass
from types import GenericAlias
from typing import Annotated, Any, Literal, NamedTuple, cast
from uuid import UUID

from fastapi import Depends, Query
from fastapi.exceptions import RequestValidationError
from typing_extensions import NotRequired, TypedDict

from sqlspec.core import (
    BeforeAfterFilter,
    BooleanFilter,
    ChoicesFilter,
    FilterTypes,
    InCollectionFilter,
    LimitOffsetFilter,
    NotInCollectionFilter,
    NotNullFilter,
    NullFilter,
    OrderByFilter,
    SearchFilter,
)
from sqlspec.utils.text import camelize

__all__ = (
    "DEPENDENCY_DEFAULTS",
    "BooleanOrNone",
    "ChoiceField",
    "DTorNone",
    "DependencyDefaults",
    "FieldNameType",
    "FilterConfig",
    "HashableType",
    "HashableValue",
    "IntOrNone",
    "SortField",
    "SortOrder",
    "SortOrderOrNone",
    "StringOrNone",
    "UuidOrNone",
    "dep_cache",
    "normalize_choice_field_types",
    "provide_filters",
)

DTorNone = datetime.datetime | None
StringOrNone = str | None
UuidOrNone = UUID | None
IntOrNone = int | None
BooleanOrNone = bool | None
SortOrder = Literal["asc", "desc"]
SortOrderOrNone = SortOrder | None
SortField = str | set[str] | list[str]
HashableValue = str | int | float | bool | None
HashableType = HashableValue | tuple[Any, ...] | tuple[tuple[str, Any], ...] | tuple[HashableValue, ...]
_FILTER_CONFIG_KEYS = frozenset({
    "id_filter",
    "created_at",
    "updated_at",
    "pagination_type",
    "search",
    "sort_field",
    "not_in_fields",
    "in_fields",
    "null_fields",
    "not_null_fields",
    "boolean_fields",
    "choice_fields",
})


[docs] class DependencyDefaults: """Default values for dependency generation.""" CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter" ID_FILTER_DEPENDENCY_KEY: str = "id_filter" LIMIT_OFFSET_FILTER_DEPENDENCY_KEY: str = "limit_offset_filter" UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter" ORDER_BY_FILTER_DEPENDENCY_KEY: str = "order_by_filter" SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter" DEFAULT_PAGINATION_SIZE: int = 20
DEPENDENCY_DEFAULTS = DependencyDefaults() class FieldNameType(NamedTuple): """Type for field name and associated type information for filter configuration.""" name: str """Name of the field to filter on.""" type_hint: type[Any] = str """Type of the filter value. Defaults to str.""" class ChoiceField: """Type for choice field name and allowed choices for filter configuration.""" __slots__ = ("choices", "name") def __init__(self, name: str, choices: list[Any] | tuple[Any, ...] | type[Enum]) -> None: self.name = name self.choices = choices def normalize_choice_field_types(choices: list[Any] | tuple[Any, ...] | type[Enum]) -> Any: """Normalize choices into a generic type hint (Literal or Enum).""" if isclass(choices) and issubclass(choices, Enum): return choices return cast("Any", typing.Literal).__getitem__(tuple(choices)) class _SortFieldResolution(NamedTuple): default_field: str default_query_value: str allowed_fields: frozenset[str] inbound_aliases: dict[str, str] field_display_names: dict[str, str] allowed_display_names: tuple[str, ...] def normalize(self, value: str | None) -> str | None: if value is None: return self.default_field return self.inbound_aliases.get(value) # Keep FilterConfig field unions and provider signatures in sync with sqlspec.extensions.litestar.providers.
[docs] class FilterConfig(TypedDict): """Configuration for generated FastAPI filter dependencies. All keys are optional. A filter dependency is created only for each enabled key. Field names are SQL-facing allowlist values; generated query parameter names and order-by aliases remain API-facing. """ id_filter: NotRequired[type[UUID | int | str]] """Type of ID filter to enable. When set, creates an ``ids`` collection filter.""" id_field: NotRequired[str] """SQL-facing field name for ID filtering. Defaults to ``"id"``.""" sort_field: NotRequired[SortField] """Allowed SQL-facing field or fields for ``orderBy`` sorting.""" sort_field_aliases: NotRequired[dict[str, str]] """Additional API-facing ``orderBy`` aliases mapped to configured ``sort_field`` values.""" sort_field_camelize: NotRequired[bool] """Whether to accept camel-case aliases for configured sort fields. Defaults to ``True``.""" sort_order: NotRequired[SortOrder] """Default sort order. Defaults to ``"desc"``.""" pagination_type: NotRequired[Literal["limit_offset"]] """Pagination strategy to enable. Currently supports ``"limit_offset"``.""" pagination_size: NotRequired[int] """Default page size for limit/offset pagination.""" search: NotRequired[str | set[str] | list[str]] """SQL-facing field or fields to search. Strings may be comma-separated.""" search_ignore_case: NotRequired[bool] """Whether search filtering is case-insensitive. Defaults to ``False``.""" created_at: NotRequired[bool] """Whether to enable ``created_at`` before/after range filtering.""" updated_at: NotRequired[bool] """Whether to enable ``updated_at`` before/after range filtering.""" not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]] """Field or fields that support ``NOT IN`` collection filtering.""" in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]] """Field or fields that support ``IN`` collection filtering.""" null_fields: NotRequired[str | set[str] | list[str]] """Field or fields that support ``IS NULL`` filtering.""" not_null_fields: NotRequired[str | set[str] | list[str]] """Field or fields that support ``IS NOT NULL`` filtering.""" boolean_fields: NotRequired[str | set[str] | list[str]] """Field or fields that support boolean filtering.""" choice_fields: NotRequired[ChoiceField | set[ChoiceField] | list[str | ChoiceField]] """Field or fields that support choices filtering."""
class DependencyCache: """Simple dependency cache to memoize dynamically generated dependencies.""" def __init__(self) -> None: self.dependencies: dict[int, Callable[..., list[FilterTypes]]] = {} def add_dependencies(self, key: int, dependencies: "Callable[..., list[FilterTypes]]") -> None: """Add dependencies to cache. Args: key: Cache key (hash of config). dependencies: Dependency callable to cache. """ self.dependencies[key] = dependencies def get_dependencies(self, key: int) -> "Callable[..., list[FilterTypes]] | None": """Get dependencies from cache. Args: key: Cache key (hash of config). Returns: Cached dependency callable or None if not found. """ return self.dependencies.get(key) dep_cache = DependencyCache()
[docs] def provide_filters( config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS ) -> "Callable[..., list[FilterTypes]]": """Create FastAPI dependency provider for filters based on configuration. This function dynamically generates a FastAPI dependency function that parses query parameters into SQLSpec filter objects. Args: config: Filter configuration specifying which filters to enable. dep_defaults: Dependency defaults for filter configuration. Returns: A FastAPI dependency callable that returns list of filters. """ if not _has_filter_config(config): return _empty_filter_list cache_key = hash(_make_hashable(config)) cached_dep = dep_cache.get_dependencies(cache_key) if cached_dep is not None: return cached_dep dep = _create_filter_aggregate_function(config, dep_defaults) dep_cache.add_dependencies(cache_key, dep) return dep
def _create_filter_aggregate_function( config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS ) -> "Callable[..., list[FilterTypes]]": """Create a FastAPI dependency function that aggregates multiple filter dependencies. Args: config: Filter configuration. dep_defaults: Dependency defaults. Returns: A FastAPI dependency function that aggregates multiple filter dependencies. """ params: list[inspect.Parameter] = [] annotations: dict[str, Any] = {} if (id_type := config.get("id_filter", False)) is not False: _add_dependency( params, annotations, dep_defaults.ID_FILTER_DEPENDENCY_KEY, _IdFilterProvider(config.get("id_field", "id"), id_type if isinstance(id_type, type) else object), ) if config.get("created_at", False): _add_dependency( params, annotations, dep_defaults.CREATED_FILTER_DEPENDENCY_KEY, _BeforeAfterFilterProvider("created_at", "createdBefore", "createdAfter"), ) if config.get("updated_at", False): _add_dependency( params, annotations, dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY, _BeforeAfterFilterProvider("updated_at", "updatedBefore", "updatedAfter"), ) if config.get("pagination_type") == "limit_offset": _add_dependency( params, annotations, dep_defaults.LIMIT_OFFSET_FILTER_DEPENDENCY_KEY, _LimitOffsetFilterProvider(config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE)), ) if search_fields := config.get("search"): _add_dependency( params, annotations, dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY, _SearchFilterProvider(search_fields, config.get("search_ignore_case", False)), ) if sort_field := config.get("sort_field"): _add_dependency( params, annotations, dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY, _OrderByProvider(sort_field, config) ) if not_in_fields := config.get("not_in_fields"): not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields for field_def in not_in_fields: resolved_field: FieldNameType = ( FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def ) param_name = f"{resolved_field.name}_not_in_filter" _add_dependency(params, annotations, param_name, _CollectionFilterProvider(resolved_field, negated=True)) if in_fields := config.get("in_fields"): in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields for field_def in in_fields: resolved_field = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def param_name = f"{resolved_field.name}_in_filter" _add_dependency(params, annotations, param_name, _CollectionFilterProvider(resolved_field, negated=False)) if null_fields := config.get("null_fields"): null_fields = {null_fields} if isinstance(null_fields, str) else null_fields for field_name in null_fields: param_name = f"{field_name}_null_filter" _add_dependency(params, annotations, param_name, _NullFilterProvider(field_name, negated=False)) if not_null_fields := config.get("not_null_fields"): not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else not_null_fields for field_name in not_null_fields: param_name = f"{field_name}_not_null_filter" _add_dependency(params, annotations, param_name, _NullFilterProvider(field_name, negated=True)) if boolean_fields := config.get("boolean_fields"): boolean_fields = {boolean_fields} if isinstance(boolean_fields, str) else boolean_fields for field_name in boolean_fields: param_name = f"{field_name}_boolean_filter" _add_dependency(params, annotations, param_name, _BooleanFilterProvider(field_name)) if choice_fields := config.get("choice_fields"): choice_fields = {choice_fields} if isinstance(choice_fields, ChoiceField) else choice_fields for choice_def in choice_fields: resolved_choice = ChoiceField(name=choice_def, choices=[]) if isinstance(choice_def, str) else choice_def param_name = f"{resolved_choice.name}_choices_filter" _add_dependency( params, annotations, param_name, _ChoicesFilterProvider(resolved_choice.name, resolved_choice.choices) ) return _make_aggregate_filter_provider(params, annotations) def _empty_filter_list() -> "list[FilterTypes]": return [] def _make_hashable(value: Any) -> HashableType: """Convert a value into a hashable type for caching purposes. Args: value: Any value that needs to be made hashable. Returns: A hashable version of the value. """ if isinstance(value, dict): items = [] for k in sorted(value.keys()): v = value[k] items.append((str(k), _make_hashable(v))) return tuple(items) if isinstance(value, (list, set)): hashable_items = [_make_hashable(item) for item in value] filtered_items = [item for item in hashable_items if item is not None] return tuple(sorted(filtered_items, key=str)) if isinstance(value, (str, int, float, bool, type(None))): return value return str(value) def _has_filter_config(config: FilterConfig) -> bool: for key in _FILTER_CONFIG_KEYS: value = config.get(key) if value is not None and value is not False and value != []: return True return False def _collection_value_annotation(collection_type: type[Any], value_type: type[Any]) -> Any: return GenericAlias(collection_type, (value_type,)) | None def _query_parameter_annotation(value_annotation: Any, query: Any) -> Any: return Annotated[value_annotation, query] def _set_provider_metadata( provider: Any, signature: inspect.Signature, annotations: dict[str, Any] ) -> Callable[..., Any]: provider.__signature__ = signature provider.__annotations__ = annotations return cast("Callable[..., Any]", provider) def _aggregate_filter_provider(**kwargs: Any) -> list[FilterTypes]: filters: list[FilterTypes] = [] for filter_value in kwargs.values(): if filter_value is None: continue if isinstance(filter_value, list): filters.extend(filter_value) elif (isinstance(filter_value, SearchFilter) and filter_value.value is None) or ( isinstance(filter_value, OrderByFilter) and filter_value.field_name is None ): continue else: filters.append(filter_value) return filters def _make_aggregate_filter_provider( parameters: list[inspect.Parameter], annotations: dict[str, Any] ) -> Callable[..., list[FilterTypes]]: aggregate_annotations = dict(annotations) aggregate_annotations["return"] = list[FilterTypes] return _set_provider_metadata( partial(_aggregate_filter_provider), inspect.Signature(parameters=parameters, return_annotation=list[FilterTypes]), aggregate_annotations, ) class _IdFilterProvider: def __init__(self, field_name: str, id_type: type[Any]) -> None: self.field_name = field_name self.return_annotation = InCollectionFilter[id_type] | None # type: ignore[valid-type] ids_parameter_annotation = _query_parameter_annotation( _collection_value_annotation(list, id_type), Query(alias="ids", description="IDs to filter by.") ) self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( "ids", kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=ids_parameter_annotation ) ], return_annotation=self.return_annotation, ) def __call__(self, ids: list[Any] | None = None) -> InCollectionFilter[Any] | None: return InCollectionFilter(field_name=self.field_name, values=ids) if ids else None class _BeforeAfterFilterProvider: def __init__(self, field_name: str, before_alias: str, after_alias: str) -> None: self.field_name = field_name self.before_alias = before_alias self.after_alias = after_alias self.before_param = f"{field_name}_before" self.after_param = f"{field_name}_after" self.return_annotation = BeforeAfterFilter | None self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( self.before_param, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=Annotated[ str | None, Query( alias=before_alias, description=f"Filter by {field_name} before this timestamp.", json_schema_extra={"format": "date-time"}, ), ], ), inspect.Parameter( self.after_param, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=Annotated[ str | None, Query( alias=after_alias, description=f"Filter by {field_name} after this timestamp.", json_schema_extra={"format": "date-time"}, ), ], ), ], return_annotation=self.return_annotation, ) def __call__(self, **kwargs: Any) -> BeforeAfterFilter | None: before_dt = self._parse_datetime(kwargs.get(self.before_param), self.before_alias) after_dt = self._parse_datetime(kwargs.get(self.after_param), self.after_alias) if before_dt or after_dt: return BeforeAfterFilter(field_name=self.field_name, before=before_dt, after=after_dt) return None @staticmethod def _parse_datetime(value: Any, alias: str) -> datetime.datetime | None: if value is None: return None if isinstance(value, datetime.datetime): return value try: return datetime.datetime.fromisoformat(value.replace("Z", "+00:00")) except (ValueError, TypeError, AttributeError): msg = f"Invalid date format for {alias}" raise RequestValidationError(errors=[{"loc": ("query", alias), "msg": msg, "type": "value_error.datetime"}]) class _LimitOffsetFilterProvider: def __init__(self, default_page_size: int) -> None: self.default_page_size = default_page_size self.return_annotation = LimitOffsetFilter self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( "current_page", kind=inspect.Parameter.KEYWORD_ONLY, default=1, annotation=Annotated[ int, Query(ge=1, alias="currentPage", description="Page number for pagination.") ], ), inspect.Parameter( "page_size", kind=inspect.Parameter.KEYWORD_ONLY, default=default_page_size, annotation=Annotated[int, Query(ge=1, alias="pageSize", description="Number of items per page.")], ), ], return_annotation=self.return_annotation, ) def __call__(self, current_page: int = 1, page_size: int | None = None) -> LimitOffsetFilter: resolved_page_size = page_size if page_size is not None else self.default_page_size return LimitOffsetFilter(limit=resolved_page_size, offset=resolved_page_size * (current_page - 1)) class _SearchFilterProvider: def __init__(self, search_fields: str | set[str] | list[str], ignore_case_default: bool) -> None: self.search_fields = search_fields self.ignore_case_default = ignore_case_default self.return_annotation = SearchFilter | None self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( "search_string", kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=Annotated[str | None, Query(alias="searchString", description="Search term.")], ), inspect.Parameter( "ignore_case", kind=inspect.Parameter.KEYWORD_ONLY, default=ignore_case_default, annotation=Annotated[ bool | None, Query(alias="searchIgnoreCase", description="Whether search should be case-insensitive."), ], ), ], return_annotation=self.return_annotation, ) def __call__(self, search_string: str | None = None, ignore_case: bool | None = None) -> SearchFilter | None: field_names: set[Any] = ( set(self.search_fields.split(",")) if isinstance(self.search_fields, str) else set(self.search_fields) ) if search_string: return SearchFilter( field_name=field_names, value=search_string, ignore_case=self.ignore_case_default if ignore_case is None else ignore_case, ) return None class _OrderByProvider: def __init__(self, sort_field: SortField, config: FilterConfig) -> None: self.sort_resolution = _resolve_sort_field_aliases( sort_field, sort_field_aliases=config.get("sort_field_aliases"), sort_field_camelize=config.get("sort_field_camelize", True), ) self.sort_order_default: SortOrder = config.get("sort_order", "desc") self.allowed_field_names = ", ".join(self.sort_resolution.allowed_display_names) self.return_annotation = OrderByFilter self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( "field_name", kind=inspect.Parameter.KEYWORD_ONLY, default=self.sort_resolution.default_query_value, annotation=Annotated[str, Query(alias="orderBy", description="Field to order by.")], ), inspect.Parameter( "sort_order", kind=inspect.Parameter.KEYWORD_ONLY, default=self.sort_order_default, annotation=Annotated[ SortOrder | None, Query(alias="sortOrder", description="Sort order ('asc' or 'desc').") ], ), ], return_annotation=self.return_annotation, ) def __call__(self, field_name: str | None = None, sort_order: SortOrder | None = None) -> OrderByFilter: query_value = field_name or self.sort_resolution.default_query_value resolved_field = self.sort_resolution.normalize(query_value) if resolved_field is None: msg = f"Invalid orderBy field '{query_value}'. Allowed fields: {self.allowed_field_names}" raise RequestValidationError(errors=[{"loc": ("query", "orderBy"), "msg": msg, "type": "value_error"}]) return OrderByFilter(field_name=resolved_field, sort_order=sort_order or self.sort_order_default) def _resolve_sort_field_aliases( sort_field: SortField, sort_field_aliases: Mapping[str, str] | None = None, sort_field_camelize: bool = True ) -> _SortFieldResolution: fields = _coerce_sort_fields(sort_field) allowed_fields = frozenset(fields) inbound_aliases: dict[str, str] = {} field_display_names = {field: field for field in fields} for field in fields: _add_sort_field_alias(inbound_aliases, alias=field, field=field) if sort_field_camelize: for field in fields: alias = camelize(field) _add_sort_field_alias(inbound_aliases, alias=alias, field=field) field_display_names[field] = alias if sort_field_aliases: for alias, field in sort_field_aliases.items(): if field not in allowed_fields: msg = f"sort field alias '{alias}' targets unknown sort field '{field}'" raise ValueError(msg) _add_sort_field_alias(inbound_aliases, alias=alias, field=field) field_display_names[field] = alias allowed_display_names = tuple(field_display_names[field] for field in fields) return _SortFieldResolution( default_field=fields[0], default_query_value=field_display_names[fields[0]], allowed_fields=allowed_fields, inbound_aliases=inbound_aliases, field_display_names=field_display_names, allowed_display_names=allowed_display_names, ) def _coerce_sort_fields(sort_field: SortField) -> tuple[str, ...]: if isinstance(sort_field, str): return (sort_field,) fields = tuple(sorted(sort_field)) if isinstance(sort_field, set) else tuple(sort_field) if not fields: msg = "sort_field must include at least one field" raise ValueError(msg) return fields def _add_sort_field_alias(inbound_aliases: dict[str, str], *, alias: str, field: str) -> None: existing_field = inbound_aliases.get(alias) if existing_field is None or existing_field == field: inbound_aliases[alias] = field return msg = f"ambiguous sort field alias '{alias}' maps to both '{existing_field}' and '{field}'" raise ValueError(msg) class _CollectionFilterProvider: def __init__(self, field: FieldNameType, *, negated: bool) -> None: self.field_name = field.name self.type_hint = field.type_hint self.param_name = f"{field.name}_{'not_in' if negated else 'in'}_values" self.filter_cls: Any = NotInCollectionFilter if negated else InCollectionFilter query_suffix = "not_in" if negated else "in" parameter_annotation = _query_parameter_annotation( _collection_value_annotation(set, field.type_hint), Query( alias=camelize(f"{field.name}_{query_suffix}"), description=f"Filter {field.name} {query_suffix} values" ), ) self.return_annotation = self.filter_cls[field.type_hint] | None self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=parameter_annotation ) ], return_annotation=self.return_annotation, ) def __call__(self, **kwargs: Any) -> Any: values = kwargs.get(self.param_name) return self.filter_cls[self.type_hint](field_name=self.field_name, values=values) if values else None class _NullFilterProvider: def __init__(self, field_name: str, *, negated: bool) -> None: self.field_name = field_name self.param_name = f"{field_name}_{'is_not_null' if negated else 'is_null'}" self.filter_cls: type[Any] = NotNullFilter if negated else NullFilter self.return_annotation = self.filter_cls | None self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=Annotated[ bool | None, Query( alias=camelize(self.param_name), description=f"Filter where {field_name} {'IS NOT NULL' if negated else 'IS NULL'}", ), ], ) ], return_annotation=self.return_annotation, ) def __call__(self, **kwargs: Any) -> Any: return self.filter_cls(field_name=self.field_name) if kwargs.get(self.param_name) else None class _BooleanFilterProvider: __slots__ = ("__signature__", "field_name", "param_name", "return_annotation") def __init__(self, field_name: str) -> None: self.field_name = field_name self.param_name = f"{field_name}_boolean" self.return_annotation = BooleanFilter | None annotation = _query_parameter_annotation( bool | None, Query(alias=camelize(self.param_name), description=f"Filter by boolean field {field_name}") ) self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=annotation ) ], return_annotation=self.return_annotation, ) def __call__(self, **kwargs: Any) -> BooleanFilter | None: val = kwargs.get(self.param_name) if val is None: return None return BooleanFilter(field_name=self.field_name, value=val) class _ChoicesFilterProvider: __slots__ = ("__signature__", "choices", "field_name", "param_name", "return_annotation") def __init__(self, field_name: str, choices: list[Any] | tuple[Any, ...] | type[Enum]) -> None: self.field_name = field_name self.choices = choices self.param_name = f"{field_name}_choices" choices_type = normalize_choice_field_types(choices) self.return_annotation = ChoicesFilter[Any] | None parameter_annotation = _query_parameter_annotation( _collection_value_annotation(list, choices_type), Query(alias=camelize(self.param_name), description=f"Filter {field_name} by choices"), ) self.__signature__ = inspect.Signature( parameters=[ inspect.Parameter( self.param_name, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=parameter_annotation ) ], return_annotation=self.return_annotation, ) def __call__(self, **kwargs: Any) -> Any: values = kwargs.get(self.param_name) if not values: return None return ChoicesFilter[Any](field_name=self.field_name, values=values) def _add_dependency(params: list[inspect.Parameter], annotations: dict[str, Any], name: str, provider: Any) -> None: dependency_annotation = _query_parameter_annotation(provider.return_annotation, Depends(provider)) params.append(inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=dependency_annotation)) annotations[name] = dependency_annotation