"""Shared query modification utilities for SQL and builder classes.
This module provides pure functions for building SQL expressions that can be
used by both the immutable SQL class and the mutable builder classes. All
functions are designed to be mypyc-compatible with no dynamic dispatch.
The utilities are organized in layers:
- Expression factories: Create comparison expressions (eq, lt, like, etc.)
- Condition builders: Create parameterized WHERE conditions
- Expression modifiers: Apply WHERE, LIMIT, OFFSET to expressions
- CTE utilities: Safe CTE extraction and reattachment
Example:
>>> from sqlspec.core.query_modifiers import (
... expr_eq,
... create_condition,
... apply_where,
... )
>>> condition = create_condition(
... "status", "status_param", expr_eq
... )
>>> modified = apply_where(select_expr, condition)
"""
from collections.abc import Callable
from typing import Any
from sqlglot import exp
from sqlspec.exceptions import SQLSpecError
__all__ = (
"apply_column_pruning",
"apply_limit",
"apply_offset",
"apply_or_where",
"apply_select_only",
"apply_where",
"create_between_condition",
"create_condition",
"create_exists_condition",
"create_in_condition",
"create_not_exists_condition",
"create_not_in_condition",
"expr_eq",
"expr_gt",
"expr_gte",
"expr_ilike",
"expr_is_not_null",
"expr_is_null",
"expr_like",
"expr_lt",
"expr_lte",
"expr_neq",
"expr_not_like",
"extract_column_name",
"parse_column_for_condition",
"safe_modify_with_cte",
)
# Type alias for condition factory functions
ConditionFactory = Callable[[exp.Expression, exp.Placeholder], exp.Expression]
# =============================================================================
# Expression Factories
# =============================================================================
[docs]
def expr_eq(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create equality expression: column = :param."""
return exp.EQ(this=col, expression=placeholder)
[docs]
def expr_neq(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create not-equal expression: column != :param."""
return exp.NEQ(this=col, expression=placeholder)
[docs]
def expr_lt(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create less-than expression: column < :param."""
return exp.LT(this=col, expression=placeholder)
[docs]
def expr_lte(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create less-than-or-equal expression: column <= :param."""
return exp.LTE(this=col, expression=placeholder)
[docs]
def expr_gt(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create greater-than expression: column > :param."""
return exp.GT(this=col, expression=placeholder)
[docs]
def expr_gte(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create greater-than-or-equal expression: column >= :param."""
return exp.GTE(this=col, expression=placeholder)
[docs]
def expr_like(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create LIKE expression: column LIKE :param."""
return exp.Like(this=col, expression=placeholder)
[docs]
def expr_not_like(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create NOT LIKE expression: NOT (column LIKE :param)."""
return exp.Not(this=exp.Like(this=col, expression=placeholder))
[docs]
def expr_ilike(col: exp.Expression, placeholder: exp.Placeholder) -> exp.Expression:
"""Create case-insensitive LIKE expression: column ILIKE :param."""
return exp.ILike(this=col, expression=placeholder)
[docs]
def expr_is_null(col: exp.Expression, _placeholder: exp.Placeholder) -> exp.Expression:
"""Create IS NULL expression: column IS NULL.
Note: placeholder is ignored but kept for consistent factory signature.
"""
return exp.Is(this=col, expression=exp.null())
[docs]
def expr_is_not_null(col: exp.Expression, _placeholder: exp.Placeholder) -> exp.Expression:
"""Create IS NOT NULL expression: column IS NOT NULL.
Note: placeholder is ignored but kept for consistent factory signature.
"""
return exp.Not(this=exp.Is(this=col, expression=exp.null()))
# =============================================================================
# Column Parsing
# =============================================================================
[docs]
def parse_column_for_condition(column: str | exp.Column | exp.Expression) -> exp.Expression:
"""Parse column specification for use in conditions.
Handles various input formats:
- "column_name" -> exp.Column
- "table.column" -> exp.Column with table
- exp.Column -> returned as-is
- Other exp.Expression -> returned as-is
Args:
column: Column specification
Returns:
SQLGlot column expression
"""
if isinstance(column, exp.Expression):
return column
if isinstance(column, str):
if "." in column:
parts = column.split(".", 1)
return exp.column(parts[1], table=parts[0])
return exp.column(column)
return exp.column(str(column))
[docs]
def extract_column_name(column: str | exp.Column | exp.Expression) -> str:
"""Extract column name from column expression for parameter naming.
Args:
column: Column expression (string or SQLGlot Column)
Returns:
Column name as string for use as parameter name base
"""
if isinstance(column, str):
return column.split(".")[-1] if "." in column else column
if isinstance(column, exp.Column):
return column.name
if isinstance(column, exp.Expression) and hasattr(column, "name") and column.name:
return str(column.name)
return "column"
# =============================================================================
# Condition Builders
# =============================================================================
[docs]
def create_condition(
column: str | exp.Column | exp.Expression, param_name: str, condition_factory: ConditionFactory
) -> exp.Expression:
"""Create parameterized condition expression.
This is a pure function - parameter value binding happens in the caller.
Args:
column: Column name or expression
param_name: Pre-generated unique parameter name
condition_factory: Factory function for the condition type
Returns:
Condition expression with placeholder
"""
col_expr = parse_column_for_condition(column)
placeholder = exp.Placeholder(this=param_name)
return condition_factory(col_expr, placeholder)
[docs]
def create_in_condition(column: str | exp.Column | exp.Expression, param_names: list[str]) -> exp.Expression:
"""Create IN condition with multiple placeholders.
Args:
column: Column name or expression
param_names: Pre-generated parameter names (one per value)
Returns:
IN expression with placeholders
"""
col_expr = parse_column_for_condition(column)
placeholders = [exp.Placeholder(this=name) for name in param_names]
return exp.In(this=col_expr, expressions=placeholders)
[docs]
def create_not_in_condition(column: str | exp.Column | exp.Expression, param_names: list[str]) -> exp.Expression:
"""Create NOT IN condition with multiple placeholders.
Args:
column: Column name or expression
param_names: Pre-generated parameter names (one per value)
Returns:
NOT IN expression with placeholders
"""
in_expr = create_in_condition(column, param_names)
return exp.Not(this=in_expr)
[docs]
def create_between_condition(
column: str | exp.Column | exp.Expression, low_param: str, high_param: str
) -> exp.Expression:
"""Create BETWEEN condition.
Args:
column: Column name or expression
low_param: Parameter name for low bound
high_param: Parameter name for high bound
Returns:
BETWEEN expression with placeholders
"""
col_expr = parse_column_for_condition(column)
low_placeholder = exp.Placeholder(this=low_param)
high_placeholder = exp.Placeholder(this=high_param)
return exp.Between(this=col_expr, low=low_placeholder, high=high_placeholder)
[docs]
def create_exists_condition(subquery: exp.Expression) -> exp.Expression:
"""Create EXISTS condition.
Args:
subquery: Subquery expression
Returns:
EXISTS expression
"""
return exp.Exists(this=subquery)
[docs]
def create_not_exists_condition(subquery: exp.Expression) -> exp.Expression:
"""Create NOT EXISTS condition.
Args:
subquery: Subquery expression
Returns:
NOT EXISTS expression
"""
return exp.Not(this=exp.Exists(this=subquery))
# =============================================================================
# Expression Modifiers
# =============================================================================
[docs]
def apply_where(expression: exp.Expression, condition: exp.Expression) -> exp.Expression:
"""Apply WHERE condition to an expression using AND.
Works with SELECT, UPDATE, and DELETE expressions.
Args:
expression: Base expression to modify (will be copied)
condition: WHERE condition to add
Returns:
Modified expression with WHERE condition
Raises:
SQLSpecError: If expression type doesn't support WHERE
"""
if not isinstance(expression, (exp.Select, exp.Update, exp.Delete)):
msg = f"Cannot apply WHERE to {type(expression).__name__}"
raise SQLSpecError(msg)
return expression.where(condition, copy=False)
[docs]
def apply_or_where(expression: exp.Expression, condition: exp.Expression) -> exp.Expression:
"""Apply WHERE condition to an expression using OR.
Combines the new condition with any existing WHERE clause using OR.
Args:
expression: Base expression with existing WHERE
condition: New condition to add with OR
Returns:
Modified expression with OR condition
Raises:
SQLSpecError: If expression type doesn't support WHERE or has no existing WHERE
"""
if not isinstance(expression, (exp.Select, exp.Update, exp.Delete)):
msg = f"Cannot apply WHERE to {type(expression).__name__}"
raise SQLSpecError(msg)
existing_where = expression.args.get("where")
if not existing_where or not isinstance(existing_where, exp.Where):
msg = "Cannot use OR without existing WHERE clause"
raise SQLSpecError(msg)
combined = exp.Or(this=existing_where.this, expression=condition)
existing_where.set("this", combined)
return expression
[docs]
def apply_limit(expression: exp.Expression, limit_value: int) -> exp.Expression:
"""Apply LIMIT clause to expression.
Args:
expression: Base expression (must be SELECT or set operation)
limit_value: LIMIT value
Returns:
Modified expression with LIMIT
Raises:
SQLSpecError: If expression does not support LIMIT
"""
if not isinstance(expression, (exp.Select, exp.SetOperation)):
msg = f"LIMIT only valid for SELECT or set operations, got {type(expression).__name__}"
raise SQLSpecError(msg)
return expression.limit(limit_value, copy=False)
[docs]
def apply_offset(expression: exp.Expression, offset_value: int) -> exp.Expression:
"""Apply OFFSET clause to expression.
Args:
expression: Base expression (must be SELECT or set operation)
offset_value: OFFSET value
Returns:
Modified expression with OFFSET
Raises:
SQLSpecError: If expression does not support OFFSET
"""
if not isinstance(expression, (exp.Select, exp.SetOperation)):
msg = f"OFFSET only valid for SELECT or set operations, got {type(expression).__name__}"
raise SQLSpecError(msg)
return expression.offset(offset_value, copy=False)
[docs]
def apply_select_only(expression: exp.Expression, columns: tuple[str | exp.Expression, ...]) -> exp.Expression:
"""Replace SELECT clause with only specified columns.
Args:
expression: Base expression (must be SELECT)
columns: Column names or expressions to select
Returns:
Modified expression with new SELECT columns
Raises:
SQLSpecError: If expression is not SELECT
"""
if not isinstance(expression, exp.Select):
msg = f"select_only only valid for SELECT, got {type(expression).__name__}"
raise SQLSpecError(msg)
expression.set("expressions", [])
for col in columns:
col_expr = parse_column_for_condition(col) if isinstance(col, str) else col
expression = expression.select(col_expr, copy=False)
return expression
# =============================================================================
# CTE Utilities
# =============================================================================
[docs]
def safe_modify_with_cte(
expression: exp.Expression, modification_fn: Callable[[exp.Expression], exp.Expression]
) -> exp.Expression:
"""Safely apply a modification, preserving CTEs at top level.
This ensures CTEs stay at the outermost level even when the modification
would normally wrap them in a subquery. This fixes issue #301 where
CTEs inside subqueries generate invalid SQL.
Args:
expression: Expression that may contain CTEs
modification_fn: Function to apply to the expression
Returns:
Modified expression with CTE preserved at top level
"""
cte: Any = None
working_expr = expression
if isinstance(expression, (exp.Select, exp.SetOperation)):
cte = expression.args.get("with_")
if cte:
working_expr = expression.copy()
working_expr.set("with_", None)
result = modification_fn(working_expr)
if cte and isinstance(result, (exp.Select, exp.SetOperation)):
result.set("with_", cte)
return result
# =============================================================================
# Column Pruning
# =============================================================================
def apply_column_pruning(
expression: exp.Expression, dialect: str | None = None, cache_key: str | None = None
) -> exp.Expression:
"""Apply column pruning optimization to remove unused columns from subqueries.
Uses SQLGlot's `qualify()` to resolve column references and table aliases,
then `pushdown_projections()` to remove columns from subqueries that aren't
needed by outer queries.
This optimization can improve query performance by:
- Reducing I/O by selecting fewer columns from disk
- Reducing network transfer when fetching results
- Enabling better query plans in some databases
Args:
expression: Base expression (must be SELECT)
dialect: SQL dialect for qualification
cache_key: Optional cache key for looking up/storing optimized result
Returns:
Optimized expression with unused columns removed from subqueries
Example:
Before pruning:
SELECT id, name FROM (SELECT id, name, email, created_at FROM users)
After pruning:
SELECT id, name FROM (SELECT id, name FROM users)
"""
from sqlglot.optimizer import pushdown_projections as pushdown_projections_module
from sqlglot.optimizer import qualify as qualify_module
from sqlspec.core.cache import get_cache
if not isinstance(expression, exp.Select):
return expression
# Check cache first if key provided
if cache_key is not None:
cache = get_cache()
cached = cache.get_optimized(cache_key, dialect)
if cached is not None and isinstance(cached, exp.Expression):
cached_expr: exp.Expression = cached
return cached_expr.copy()
# Apply qualification to resolve column references
try:
qualified = qualify_module.qualify(expression.copy(), dialect=dialect, validate_qualify_columns=False)
except Exception:
# If qualification fails, return unchanged expression
return expression
# Apply pushdown_projections to remove unused columns
try:
pruned = pushdown_projections_module.pushdown_projections(qualified, dialect=dialect)
except Exception:
# If pushdown fails, return the qualified expression
pruned = qualified
# Cache the result
if cache_key is not None:
cache = get_cache()
cache.put_optimized(cache_key, pruned, dialect)
return pruned