Source code for sqlspec.core.statement

"""SQL statement and configuration management."""

from typing import TYPE_CHECKING, Any, Final, Optional, TypeAlias

import sqlglot
from mypy_extensions import mypyc_attr
from sqlglot import exp
from sqlglot.errors import ParseError

import sqlspec.exceptions
from sqlspec.core.compiler import OperationProfile, OperationType
from sqlspec.core.parameters import (
    ParameterConverter,
    ParameterProfile,
    ParameterStyle,
    ParameterStyleConfig,
    ParameterValidator,
)
from sqlspec.core.pipeline import compile_with_shared_pipeline
from sqlspec.typing import Empty, EmptyEnum
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import is_statement_filter, supports_where

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlglot.dialects.dialect import DialectType

    from sqlspec.core.cache import FiltersView
    from sqlspec.core.filters import StatementFilter


__all__ = (
    "SQL",
    "ProcessedState",
    "Statement",
    "StatementConfig",
    "get_default_config",
    "get_default_parameter_config",
)
logger = get_logger("sqlspec.core.statement")

RETURNS_ROWS_OPERATIONS: Final = {"SELECT", "WITH", "VALUES", "TABLE", "SHOW", "DESCRIBE", "PRAGMA"}
MODIFYING_OPERATIONS: Final = {"INSERT", "UPDATE", "DELETE", "MERGE", "UPSERT"}

SQL_CONFIG_SLOTS: Final = (
    "pre_process_steps",
    "post_process_steps",
    "dialect",
    "enable_analysis",
    "enable_caching",
    "enable_expression_simplification",
    "enable_parameter_type_wrapping",
    "enable_parsing",
    "enable_transformations",
    "enable_validation",
    "execution_mode",
    "execution_args",
    "output_transformer",
    "parameter_config",
    "parameter_converter",
    "parameter_validator",
)

PROCESSED_STATE_SLOTS: Final = (
    "compiled_sql",
    "execution_parameters",
    "parsed_expression",
    "operation_type",
    "parameter_casts",
    "parameter_profile",
    "operation_profile",
    "validation_errors",
    "is_many",
)


@mypyc_attr(allow_interpreted_subclasses=False)
class ProcessedState:
    """Processing results for SQL statements.

    Contains the compiled SQL, execution parameters, parsed expression,
    operation type, and validation errors for a processed SQL statement.
    """

    __slots__ = PROCESSED_STATE_SLOTS
    operation_type: "OperationType"

    def __init__(
        self,
        compiled_sql: str,
        execution_parameters: Any,
        parsed_expression: "exp.Expression | None" = None,
        operation_type: "OperationType" = "UNKNOWN",
        parameter_casts: "dict[int, str] | None" = None,
        validation_errors: "list[str] | None" = None,
        parameter_profile: "ParameterProfile | None" = None,
        operation_profile: "OperationProfile | None" = None,
        is_many: bool = False,
    ) -> None:
        self.compiled_sql = compiled_sql
        self.execution_parameters = execution_parameters
        self.parsed_expression = parsed_expression
        self.operation_type = operation_type
        self.parameter_casts = parameter_casts or {}
        self.validation_errors = validation_errors or []
        self.parameter_profile = parameter_profile or ParameterProfile.empty()
        self.operation_profile = operation_profile or OperationProfile.empty()
        self.is_many = is_many

    def __hash__(self) -> int:
        return hash((self.compiled_sql, str(self.execution_parameters), self.operation_type))


[docs] @mypyc_attr(allow_interpreted_subclasses=False) class SQL: """SQL statement with parameter and filter support. Represents a SQL statement that can be compiled with parameters and filters. Supports both positional and named parameters, statement filtering, and various execution modes including batch operations. """ __slots__ = ( "_dialect", "_filters", "_hash", "_is_many", "_is_script", "_named_parameters", "_original_parameters", "_positional_parameters", "_processed_state", "_raw_sql", "_statement_config", )
[docs] def __init__( self, statement: "str | exp.Expression | 'SQL'", *parameters: "Any | StatementFilter | list[Any | StatementFilter]", statement_config: Optional["StatementConfig"] = None, is_many: bool | None = None, **kwargs: Any, ) -> None: """Initialize SQL statement. Args: statement: SQL string, expression, or existing SQL object *parameters: Parameters and filters statement_config: Configuration is_many: Mark as execute_many operation **kwargs: Additional parameters """ config = statement_config or self._create_auto_config(statement, parameters, kwargs) self._statement_config = config self._dialect = self._normalize_dialect(config.dialect) self._processed_state: EmptyEnum | ProcessedState = Empty self._hash: int | None = None self._filters: list[StatementFilter] = [] self._named_parameters: dict[str, Any] = {} self._positional_parameters: list[Any] = [] self._is_script = False if isinstance(statement, SQL): self._init_from_sql_object(statement) if is_many is not None: self._is_many = is_many else: if isinstance(statement, str): self._raw_sql = statement else: dialect = self._dialect self._raw_sql = statement.sql(dialect=str(dialect) if dialect else None) self._is_many = is_many if is_many is not None else self._should_auto_detect_many(parameters) self._original_parameters = parameters self._process_parameters(*parameters, **kwargs)
def _create_auto_config( self, _statement: "str | exp.Expression | 'SQL'", _parameters: tuple, _kwargs: dict[str, Any] ) -> "StatementConfig": """Create default StatementConfig when none provided. Args: _statement: The SQL statement (unused) _parameters: Statement parameters (unused) _kwargs: Additional keyword arguments (unused) Returns: Default StatementConfig instance """ return get_default_config() def _normalize_dialect(self, dialect: "DialectType | None") -> "str | None": """Convert dialect to string representation. Args: dialect: Dialect type or string Returns: String representation of the dialect or None """ if dialect is None: return None if isinstance(dialect, str): return dialect return dialect.__class__.__name__.lower() def _init_from_sql_object(self, sql_obj: "SQL") -> None: """Initialize instance attributes from existing SQL object. Args: sql_obj: Existing SQL object to copy from """ self._raw_sql = sql_obj.raw_sql self._filters = sql_obj.filters.copy() self._named_parameters = sql_obj.named_parameters.copy() self._positional_parameters = sql_obj.positional_parameters.copy() self._is_many = sql_obj.is_many self._is_script = sql_obj.is_script if sql_obj.is_processed: self._processed_state = sql_obj.get_processed_state() def _should_auto_detect_many(self, parameters: tuple) -> bool: """Detect execute_many mode from parameter structure. Args: parameters: Parameter tuple to analyze Returns: True if parameters indicate batch execution """ if len(parameters) == 1 and isinstance(parameters[0], list): param_list = parameters[0] if param_list and all(isinstance(item, (tuple, list)) for item in param_list): return len(param_list) > 1 return False def _process_parameters(self, *parameters: Any, dialect: str | None = None, **kwargs: Any) -> None: """Process and organize parameters and filters. Args: *parameters: Variable parameters and filters dialect: SQL dialect override **kwargs: Additional named parameters """ if dialect is not None: self._dialect = self._normalize_dialect(dialect) if "is_script" in kwargs: self._is_script = bool(kwargs.pop("is_script")) filters: list[StatementFilter] = [] actual_params: list[Any] = [] for p in parameters: if is_statement_filter(p): filters.append(p) else: actual_params.append(p) self._filters.extend(filters) if actual_params: param_count = len(actual_params) if param_count == 1: param = actual_params[0] if isinstance(param, dict): self._named_parameters.update(param) elif isinstance(param, (list, tuple)): if self._is_many: self._positional_parameters = list(param) else: # For drivers with native list expansion support, each item in the tuple/list # should be treated as a separate parameter (but preserve inner lists/arrays) # This allows passing arrays/lists as single JSONB parameters self._positional_parameters.extend(param) else: self._positional_parameters.append(param) else: self._positional_parameters.extend(actual_params) self._named_parameters.update(kwargs) @property def sql(self) -> str: """Get the raw SQL string.""" return self._raw_sql @property def raw_sql(self) -> str: """Get raw SQL string (public API). Returns: The raw SQL string """ return self._raw_sql @property def parameters(self) -> Any: """Get the original parameters.""" if self._named_parameters: return self._named_parameters return self._positional_parameters or [] @property def positional_parameters(self) -> "list[Any]": """Get positional parameters (public API).""" return self._positional_parameters or [] @property def named_parameters(self) -> "dict[str, Any]": """Get named parameters (public API).""" return self._named_parameters @property def original_parameters(self) -> Any: """Get original parameters (public API).""" return self._original_parameters @property def operation_type(self) -> "OperationType": """SQL operation type.""" if self._processed_state is Empty: return "UNKNOWN" return self._processed_state.operation_type @property def statement_config(self) -> "StatementConfig": """Statement configuration.""" return self._statement_config @property def expression(self) -> "exp.Expression | None": """SQLGlot expression.""" if self._processed_state is not Empty: return self._processed_state.parsed_expression return None @property def filters(self) -> "list[StatementFilter]": """Applied filters.""" return self._filters.copy()
[docs] def get_filters_view(self) -> "FiltersView": """Get zero-copy filters view (public API). Returns: Read-only view of filters without copying """ from sqlspec.core.cache import FiltersView return FiltersView(self._filters)
@property def is_processed(self) -> bool: """Check if SQL has been processed (public API).""" return self._processed_state is not Empty
[docs] def get_processed_state(self) -> Any: """Get processed state (public API).""" return self._processed_state
@property def dialect(self) -> "str | None": """SQL dialect.""" return self._dialect @property def _statement(self) -> "exp.Expression | None": """Internal SQLGlot expression.""" return self.expression @property def statement_expression(self) -> "exp.Expression | None": """Get parsed statement expression (public API). Returns: Parsed SQLGlot expression or None if not parsed """ if self._processed_state is not Empty: return self._processed_state.parsed_expression return None @property def is_many(self) -> bool: """Check if this is execute_many.""" return self._is_many @property def is_script(self) -> bool: """Check if this is script execution.""" return self._is_script @property def validation_errors(self) -> "list[str]": """Validation errors.""" if self._processed_state is Empty: return [] return self._processed_state.validation_errors.copy() @property def has_errors(self) -> bool: """Check if there are validation errors.""" return len(self.validation_errors) > 0
[docs] def returns_rows(self) -> bool: """Check if statement returns rows. Returns: True if the SQL statement returns result rows """ if self._processed_state is Empty: self.compile() if self._processed_state is Empty: return False profile = getattr(self._processed_state, "operation_profile", None) if profile and profile.returns_rows: return True op_type = self._processed_state.operation_type if op_type in RETURNS_ROWS_OPERATIONS: return True if self._processed_state.parsed_expression: expr = self._processed_state.parsed_expression if isinstance(expr, (exp.Insert, exp.Update, exp.Delete)) and expr.args.get("returning"): return True return False
[docs] def is_modifying_operation(self) -> bool: """Check if the SQL statement is a modifying operation. Returns: True if the operation modifies data (INSERT/UPDATE/DELETE) """ if self._processed_state is Empty: return False profile = getattr(self._processed_state, "operation_profile", None) if profile and profile.modifies_rows: return True op_type = self._processed_state.operation_type if op_type in MODIFYING_OPERATIONS: return True if self._processed_state.parsed_expression: return isinstance(self._processed_state.parsed_expression, (exp.Insert, exp.Update, exp.Delete, exp.Merge)) return False
[docs] def compile(self) -> tuple[str, Any]: """Compile SQL statement with parameters. Returns: Tuple of compiled SQL string and execution parameters """ if self._processed_state is Empty: try: config = self._statement_config raw_sql = self._raw_sql params = self._named_parameters or self._positional_parameters is_many = self._is_many compiled_result = compile_with_shared_pipeline(config, raw_sql, params, is_many=is_many) self._processed_state = ProcessedState( compiled_sql=compiled_result.compiled_sql, execution_parameters=compiled_result.execution_parameters, parsed_expression=compiled_result.expression, operation_type=compiled_result.operation_type, parameter_casts=compiled_result.parameter_casts, parameter_profile=compiled_result.parameter_profile, operation_profile=compiled_result.operation_profile, validation_errors=[], is_many=self._is_many, ) except sqlspec.exceptions.SQLSpecError: raise except Exception as e: self._processed_state = self._handle_compile_failure(e) return self._processed_state.compiled_sql, self._processed_state.execution_parameters
[docs] def as_script(self) -> "SQL": """Create copy marked for script execution. Returns: New SQL instance configured for script execution """ original_params = self._original_parameters config = self._statement_config is_many = self._is_many new_sql = SQL(self._raw_sql, *original_params, statement_config=config, is_many=is_many) new_sql._named_parameters.update(self._named_parameters) new_sql._positional_parameters = self._positional_parameters.copy() new_sql._filters = self._filters.copy() new_sql._is_script = True return new_sql
[docs] def copy( self, statement: "str | exp.Expression | None" = None, parameters: Any | None = None, **kwargs: Any ) -> "SQL": """Create copy with modifications. Args: statement: New SQL statement to use parameters: New parameters to use **kwargs: Additional modifications Returns: New SQL instance with modifications applied """ new_sql = SQL( statement or self._raw_sql, *(parameters if parameters is not None else self._original_parameters), statement_config=self._statement_config, is_many=self._is_many, **kwargs, ) if parameters is None: new_sql._named_parameters.update(self._named_parameters) new_sql._positional_parameters = self._positional_parameters.copy() new_sql._filters = self._filters.copy() return new_sql
def _handle_compile_failure(self, error: Exception) -> ProcessedState: logger.warning("Processing failed, using fallback: %s", error) return ProcessedState( compiled_sql=self._raw_sql, execution_parameters=self._named_parameters or self._positional_parameters, operation_type="UNKNOWN", parameter_casts={}, parameter_profile=ParameterProfile.empty(), operation_profile=OperationProfile.empty(), is_many=self._is_many, )
[docs] def add_named_parameter(self, name: str, value: Any) -> "SQL": """Add a named parameter and return a new SQL instance. Args: name: Parameter name value: Parameter value Returns: New SQL instance with the added parameter """ original_params = self._original_parameters config = self._statement_config is_many = self._is_many new_sql = SQL(self._raw_sql, *original_params, statement_config=config, is_many=is_many) new_sql._named_parameters.update(self._named_parameters) new_sql._named_parameters[name] = value new_sql._positional_parameters = self._positional_parameters.copy() new_sql._filters = self._filters.copy() return new_sql
[docs] def where(self, condition: "str | exp.Expression") -> "SQL": """Add WHERE condition to the SQL statement. Args: condition: WHERE condition as string or SQLGlot expression Returns: New SQL instance with the WHERE condition applied """ try: current_expr = sqlglot.parse_one(self._raw_sql, dialect=self._dialect) except ParseError: subquery_sql = f"SELECT * FROM ({self._raw_sql}) AS subquery" current_expr = sqlglot.parse_one(subquery_sql, dialect=self._dialect) condition_expr: exp.Expression if isinstance(condition, str): try: condition_expr = sqlglot.parse_one(condition, dialect=self._dialect, into=exp.Condition) except ParseError: condition_expr = exp.Condition(this=condition) else: condition_expr = condition if isinstance(current_expr, exp.Select) or supports_where(current_expr): new_expr = current_expr.where(condition_expr, copy=False) else: new_expr = exp.Select().from_(current_expr).where(condition_expr, copy=False) original_params = self._original_parameters config = self._statement_config is_many = self._is_many new_sql = SQL(new_expr, *original_params, statement_config=config, is_many=is_many) new_sql._named_parameters.update(self._named_parameters) new_sql._positional_parameters = self._positional_parameters.copy() new_sql._filters = self._filters.copy() return new_sql
[docs] def __hash__(self) -> int: """Hash value computation.""" if self._hash is None: positional_tuple = tuple(self._positional_parameters) named_tuple = tuple(sorted(self._named_parameters.items())) if self._named_parameters else () raw_sql = self._raw_sql is_many = self._is_many is_script = self._is_script self._hash = hash((raw_sql, positional_tuple, named_tuple, is_many, is_script)) return self._hash
[docs] def __eq__(self, other: object) -> bool: """Equality comparison.""" if not isinstance(other, SQL): return False return ( self._raw_sql == other._raw_sql and self._positional_parameters == other._positional_parameters and self._named_parameters == other._named_parameters and self._is_many == other._is_many and self._is_script == other._is_script )
[docs] def __repr__(self) -> str: """String representation.""" params_parts = [] if self._positional_parameters: params_parts.append(f"params={self._positional_parameters}") if self._named_parameters: params_parts.append(f"named_params={self._named_parameters}") params_str = f", {', '.join(params_parts)}" if params_parts else "" flags = [] if self._is_many: flags.append("is_many") if self._is_script: flags.append("is_script") flags_str = f", {', '.join(flags)}" if flags else "" return f"SQL({self._raw_sql!r}{params_str}{flags_str})"
[docs] @mypyc_attr(allow_interpreted_subclasses=False) class StatementConfig: """Configuration for SQL statement processing. Controls SQL parsing, validation, transformations, parameter handling, and other processing options for SQL statements. """ __slots__ = SQL_CONFIG_SLOTS
[docs] def __init__( self, parameter_config: "ParameterStyleConfig | None" = None, enable_parsing: bool = True, enable_validation: bool = True, enable_transformations: bool = True, enable_analysis: bool = False, enable_expression_simplification: bool = False, enable_parameter_type_wrapping: bool = True, enable_caching: bool = True, parameter_converter: "ParameterConverter | None" = None, parameter_validator: "ParameterValidator | None" = None, dialect: "DialectType | None" = None, pre_process_steps: "list[Any] | None" = None, post_process_steps: "list[Any] | None" = None, execution_mode: "str | None" = None, execution_args: "dict[str, Any] | None" = None, output_transformer: "Callable[[str, Any], tuple[str, Any]] | None" = None, ) -> None: """Initialize StatementConfig. Args: parameter_config: Parameter style configuration enable_parsing: Enable SQL parsing enable_validation: Run SQL validators enable_transformations: Apply SQL transformers enable_analysis: Run SQL analyzers enable_expression_simplification: Apply expression simplification enable_parameter_type_wrapping: Wrap parameters with type information enable_caching: Cache processed SQL statements parameter_converter: Handles parameter style conversions parameter_validator: Validates parameter usage and styles dialect: SQL dialect pre_process_steps: Optional list of preprocessing steps post_process_steps: Optional list of postprocessing steps execution_mode: Special execution mode execution_args: Arguments for special execution modes output_transformer: Optional output transformation function """ self.enable_parsing = enable_parsing self.enable_validation = enable_validation self.enable_transformations = enable_transformations self.enable_analysis = enable_analysis self.enable_expression_simplification = enable_expression_simplification self.enable_parameter_type_wrapping = enable_parameter_type_wrapping self.enable_caching = enable_caching self.parameter_converter = parameter_converter or ParameterConverter() self.parameter_validator = parameter_validator or ParameterValidator() self.parameter_config = parameter_config or ParameterStyleConfig( default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} ) self.dialect = dialect self.pre_process_steps = pre_process_steps self.post_process_steps = post_process_steps self.execution_mode = execution_mode self.execution_args = execution_args self.output_transformer = output_transformer
[docs] def replace(self, **kwargs: Any) -> "StatementConfig": """Immutable update pattern. Args: **kwargs: Attributes to update Returns: New StatementConfig instance with updated attributes """ for key in kwargs: if key not in SQL_CONFIG_SLOTS: msg = f"{key!r} is not a field in {type(self).__name__}" raise TypeError(msg) current_kwargs: dict[str, Any] = { "parameter_config": self.parameter_config, "enable_parsing": self.enable_parsing, "enable_validation": self.enable_validation, "enable_transformations": self.enable_transformations, "enable_analysis": self.enable_analysis, "enable_expression_simplification": self.enable_expression_simplification, "enable_parameter_type_wrapping": self.enable_parameter_type_wrapping, "enable_caching": self.enable_caching, "parameter_converter": self.parameter_converter, "parameter_validator": self.parameter_validator, "dialect": self.dialect, "pre_process_steps": self.pre_process_steps, "post_process_steps": self.post_process_steps, "execution_mode": self.execution_mode, "execution_args": self.execution_args, "output_transformer": self.output_transformer, } current_kwargs.update(kwargs) return type(self)(**current_kwargs)
[docs] def __hash__(self) -> int: """Hash based on configuration settings.""" return hash(( self.enable_parsing, self.enable_validation, self.enable_transformations, self.enable_analysis, self.enable_expression_simplification, self.enable_parameter_type_wrapping, self.enable_caching, str(self.dialect), ))
[docs] def __repr__(self) -> str: """String representation of the StatementConfig instance.""" field_strs = [ f"parameter_config={self.parameter_config!r}", f"enable_parsing={self.enable_parsing!r}", f"enable_validation={self.enable_validation!r}", f"enable_transformations={self.enable_transformations!r}", f"enable_analysis={self.enable_analysis!r}", f"enable_expression_simplification={self.enable_expression_simplification!r}", f"enable_parameter_type_wrapping={self.enable_parameter_type_wrapping!r}", f"enable_caching={self.enable_caching!r}", f"parameter_converter={self.parameter_converter!r}", f"parameter_validator={self.parameter_validator!r}", f"dialect={self.dialect!r}", f"pre_process_steps={self.pre_process_steps!r}", f"post_process_steps={self.post_process_steps!r}", f"execution_mode={self.execution_mode!r}", f"execution_args={self.execution_args!r}", f"output_transformer={self.output_transformer!r}", ] return f"{self.__class__.__name__}({', '.join(field_strs)})"
[docs] def __eq__(self, other: object) -> bool: """Equality comparison.""" if not isinstance(other, type(self)): return False if not self._compare_parameter_configs(self.parameter_config, other.parameter_config): return False return ( self.enable_parsing == other.enable_parsing and self.enable_validation == other.enable_validation and self.enable_transformations == other.enable_transformations and self.enable_analysis == other.enable_analysis and self.enable_expression_simplification == other.enable_expression_simplification and self.enable_parameter_type_wrapping == other.enable_parameter_type_wrapping and self.enable_caching == other.enable_caching and self.dialect == other.dialect and self.pre_process_steps == other.pre_process_steps and self.post_process_steps == other.post_process_steps and self.execution_mode == other.execution_mode and self.execution_args == other.execution_args and self.output_transformer == other.output_transformer )
def _compare_parameter_configs(self, config1: Any, config2: Any) -> bool: """Compare parameter configs.""" return bool( config1.default_parameter_style == config2.default_parameter_style and config1.supported_parameter_styles == config2.supported_parameter_styles and config1.supported_execution_parameter_styles == config2.supported_execution_parameter_styles )
def get_default_config() -> StatementConfig: """Get default statement configuration. Returns: StatementConfig with default settings """ return StatementConfig() def get_default_parameter_config() -> ParameterStyleConfig: """Get default parameter configuration. Returns: ParameterStyleConfig with QMARK style as default """ return ParameterStyleConfig( default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} ) Statement: TypeAlias = str | exp.Expression | SQL