# ruff: noqa: B008
"""Application dependency providers generators.
This module contains functions to create dependency providers for services and filters.
"""
import datetime
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypedDict, cast
from uuid import UUID
from litestar.di import Provide
from litestar.exceptions import ValidationException
from litestar.params import Dependency, Parameter
from typing_extensions import NotRequired
from sqlspec.core import (
BeforeAfterFilter,
FilterTypes,
InCollectionFilter,
LimitOffsetFilter,
NotInCollectionFilter,
NotNullFilter,
NullFilter,
OrderByFilter,
SearchFilter,
)
from sqlspec.utils.singleton import SingletonMeta
from sqlspec.utils.text import camelize
if TYPE_CHECKING:
from sqlglot import exp
__all__ = (
"DEPENDENCY_DEFAULTS",
"BooleanOrNone",
"DTorNone",
"DependencyDefaults",
"FieldNameType",
"FilterConfig",
"HashableType",
"HashableValue",
"IntOrNone",
"SortField",
"SortOrder",
"SortOrderOrNone",
"StringOrNone",
"UuidOrNone",
"create_filter_dependencies",
"dep_cache",
)
DTorNone = datetime.datetime | None
StringOrNone = str | None
UuidOrNone = UUID | None
IntOrNone = int | None
BooleanOrNone = bool | None
SortOrder = Literal["asc", "desc"]
SortOrderOrNone = SortOrder | None
SortField = str | set[str] | list[str]
HashableValue = str | int | float | bool | None
HashableType = HashableValue | tuple[Any, ...] | tuple[tuple[str, Any], ...] | tuple[HashableValue, ...]
[docs]
class DependencyDefaults:
FILTERS_DEPENDENCY_KEY: str = "filters"
CREATED_FILTER_DEPENDENCY_KEY: str = "created_filter"
ID_FILTER_DEPENDENCY_KEY: str = "id_filter"
LIMIT_OFFSET_FILTER_DEPENDENCY_KEY: str = "limit_offset_filter"
UPDATED_FILTER_DEPENDENCY_KEY: str = "updated_filter"
ORDER_BY_FILTER_DEPENDENCY_KEY: str = "order_by_filter"
SEARCH_FILTER_DEPENDENCY_KEY: str = "search_filter"
DEFAULT_PAGINATION_SIZE: int = 20
DEPENDENCY_DEFAULTS = DependencyDefaults()
class FieldNameType(NamedTuple):
"""Type for field name and associated type information for filter configuration."""
name: str
type_hint: type[Any] = str
[docs]
class FilterConfig(TypedDict):
"""Configuration for generating dynamic filters."""
id_filter: NotRequired[type[UUID | int | str]]
id_field: NotRequired[str]
sort_field: NotRequired[SortField]
sort_order: NotRequired[SortOrder]
pagination_type: NotRequired[Literal["limit_offset"]]
pagination_size: NotRequired[int]
search: NotRequired[str | set[str] | list[str]]
search_ignore_case: NotRequired[bool]
created_at: NotRequired[bool]
updated_at: NotRequired[bool]
not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
null_fields: NotRequired[str | set[str] | list[str]]
"""Fields that support IS NULL filtering."""
not_null_fields: NotRequired[str | set[str] | list[str]]
"""Fields that support IS NOT NULL filtering."""
class DependencyCache(metaclass=SingletonMeta):
"""Dependency cache for memoizing dynamically generated dependencies."""
def __init__(self) -> None:
self.dependencies: dict[int | str, dict[str, Provide]] = {}
def add_dependencies(self, key: int | str, dependencies: dict[str, Provide]) -> None:
self.dependencies[key] = dependencies
def get_dependencies(self, key: int | str) -> dict[str, Provide] | None:
return self.dependencies.get(key)
dep_cache = DependencyCache()
def create_filter_dependencies(
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
) -> dict[str, Provide]:
"""Create a dependency provider for the combined filter function.
Args:
config: FilterConfig instance with desired settings.
dep_defaults: Dependency defaults to use for the filter dependencies
Returns:
A dependency provider function for the combined filter function.
"""
if (deps := dep_cache.get_dependencies(cache_key := hash(_make_hashable(config)))) is not None:
return deps
deps = _create_statement_filters(config, dep_defaults)
dep_cache.add_dependencies(cache_key, deps)
return deps
def _make_hashable(value: Any) -> HashableType:
"""Convert a value into a hashable type for caching purposes.
Args:
value: Any value that needs to be made hashable.
Returns:
A hashable version of the value.
"""
if isinstance(value, dict):
items = []
for k in sorted(value.keys()): # pyright: ignore
v = value[k]
items.append((str(k), _make_hashable(v)))
return tuple(items)
if isinstance(value, (list, set)):
hashable_items = [_make_hashable(item) for item in value]
filtered_items = [item for item in hashable_items if item is not None]
return tuple(sorted(filtered_items, key=str))
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)
def _resolve_sort_fields(sort_field: SortField) -> tuple[str, set[str]]:
if isinstance(sort_field, str):
return sort_field, {sort_field}
fields = tuple(sorted(sort_field)) if isinstance(sort_field, set) else tuple(sort_field)
return fields[0], set(fields)
def _create_statement_filters( # noqa: C901
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
) -> dict[str, Provide]:
"""Create filter dependencies based on configuration.
Args:
config: Configuration dictionary specifying which filters to enable
dep_defaults: Dependency defaults to use for the filter dependencies
Returns:
Dictionary of filter provider functions
"""
filters: dict[str, Provide] = {}
if config.get("id_filter", False):
def provide_id_filter( # pyright: ignore[reportUnknownParameterType]
ids: list[str] | None = Parameter(query="ids", default=None, required=False),
) -> InCollectionFilter: # pyright: ignore[reportMissingTypeArgument]
return InCollectionFilter(field_name=config.get("id_field", "id"), values=ids)
filters[dep_defaults.ID_FILTER_DEPENDENCY_KEY] = Provide(provide_id_filter, sync_to_thread=False) # pyright: ignore[reportUnknownArgumentType]
if config.get("created_at", False):
def provide_created_filter(
before: DTorNone = Parameter(query="createdBefore", default=None, required=False),
after: DTorNone = Parameter(query="createdAfter", default=None, required=False),
) -> BeforeAfterFilter:
return BeforeAfterFilter("created_at", before, after)
filters[dep_defaults.CREATED_FILTER_DEPENDENCY_KEY] = Provide(provide_created_filter, sync_to_thread=False)
if config.get("updated_at", False):
def provide_updated_filter(
before: DTorNone = Parameter(query="updatedBefore", default=None, required=False),
after: DTorNone = Parameter(query="updatedAfter", default=None, required=False),
) -> BeforeAfterFilter:
return BeforeAfterFilter("updated_at", before, after)
filters[dep_defaults.UPDATED_FILTER_DEPENDENCY_KEY] = Provide(provide_updated_filter, sync_to_thread=False)
if config.get("pagination_type") == "limit_offset":
def provide_limit_offset_pagination(
current_page: int = Parameter(ge=1, query="currentPage", default=1, required=False),
page_size: int = Parameter(
query="pageSize",
ge=1,
default=config.get("pagination_size", dep_defaults.DEFAULT_PAGINATION_SIZE),
required=False,
),
) -> LimitOffsetFilter:
return LimitOffsetFilter(page_size, page_size * (current_page - 1))
filters[dep_defaults.LIMIT_OFFSET_FILTER_DEPENDENCY_KEY] = Provide(
provide_limit_offset_pagination, sync_to_thread=False
)
if search_fields := config.get("search"):
def provide_search_filter(
search_string: StringOrNone = Parameter(
title="Field to search", query="searchString", default=None, required=False
),
ignore_case: BooleanOrNone = Parameter(
title="Search should be case sensitive",
query="searchIgnoreCase",
default=config.get("search_ignore_case", False),
required=False,
),
) -> SearchFilter:
field_names: set[str | exp.Expression] = (
set(search_fields.split(",")) if isinstance(search_fields, str) else set(search_fields)
)
return SearchFilter(field_name=field_names, value=search_string, ignore_case=ignore_case or False)
filters[dep_defaults.SEARCH_FILTER_DEPENDENCY_KEY] = Provide(provide_search_filter, sync_to_thread=False)
if sort_field := config.get("sort_field"):
default_field, allowed_fields = _resolve_sort_fields(sort_field)
allowed_field_names = ", ".join(sorted(allowed_fields))
sort_order_default = config.get("sort_order", "desc")
def provide_order_by(
field_name: StringOrNone = Parameter(
title="Order by field", query="orderBy", default=default_field, required=False
),
sort_order: SortOrderOrNone = Parameter(
title="Field to search", query="sortOrder", default=sort_order_default, required=False
),
) -> OrderByFilter:
resolved_field = field_name or default_field
if resolved_field not in allowed_fields:
msg = f"Invalid orderBy field '{resolved_field}'. Allowed fields: {allowed_field_names}"
raise ValidationException(detail=msg)
return OrderByFilter(field_name=resolved_field, sort_order=sort_order or sort_order_default)
filters[dep_defaults.ORDER_BY_FILTER_DEPENDENCY_KEY] = Provide(provide_order_by, sync_to_thread=False)
if not_in_fields := config.get("not_in_fields"):
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
for field_def in not_in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
def create_not_in_filter_provider( # pyright: ignore
field_name: FieldNameType,
) -> Callable[..., NotInCollectionFilter[field_def.type_hint] | None]: # type: ignore
def provide_not_in_filter( # pyright: ignore
values: list[field_name.type_hint] | None = Parameter( # type: ignore
query=camelize(f"{field_name.name}_not_in"), default=None, required=False
),
) -> NotInCollectionFilter[field_name.type_hint] | None: # type: ignore
return (
NotInCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore
if values
else None
)
return provide_not_in_filter # pyright: ignore
provider = create_not_in_filter_provider(field_def) # pyright: ignore
filters[f"{field_def.name}_not_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
if in_fields := config.get("in_fields"):
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
for field_def in in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
def create_in_filter_provider( # pyright: ignore
field_name: FieldNameType,
) -> Callable[..., InCollectionFilter[field_def.type_hint] | None]: # type: ignore # pyright: ignore
def provide_in_filter( # pyright: ignore
values: list[field_name.type_hint] | None = Parameter( # type: ignore # pyright: ignore
query=camelize(f"{field_name.name}_in"), default=None, required=False
),
) -> InCollectionFilter[field_name.type_hint] | None: # type: ignore # pyright: ignore
return (
InCollectionFilter[field_name.type_hint](field_name=field_name.name, values=values) # type: ignore # pyright: ignore
if values
else None
)
return provide_in_filter # pyright: ignore
provider = create_in_filter_provider(field_def) # type: ignore
filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
if null_fields := config.get("null_fields"):
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
for field_name in null_fields:
def create_null_filter_provider(fname: str) -> Callable[..., NullFilter | None]:
def provide_null_filter(
is_null: bool | None = Parameter(query=camelize(f"{fname}_is_null"), default=None, required=False),
) -> NullFilter | None:
return NullFilter(field_name=fname) if is_null else None
return provide_null_filter
null_provider = create_null_filter_provider(field_name)
filters[f"{field_name}_null_filter"] = Provide(null_provider, sync_to_thread=False)
if not_null_fields := config.get("not_null_fields"):
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
for field_name in not_null_fields:
def create_not_null_filter_provider(fname: str) -> Callable[..., NotNullFilter | None]:
def provide_not_null_filter(
is_not_null: bool | None = Parameter(
query=camelize(f"{fname}_is_not_null"), default=None, required=False
),
) -> NotNullFilter | None:
return NotNullFilter(field_name=fname) if is_not_null else None
return provide_not_null_filter
not_null_provider = create_not_null_filter_provider(field_name)
filters[f"{field_name}_not_null_filter"] = Provide(not_null_provider, sync_to_thread=False)
if filters:
filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(
_create_filter_aggregate_function(config), sync_to_thread=False
)
return filters
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: # noqa: C901
"""Create filter aggregation function based on configuration.
Args:
config: The filter configuration.
Returns:
Function that returns list of configured filters.
"""
parameters: dict[str, inspect.Parameter] = {}
annotations: dict[str, Any] = {}
if cls := config.get("id_filter"):
parameters["id_filter"] = inspect.Parameter(
name="id_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=InCollectionFilter[cls], # type: ignore[valid-type]
)
annotations["id_filter"] = InCollectionFilter[cls] # type: ignore[valid-type]
if config.get("created_at"):
parameters["created_filter"] = inspect.Parameter(
name="created_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=BeforeAfterFilter,
)
annotations["created_filter"] = BeforeAfterFilter
if config.get("updated_at"):
parameters["updated_filter"] = inspect.Parameter(
name="updated_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=BeforeAfterFilter,
)
annotations["updated_filter"] = BeforeAfterFilter
if config.get("search"):
parameters["search_filter"] = inspect.Parameter(
name="search_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=SearchFilter,
)
annotations["search_filter"] = SearchFilter
if config.get("pagination_type") == "limit_offset":
parameters["limit_offset_filter"] = inspect.Parameter(
name="limit_offset_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=LimitOffsetFilter,
)
annotations["limit_offset_filter"] = LimitOffsetFilter
if config.get("sort_field"):
parameters["order_by_filter"] = inspect.Parameter(
name="order_by_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=OrderByFilter,
)
annotations["order_by_filter"] = OrderByFilter
if not_in_fields := config.get("not_in_fields"):
for field_def in not_in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
parameters[f"{field_def.name}_not_in_filter"] = inspect.Parameter(
name=f"{field_def.name}_not_in_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=NotInCollectionFilter[field_def.type_hint], # type: ignore
)
annotations[f"{field_def.name}_not_in_filter"] = NotInCollectionFilter[field_def.type_hint] # type: ignore
if in_fields := config.get("in_fields"):
for field_def in in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
parameters[f"{field_def.name}_in_filter"] = inspect.Parameter(
name=f"{field_def.name}_in_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=InCollectionFilter[field_def.type_hint], # type: ignore
)
annotations[f"{field_def.name}_in_filter"] = InCollectionFilter[field_def.type_hint] # type: ignore
if null_fields := config.get("null_fields"):
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
for field_name in null_fields:
parameters[f"{field_name}_null_filter"] = inspect.Parameter(
name=f"{field_name}_null_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=NullFilter | None,
)
annotations[f"{field_name}_null_filter"] = NullFilter | None
if not_null_fields := config.get("not_null_fields"):
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
for field_name in not_null_fields:
parameters[f"{field_name}_not_null_filter"] = inspect.Parameter(
name=f"{field_name}_not_null_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=NotNullFilter | None,
)
annotations[f"{field_name}_not_null_filter"] = NotNullFilter | None
def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
"""Aggregate filter dependencies based on configuration.
Args:
**kwargs: Filter parameters dynamically provided based on configuration.
Returns:
List of configured filters.
"""
filters: list[FilterTypes] = []
if id_filter := kwargs.get("id_filter"):
filters.append(id_filter)
if created_filter := kwargs.get("created_filter"):
filters.append(created_filter)
if limit_offset := kwargs.get("limit_offset_filter"):
filters.append(limit_offset)
if updated_filter := kwargs.get("updated_filter"):
filters.append(updated_filter)
if (
(search_filter := cast("SearchFilter | None", kwargs.get("search_filter")))
and search_filter is not None # pyright: ignore[reportUnnecessaryComparison]
and search_filter.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
and search_filter.value is not None # pyright: ignore[reportUnnecessaryComparison]
):
filters.append(search_filter)
if (
(order_by := cast("OrderByFilter | None", kwargs.get("order_by_filter")))
and order_by is not None # pyright: ignore[reportUnnecessaryComparison]
and order_by.field_name is not None # pyright: ignore[reportUnnecessaryComparison]
):
filters.append(order_by)
if not_in_fields := config.get("not_in_fields"):
not_in_fields = {not_in_fields} if isinstance(not_in_fields, (str, FieldNameType)) else not_in_fields
for field_def in not_in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
filter_ = kwargs.get(f"{field_def.name}_not_in_filter")
if filter_ is not None:
filters.append(filter_)
if in_fields := config.get("in_fields"):
in_fields = {in_fields} if isinstance(in_fields, (str, FieldNameType)) else in_fields
for field_def in in_fields:
field_def = FieldNameType(name=field_def, type_hint=str) if isinstance(field_def, str) else field_def
filter_ = kwargs.get(f"{field_def.name}_in_filter")
if filter_ is not None:
filters.append(filter_)
if null_fields := config.get("null_fields"):
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
for field_name in null_fields:
filter_ = kwargs.get(f"{field_name}_null_filter")
if filter_ is not None:
filters.append(filter_)
if not_null_fields := config.get("not_null_fields"):
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
for field_name in not_null_fields:
filter_ = kwargs.get(f"{field_name}_not_null_filter")
if filter_ is not None:
filters.append(filter_)
return filters
provide_filters.__signature__ = inspect.Signature( # type: ignore
parameters=list(parameters.values()), return_annotation=list[FilterTypes]
)
provide_filters.__annotations__ = annotations
provide_filters.__annotations__["return"] = list[FilterTypes]
return provide_filters