Source code for sqlspec.core.filters

"""Filter system for SQL statement manipulation.

This module provides filters that can be applied to SQL statements to add
WHERE clauses, ORDER BY clauses, LIMIT/OFFSET, and other modifications.

Components:
- StatementFilter: Abstract base class for all filters
- BeforeAfterFilter: Date range filtering
- InCollectionFilter: IN clause filtering
- LimitOffsetFilter: Pagination support
- OrderByFilter: Sorting support
- SearchFilter: Text search filtering
- Various collection and negation filters

Features:
- Parameter conflict resolution
- Type-safe filter application
- Cacheable filter configurations
"""

import uuid
from abc import ABC, abstractmethod
from collections import abc
from collections.abc import Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias

import sqlglot
from sqlglot import exp
from typing_extensions import TypeVar

if TYPE_CHECKING:
    from sqlglot.expressions import Condition

    from sqlspec.core.statement import SQL

__all__ = (
    "AnyCollectionFilter",
    "BeforeAfterFilter",
    "FilterTypeT",
    "FilterTypes",
    "InAnyFilter",
    "InCollectionFilter",
    "LimitOffsetFilter",
    "NotAnyCollectionFilter",
    "NotInCollectionFilter",
    "NotInSearchFilter",
    "OffsetPagination",
    "OnBeforeAfterFilter",
    "OrderByFilter",
    "PaginationFilter",
    "SearchFilter",
    "StatementFilter",
    "apply_filter",
    "canonicalize_filters",
    "create_filters",
)

T = TypeVar("T")
FilterTypeT = TypeVar("FilterTypeT", bound="StatementFilter")


[docs] class StatementFilter(ABC): """Abstract base class for filters that can be appended to a statement.""" __slots__ = ()
[docs] @abstractmethod def append_to_statement(self, statement: "SQL") -> "SQL": """Append the filter to the statement. This method should modify the SQL expression only, not the parameters. Parameters should be provided via extract_parameters(). """
[docs] def extract_parameters(self) -> "tuple[list[Any], dict[str, Any]]": """Extract parameters that this filter contributes. Returns: Tuple of (positional_parameters, named_parameters) where: - positional_parameters: List of positional parameter values - named_parameters: Dict of parameter name to value """ return [], {}
def _resolve_parameter_conflicts(self, statement: "SQL", proposed_names: list[str]) -> list[str]: """Resolve parameter name conflicts. Args: statement: The SQL statement to check for existing parameters proposed_names: List of proposed parameter names Returns: List of resolved parameter names (same length as proposed_names) """ existing_params = set(statement.named_parameters.keys()) existing_params.update(statement.parameters.keys() if isinstance(statement.parameters, dict) else []) resolved_names = [] for name in proposed_names: if name in existing_params: unique_suffix = str(uuid.uuid4()).replace("-", "")[:8] resolved_name = f"{name}_{unique_suffix}" else: resolved_name = name resolved_names.append(resolved_name) existing_params.add(resolved_name) return resolved_names
[docs] @abstractmethod def get_cache_key(self) -> tuple[Any, ...]: """Return a cache key for this filter's configuration. Returns: Tuple of hashable values representing the filter's configuration """
class BeforeAfterFilter(StatementFilter): """Filter for datetime range queries. Applies WHERE clauses for before/after datetime filtering. """ __slots__ = ("_after", "_before", "_field_name") def __init__(self, field_name: str, before: datetime | None = None, after: datetime | None = None) -> None: self._field_name = field_name self._before = before self._after = after @property def field_name(self) -> str: return self._field_name @property def before(self) -> datetime | None: return self._before @property def after(self) -> datetime | None: return self._after def get_param_names(self) -> list[str]: """Get parameter names without storing them.""" names = [] if self.before: names.append(f"{self.field_name}_before") if self.after: names.append(f"{self.field_name}_after") return names def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} param_names = self.get_param_names() param_idx = 0 if self.before: named_parameters[param_names[param_idx]] = self.before param_idx += 1 if self.after: named_parameters[param_names[param_idx]] = self.after return [], named_parameters def append_to_statement(self, statement: "SQL") -> "SQL": """Apply filter to SQL expression only.""" conditions: list[Condition] = [] col_expr = exp.column(self.field_name) proposed_names = self.get_param_names() if not proposed_names: return statement resolved_names = self._resolve_parameter_conflicts(statement, proposed_names) param_idx = 0 result = statement if self.before: before_param_name = resolved_names[param_idx] param_idx += 1 conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=before_param_name))) result = result.add_named_parameter(before_param_name, self.before) if self.after: after_param_name = resolved_names[param_idx] conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=after_param_name))) result = result.add_named_parameter(after_param_name, self.after) final_condition = conditions[0] for cond in conditions[1:]: final_condition = exp.And(this=final_condition, expression=cond) return result.where(final_condition) def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" return ("BeforeAfterFilter", self.field_name, self.before, self.after) class OnBeforeAfterFilter(StatementFilter): """Filter for inclusive datetime range queries. Applies WHERE clauses for on-or-before/on-or-after datetime filtering. """ __slots__ = ("_field_name", "_on_or_after", "_on_or_before") def __init__( self, field_name: str, on_or_before: datetime | None = None, on_or_after: datetime | None = None ) -> None: self._field_name = field_name self._on_or_before = on_or_before self._on_or_after = on_or_after @property def field_name(self) -> str: return self._field_name @property def on_or_before(self) -> datetime | None: return self._on_or_before @property def on_or_after(self) -> datetime | None: return self._on_or_after def get_param_names(self) -> list[str]: """Get parameter names without storing them.""" names = [] if self.on_or_before: names.append(f"{self.field_name}_on_or_before") if self.on_or_after: names.append(f"{self.field_name}_on_or_after") return names def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} param_names = self.get_param_names() param_idx = 0 if self.on_or_before: named_parameters[param_names[param_idx]] = self.on_or_before param_idx += 1 if self.on_or_after: named_parameters[param_names[param_idx]] = self.on_or_after return [], named_parameters def append_to_statement(self, statement: "SQL") -> "SQL": conditions: list[Condition] = [] proposed_names = self.get_param_names() if not proposed_names: return statement resolved_names = self._resolve_parameter_conflicts(statement, proposed_names) param_idx = 0 result = statement if self.on_or_before: before_param_name = resolved_names[param_idx] param_idx += 1 conditions.append( exp.LTE(this=exp.column(self.field_name), expression=exp.Placeholder(this=before_param_name)) ) result = result.add_named_parameter(before_param_name, self.on_or_before) if self.on_or_after: after_param_name = resolved_names[param_idx] conditions.append( exp.GTE(this=exp.column(self.field_name), expression=exp.Placeholder(this=after_param_name)) ) result = result.add_named_parameter(after_param_name, self.on_or_after) final_condition = conditions[0] for cond in conditions[1:]: final_condition = exp.And(this=final_condition, expression=cond) return result.where(final_condition) def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" return ("OnBeforeAfterFilter", self.field_name, self.on_or_before, self.on_or_after) class InAnyFilter(StatementFilter, ABC, Generic[T]): """Base class for collection-based filters that support ANY operations.""" __slots__ = () def append_to_statement(self, statement: "SQL") -> "SQL": raise NotImplementedError class InCollectionFilter(InAnyFilter[T]): """Filter for IN clause queries. Constructs WHERE ... IN (...) clauses. """ __slots__ = ("_field_name", "_values") def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None: self._field_name = field_name self._values = values @property def field_name(self) -> str: return self._field_name @property def values(self) -> abc.Collection[T] | None: return self._values def get_param_names(self) -> list[str]: """Get parameter names without storing them.""" if not self.values: return [] return [f"{self.field_name}_in_{i}" for i, _ in enumerate(self.values)] def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} if self.values: param_names = self.get_param_names() for i, value in enumerate(self.values): named_parameters[param_names[i]] = value return [], named_parameters def append_to_statement(self, statement: "SQL") -> "SQL": if self.values is None: return statement if not self.values: return statement.where(exp.false()) resolved_names = self._resolve_parameter_conflicts(statement, self.get_param_names()) placeholder_expressions: list[exp.Placeholder] = [ exp.Placeholder(this=param_name) for param_name in resolved_names ] result = statement.where(exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)) for resolved_name, value in zip(resolved_names, self.values, strict=False): result = result.add_named_parameter(resolved_name, value) return result def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" values_tuple = tuple(self.values) if self.values is not None else None return ("InCollectionFilter", self.field_name, values_tuple) class NotInCollectionFilter(InAnyFilter[T]): """Filter for NOT IN clause queries. Constructs WHERE ... NOT IN (...) clauses. """ __slots__ = ("_field_name", "_values") def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None: self._field_name = field_name self._values = values @property def field_name(self) -> str: return self._field_name @property def values(self) -> abc.Collection[T] | None: return self._values def get_param_names(self) -> list[str]: """Get parameter names without storing them.""" if not self.values: return [] # Use object id to ensure uniqueness between instances return [f"{self.field_name}_notin_{i}_{id(self)}" for i, _ in enumerate(self.values)] def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} if self.values: param_names = self.get_param_names() for i, value in enumerate(self.values): named_parameters[param_names[i]] = value return [], named_parameters def append_to_statement(self, statement: "SQL") -> "SQL": if self.values is None or not self.values: return statement resolved_names = self._resolve_parameter_conflicts(statement, self.get_param_names()) placeholder_expressions: list[exp.Placeholder] = [ exp.Placeholder(this=param_name) for param_name in resolved_names ] result = statement.where( exp.Not(this=exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)) ) for resolved_name, value in zip(resolved_names, self.values, strict=False): result = result.add_named_parameter(resolved_name, value) return result def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" values_tuple = tuple(self.values) if self.values is not None else None return ("NotInCollectionFilter", self.field_name, values_tuple) class AnyCollectionFilter(InAnyFilter[T]): """Filter for PostgreSQL-style ANY clause queries. Constructs WHERE column_name = ANY (array_expression) clauses. """ __slots__ = ("_field_name", "_values") def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None: self._field_name = field_name self._values = values @property def field_name(self) -> str: return self._field_name @property def values(self) -> abc.Collection[T] | None: return self._values def get_param_names(self) -> list[str]: """Get parameter names without storing them.""" if not self.values: return [] return [f"{self.field_name}_any_{i}" for i, _ in enumerate(self.values)] def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} if self.values: param_names = self.get_param_names() for i, value in enumerate(self.values): named_parameters[param_names[i]] = value return [], named_parameters def append_to_statement(self, statement: "SQL") -> "SQL": if self.values is None: return statement if not self.values: return statement.where(exp.false()) resolved_names = self._resolve_parameter_conflicts(statement, self.get_param_names()) placeholder_expressions: list[exp.Expression] = [ exp.Placeholder(this=param_name) for param_name in resolved_names ] array_expr = exp.Array(expressions=placeholder_expressions) result = statement.where(exp.EQ(this=exp.column(self.field_name), expression=exp.Any(this=array_expr))) for resolved_name, value in zip(resolved_names, self.values, strict=False): result = result.add_named_parameter(resolved_name, value) return result def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" values_tuple = tuple(self.values) if self.values is not None else None return ("AnyCollectionFilter", self.field_name, values_tuple) class NotAnyCollectionFilter(InAnyFilter[T]): """Filter for PostgreSQL-style NOT ANY clause queries. Constructs WHERE NOT (column_name = ANY (array_expression)) clauses. """ __slots__ = ("_field_name", "_values") def __init__(self, field_name: str, values: abc.Collection[T] | None = None) -> None: self._field_name = field_name self._values = values @property def field_name(self) -> str: return self._field_name @property def values(self) -> abc.Collection[T] | None: return self._values def get_param_names(self) -> list[str]: """Get parameter names without storing them.""" if not self.values: return [] return [f"{self.field_name}_not_any_{i}" for i, _ in enumerate(self.values)] def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} if self.values: param_names = self.get_param_names() for i, value in enumerate(self.values): named_parameters[param_names[i]] = value return [], named_parameters def append_to_statement(self, statement: "SQL") -> "SQL": if self.values is None or not self.values: return statement resolved_names = self._resolve_parameter_conflicts(statement, self.get_param_names()) placeholder_expressions: list[exp.Expression] = [ exp.Placeholder(this=param_name) for param_name in resolved_names ] array_expr = exp.Array(expressions=placeholder_expressions) condition = exp.EQ(this=exp.column(self.field_name), expression=exp.Any(this=array_expr)) result = statement.where(exp.Not(this=condition)) for resolved_name, value in zip(resolved_names, self.values, strict=False): result = result.add_named_parameter(resolved_name, value) return result def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" values_tuple = tuple(self.values) if self.values is not None else None return ("NotAnyCollectionFilter", self.field_name, values_tuple) class PaginationFilter(StatementFilter, ABC): """Base class for pagination-related filters.""" __slots__ = () @abstractmethod def append_to_statement(self, statement: "SQL") -> "SQL": raise NotImplementedError
[docs] class LimitOffsetFilter(PaginationFilter): """Filter for LIMIT and OFFSET clauses. Adds pagination support through LIMIT/OFFSET SQL clauses. """ __slots__ = ("_limit", "_offset")
[docs] def __init__(self, limit: int, offset: int) -> None: self._limit = limit self._offset = offset
@property def limit(self) -> int: return self._limit @property def offset(self) -> int: return self._offset
[docs] def get_param_names(self) -> list[str]: """Get parameter names without storing them.""" return ["limit", "offset"]
[docs] def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" param_names = self.get_param_names() return [], {param_names[0]: self.limit, param_names[1]: self.offset}
[docs] def append_to_statement(self, statement: "SQL") -> "SQL": resolved_names = self._resolve_parameter_conflicts(statement, self.get_param_names()) limit_param_name, offset_param_name = resolved_names limit_placeholder = exp.Placeholder(this=limit_param_name) offset_placeholder = exp.Placeholder(this=offset_param_name) if statement.statement_expression is not None: current_statement = statement.statement_expression else: try: current_statement = sqlglot.parse_one(statement.raw_sql, dialect=statement.dialect) except Exception: current_statement = exp.Select().from_(f"({statement.raw_sql})") if isinstance(current_statement, exp.Select): new_statement = current_statement.limit(limit_placeholder).offset(offset_placeholder) else: new_statement = exp.Select().from_(current_statement).limit(limit_placeholder).offset(offset_placeholder) result = statement.copy(statement=new_statement) result = result.add_named_parameter(limit_param_name, self.limit) return result.add_named_parameter(offset_param_name, self.offset)
[docs] def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" return ("LimitOffsetFilter", self.limit, self.offset)
[docs] class OrderByFilter(StatementFilter): """Filter for ORDER BY clauses. Adds sorting capability to SQL queries. """ __slots__ = ("_field_name", "_sort_order")
[docs] def __init__(self, field_name: str, sort_order: Literal["asc", "desc"] = "asc") -> None: self._field_name = field_name self._sort_order = sort_order
@property def field_name(self) -> str: return self._field_name @property def sort_order(self) -> Literal["asc", "desc"]: return self._sort_order # pyright: ignore
[docs] def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" return [], {}
[docs] def append_to_statement(self, statement: "SQL") -> "SQL": converted_sort_order = self.sort_order.lower() if converted_sort_order not in {"asc", "desc"}: converted_sort_order = "asc" col_expr = exp.column(self.field_name) order_expr = col_expr.desc() if converted_sort_order == "desc" else col_expr.asc() if statement.statement_expression is not None: current_statement = statement.statement_expression else: try: current_statement = sqlglot.parse_one(statement.raw_sql, dialect=statement.dialect) except Exception: current_statement = exp.Select().from_(f"({statement.raw_sql})") if isinstance(current_statement, exp.Select): new_statement = current_statement.order_by(order_expr) else: new_statement = exp.Select().from_(current_statement).order_by(order_expr) return statement.copy(statement=new_statement)
[docs] def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" return ("OrderByFilter", self.field_name, self.sort_order)
[docs] class SearchFilter(StatementFilter): """Filter for text search queries. Constructs WHERE field_name LIKE '%value%' clauses. """ __slots__ = ("_field_name", "_ignore_case", "_value")
[docs] def __init__(self, field_name: str | set[str], value: str, ignore_case: bool | None = False) -> None: self._field_name = field_name self._value = value self._ignore_case = ignore_case
@property def field_name(self) -> str | set[str]: return self._field_name @property def value(self) -> str: return self._value @property def ignore_case(self) -> bool | None: return self._ignore_case
[docs] def get_param_name(self) -> str | None: """Get parameter name without storing it.""" if not self.value: return None if isinstance(self.field_name, str): return f"{self.field_name}_search" return "search_value"
[docs] def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} param_name = self.get_param_name() if self.value and param_name: search_value_with_wildcards = f"%{self.value}%" named_parameters[param_name] = search_value_with_wildcards return [], named_parameters
[docs] def append_to_statement(self, statement: "SQL") -> "SQL": param_name = self.get_param_name() if not self.value or not param_name: return statement resolved_names = self._resolve_parameter_conflicts(statement, [param_name]) param_name = resolved_names[0] pattern_expr = exp.Placeholder(this=param_name) like_op = exp.ILike if self.ignore_case else exp.Like if isinstance(self.field_name, str): result = statement.where(like_op(this=exp.column(self.field_name), expression=pattern_expr)) elif isinstance(self.field_name, set) and self.field_name: field_conditions: list[Condition] = [ like_op(this=exp.column(field), expression=pattern_expr) for field in self.field_name ] if not field_conditions: return statement final_condition: Condition = field_conditions[0] for cond in field_conditions[1:]: final_condition = exp.Or(this=final_condition, expression=cond) result = statement.where(final_condition) else: result = statement search_value_with_wildcards = f"%{self.value}%" return result.add_named_parameter(param_name, search_value_with_wildcards)
[docs] def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" field_names = tuple(sorted(self.field_name)) if isinstance(self.field_name, set) else self.field_name return ("SearchFilter", field_names, self.value, self.ignore_case)
class NotInSearchFilter(SearchFilter): """Filter for negated text search queries. Constructs WHERE field_name NOT LIKE '%value%' clauses. """ def get_param_name(self) -> str | None: """Get parameter name without storing it.""" if not self.value: return None if isinstance(self.field_name, str): return f"{self.field_name}_not_search" return "not_search_value" def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]: """Extract filter parameters.""" named_parameters = {} param_name = self.get_param_name() if self.value and param_name: search_value_with_wildcards = f"%{self.value}%" named_parameters[param_name] = search_value_with_wildcards return [], named_parameters def append_to_statement(self, statement: "SQL") -> "SQL": param_name = self.get_param_name() if not self.value or not param_name: return statement resolved_names = self._resolve_parameter_conflicts(statement, [param_name]) param_name = resolved_names[0] pattern_expr = exp.Placeholder(this=param_name) like_op = exp.ILike if self.ignore_case else exp.Like result = statement if isinstance(self.field_name, str): result = statement.where(exp.Not(this=like_op(this=exp.column(self.field_name), expression=pattern_expr))) elif isinstance(self.field_name, set) and self.field_name: field_conditions: list[Condition] = [ exp.Not(this=like_op(this=exp.column(field), expression=pattern_expr)) for field in self.field_name ] if not field_conditions: return statement final_condition: Condition = field_conditions[0] if len(field_conditions) > 1: for cond in field_conditions[1:]: final_condition = exp.And(this=final_condition, expression=cond) result = statement.where(final_condition) search_value_with_wildcards = f"%{self.value}%" return result.add_named_parameter(param_name, search_value_with_wildcards) def get_cache_key(self) -> tuple[Any, ...]: """Return cache key for this filter configuration.""" field_names = tuple(sorted(self.field_name)) if isinstance(self.field_name, set) else self.field_name return ("NotInSearchFilter", field_names, self.value, self.ignore_case) class OffsetPagination(Generic[T]): """Container for data returned using limit/offset pagination.""" __slots__ = ("items", "limit", "offset", "total") items: Sequence[T] limit: int offset: int total: int def __init__(self, items: Sequence[T], limit: int, offset: int, total: int) -> None: """Initialize OffsetPagination. Args: items: List of data being sent as part of the response. limit: Maximal number of items to send. offset: Offset from the beginning of the query. Identical to an index. total: Total number of items. """ self.items = items self.limit = limit self.offset = offset self.total = total def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL": """Apply a statement filter to a SQL query object. Args: statement: The SQL query object to modify. filter_obj: The filter to apply. Returns: The modified query object. """ return filter_obj.append_to_statement(statement) FilterTypes: TypeAlias = ( BeforeAfterFilter | OnBeforeAfterFilter | InCollectionFilter[Any] | LimitOffsetFilter | OrderByFilter | SearchFilter | NotInCollectionFilter[Any] | NotInSearchFilter | AnyCollectionFilter[Any] | NotAnyCollectionFilter[Any] ) def create_filters(filters: "list[StatementFilter]") -> tuple["StatementFilter", ...]: """Convert mutable filters to immutable tuple. Since StatementFilter classes are now immutable (with read-only properties), we just need to convert to a tuple for consistent sharing. Args: filters: List of StatementFilter objects (already immutable) Returns: Tuple of StatementFilter objects """ return tuple(filters) def canonicalize_filters(filters: "list[StatementFilter]") -> tuple["StatementFilter", ...]: """Sort filters by type and field_name for consistent hashing. Args: filters: List of StatementFilter objects Returns: Canonically sorted tuple of filters """ def sort_key(f: "StatementFilter") -> tuple[str, str]: class_name = type(f).__name__ field_name = getattr(f, "field_name", "") return (class_name, str(field_name)) return tuple(sorted(filters, key=sort_key))