Source code for sqlspec.builder._base

"""Base query builder with validation and parameter binding.

Provides abstract base classes and core functionality for SQL query builders.
"""

import hashlib
import re
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Mapping
from typing import Any, NoReturn, cast

import sqlglot
from sqlglot import Dialect, exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import ParseError as SQLGlotParseError
from sqlglot.optimizer import optimize
from typing_extensions import Self

from sqlspec.builder._vector_expressions import VectorDistance
from sqlspec.core import (
    SQL,
    ParameterStyle,
    ParameterStyleConfig,
    SQLResult,
    StatementConfig,
    get_cache,
    get_cache_config,
    hash_optimized_expression,
)
from sqlspec.exceptions import SQLBuilderError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import has_expression_and_parameters, has_name, has_with_method, is_expression

__all__ = ("BuiltQuery", "ExpressionBuilder", "QueryBuilder")

MAX_PARAMETER_COLLISION_ATTEMPTS = 1000
PARAMETER_INDEX_PATTERN = re.compile(r"^param_(?P<index>\d+)$")


class _ExpressionParameterizer:
    __slots__ = ("_builder",)

    def __init__(self, builder: "QueryBuilder") -> None:
        self._builder = builder

    def __call__(self, node: exp.Expression) -> exp.Expression:
        if isinstance(node, exp.Literal):
            if node.this in {True, False, None}:
                return node

            parent = node.parent
            if isinstance(parent, exp.Array) and node.find_ancestor(VectorDistance) is not None:
                return node

            value = node.this
            if node.is_number and isinstance(node.this, str):
                try:
                    value = float(node.this) if "." in node.this or "e" in node.this.lower() else int(node.this)
                except ValueError:
                    value = node.this

            param_name = self._builder.add_parameter_for_expression(value, context="where")
            return exp.Placeholder(this=param_name)
        return node


class _PlaceholderReplacer:
    __slots__ = ("_param_mapping",)

    def __init__(self, param_mapping: dict[str, str]) -> None:
        self._param_mapping = param_mapping

    def __call__(self, node: exp.Expression) -> exp.Expression:
        if isinstance(node, exp.Placeholder) and str(node.this) in self._param_mapping:
            return exp.Placeholder(this=self._param_mapping[str(node.this)])
        return node


def _unquote_identifier(node: exp.Expression) -> exp.Expression:
    if isinstance(node, exp.Identifier):
        node.set("quoted", False)
    return node


logger = get_logger(__name__)


[docs] class BuiltQuery: """SQL query with bound parameters.""" __slots__ = ("dialect", "parameters", "sql")
[docs] def __init__(self, sql: str, parameters: dict[str, Any] | None = None, dialect: DialectType | None = None) -> None: self.sql = sql self.parameters = parameters if parameters is not None else {} self.dialect = dialect
def __repr__(self) -> str: parameter_keys = sorted(self.parameters.keys()) return f"BuiltQuery(sql={self.sql!r}, parameters={parameter_keys!r}, dialect={self.dialect!r})" def __eq__(self, other: object) -> bool: if not isinstance(other, BuiltQuery): return NotImplemented return self.sql == other.sql and self.parameters == other.parameters and self.dialect == other.dialect def __hash__(self) -> int: return hash((self.sql, frozenset(self.parameters.items()), self.dialect))
[docs] class QueryBuilder(ABC): """Abstract base class for SQL query builders. Provides common functionality for dialect handling, parameter management, and query construction using SQLGlot. """ __slots__ = ( "_expression", "_lock_targets_quoted", "_merge_target_quoted", "_parameter_counter", "_parameter_name_counters", "_parameters", "_with_ctes", "dialect", "enable_optimization", "optimize_joins", "optimize_predicates", "schema", "simplify_expressions", )
[docs] def __init__( self, dialect: DialectType | None = None, schema: dict[str, dict[str, str]] | None = None, enable_optimization: bool = True, optimize_joins: bool = True, optimize_predicates: bool = True, simplify_expressions: bool = True, ) -> None: self.dialect = dialect self.schema = schema self.enable_optimization = enable_optimization self.optimize_joins = optimize_joins self.optimize_predicates = optimize_predicates self.simplify_expressions = simplify_expressions self._expression: exp.Expression | None = None self._parameter_name_counters: dict[str, int] = {} self._parameters: dict[str, Any] = {} self._parameter_counter: int = 0 self._with_ctes: dict[str, exp.CTE] = {} self._lock_targets_quoted = False self._merge_target_quoted = False
@classmethod def _parse_query_builder_kwargs( cls, kwargs: "dict[str, Any]" ) -> "tuple[DialectType | None, dict[str, dict[str, str]] | None, bool, bool, bool, bool]": dialect = kwargs.pop("dialect", None) schema = kwargs.pop("schema", None) enable_optimization = kwargs.pop("enable_optimization", True) optimize_joins = kwargs.pop("optimize_joins", True) optimize_predicates = kwargs.pop("optimize_predicates", True) simplify_expressions = kwargs.pop("simplify_expressions", True) if kwargs: unknown = ", ".join(sorted(kwargs.keys())) cls._raise_sql_builder_error(f"Unexpected QueryBuilder arguments: {unknown}") return (dialect, schema, enable_optimization, optimize_joins, optimize_predicates, simplify_expressions) def _initialize_expression(self) -> None: """Initialize the base expression. Called after __init__.""" self._expression = self._create_base_expression() if not self._expression: self._raise_sql_builder_error( "QueryBuilder._create_base_expression must return a valid sqlglot expression." )
[docs] def get_expression(self) -> exp.Expression | None: """Get expression reference (no copy). Returns: The current SQLGlot expression or None if not set """ return self._expression
[docs] def set_expression(self, expression: exp.Expression) -> None: """Set expression with validation. Args: expression: SQLGlot expression to set """ if not is_expression(expression): self._raise_invalid_expression_type(expression) self._expression = expression
[docs] def has_expression(self) -> bool: """Check if expression exists. Returns: True if expression is set, False otherwise """ return self._expression is not None
@abstractmethod def _create_base_expression(self) -> exp.Expression: """Create the base sqlglot expression for the specific query type. Returns: A new sqlglot expression appropriate for the query type. """ @property @abstractmethod def _expected_result_type(self) -> "type[SQLResult]": """The expected result type for the query being built. Returns: type[ResultT]: The type of the result. """ @staticmethod def _raise_sql_builder_error(message: str, cause: BaseException | None = None) -> NoReturn: """Helper to raise SQLBuilderError, potentially with a cause. Args: message: The error message. cause: The optional original exception to chain. Raises: SQLBuilderError: Always raises this exception. """ raise SQLBuilderError(message) from cause @staticmethod def _raise_invalid_expression_type(expression: Any) -> NoReturn: """Raise error for invalid expression type. Args: expression: The invalid expression object Raises: TypeError: Always raised for type mismatch """ msg = f"Expected Expression, got {type(expression)}" raise TypeError(msg) @staticmethod def _raise_cte_query_error(alias: str, message: str) -> NoReturn: """Raise error for CTE query issues. Args: alias: CTE alias name message: Specific error message Raises: SQLBuilderError: Always raised for CTE errors """ msg = f"CTE '{alias}': {message}" raise SQLBuilderError(msg) @staticmethod def _raise_cte_parse_error(cause: BaseException) -> NoReturn: """Raise error for CTE parsing failures. Args: cause: The original parsing exception Raises: SQLBuilderError: Always raised with chained cause """ msg = f"Failed to parse CTE query: {cause!s}" raise SQLBuilderError(msg) from cause def _build_final_expression(self, *, copy: bool = False) -> exp.Expression: """Construct the current expression with attached CTEs. Args: copy: Whether to copy the underlying expression tree before applying transformations. Returns: Expression representing the current builder state with CTEs applied. """ if self._expression is None: self._raise_sql_builder_error("QueryBuilder expression not initialized.") base_expression = self._expression.copy() if copy or self._with_ctes else self._expression if not self._with_ctes: return base_expression final_expression: exp.Expression = base_expression if has_with_method(final_expression): for alias, cte_node in self._with_ctes.items(): final_expression = cast("Any", final_expression).with_(cte_node.args["this"], as_=alias, copy=False) return cast("exp.Expression", final_expression) if isinstance(final_expression, (exp.Select, exp.Insert, exp.Update, exp.Delete, exp.Union)): return exp.With(expressions=list(self._with_ctes.values()), this=final_expression) return final_expression def _spawn_like_self(self: Self) -> Self: """Create a new builder instance with matching configuration.""" return type(self)( dialect=self.dialect, schema=self.schema, enable_optimization=self.enable_optimization, optimize_joins=self.optimize_joins, optimize_predicates=self.optimize_predicates, simplify_expressions=self.simplify_expressions, ) def _resolve_cte_query(self, alias: str, query: "QueryBuilder | exp.Select | str") -> exp.Select: """Resolve a CTE query into a Select expression with merged parameters.""" if isinstance(query, QueryBuilder): query_expr = query.get_expression() if query_expr is None: self._raise_cte_query_error(alias, "query builder has no expression") if not isinstance(query_expr, exp.Select): self._raise_cte_query_error(alias, f"expression must be a Select, got {type(query_expr).__name__}") cte_select_expression = query_expr.copy() param_mapping = self._merge_cte_parameters(alias, query.parameters) updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping) if not isinstance(updated_expression, exp.Select): # pragma: no cover - defensive msg = "CTE placeholder update produced non-select expression" raise SQLBuilderError(msg) return updated_expression if isinstance(query, str): try: parsed_expression = sqlglot.parse_one(query, read=self.dialect_name) except SQLGlotParseError as e: # pragma: no cover - defensive self._raise_cte_parse_error(e) if not isinstance(parsed_expression, exp.Select): self._raise_cte_query_error( alias, f"query string must parse to SELECT, got {type(parsed_expression).__name__}" ) return parsed_expression if isinstance(query, exp.Select): return query self._raise_cte_query_error(alias, f"invalid query type: {type(query).__name__}") msg = "Unreachable" raise AssertionError(msg) def _add_parameter(self, value: Any, context: str | None = None) -> str: """Adds a parameter to the query and returns its placeholder name. Args: value: The value of the parameter. context: Optional context hint for parameter naming (e.g., "where", "join") Returns: str: The placeholder name for the parameter (e.g., :param_1 or :where_param_1). """ self._parameter_counter += 1 param_name = f"{context}_param_{self._parameter_counter}" if context else f"param_{self._parameter_counter}" self._parameters[param_name] = value return param_name
[docs] def add_parameter_for_expression(self, value: Any, context: str | None = None) -> str: """Add a parameter for expression parameterization. Args: value: The value of the parameter. context: Optional context hint for parameter naming. Returns: Parameter placeholder name. """ return self._add_parameter(value, context=context)
def _parameterize_expression(self, expression: exp.Expression) -> exp.Expression: """Replace literal values in an expression with bound parameters. This method traverses a SQLGlot expression tree and replaces literal values with parameter placeholders, adding the values to the builder's parameter collection. Args: expression: The SQLGlot expression to parameterize Returns: A new expression with literals replaced by parameter placeholders """ return cast("exp.Expression", expression.transform(_ExpressionParameterizer(self), copy=False))
[docs] def add_parameter(self: Self, value: Any, name: str | None = None) -> tuple[Self, str]: """Explicitly adds a parameter to the query. This is useful for parameters that are not directly tied to a builder method like `where` or `values`. Args: value: The value of the parameter. name: Optional explicit name for the parameter. If None, a name will be generated. Returns: tuple[Self, str]: The builder instance and the parameter name. """ if name: if name in self._parameters: self._raise_sql_builder_error(f"Parameter name '{name}' already exists.") self._parameters[name] = value return self, name self._parameter_counter += 1 param_name = f"param_{self._parameter_counter}" self._parameters[param_name] = value return self, param_name
[docs] def load_parameters(self, parameters: "Mapping[str, Any]") -> None: """Load a parameter mapping into the builder. Args: parameters: Mapping of parameter names to values. """ if not parameters: return for name, value in parameters.items(): if name in self._parameters: self._raise_sql_builder_error(f"Parameter name '{name}' already exists.") self._parameters[name] = value self._update_parameter_counter(name)
[docs] def load_ctes(self, ctes: "Iterable[exp.CTE]") -> None: """Load SQLGlot CTE nodes into the builder. Args: ctes: Iterable of CTE expressions to register. """ for cte in ctes: alias = self._resolve_cte_alias(cte) if alias in self._with_ctes: self._raise_sql_builder_error(f"CTE '{alias}' already exists.") self._with_ctes[alias] = cte
def _resolve_cte_alias(self, cte: exp.CTE) -> str: alias_name = cte.alias_or_name if not alias_name: self._raise_sql_builder_error("CTE alias is required.") return str(alias_name) def _update_parameter_counter(self, name: str) -> None: match = PARAMETER_INDEX_PATTERN.match(name) if not match: return index = int(match.group("index")) self._parameter_counter = max(self._parameter_counter, index) def _generate_unique_parameter_name(self, base_name: str) -> str: """Generate unique parameter name when collision occurs. Args: base_name: The desired base name for the parameter Returns: A unique parameter name that doesn't exist in current parameters """ current_index = self._parameter_name_counters.get(base_name, 0) if base_name not in self._parameters: # First use keeps the base name, counter stays at 0 self._parameter_name_counters[base_name] = current_index return base_name next_index = current_index + 1 candidate = f"{base_name}_{next_index}" while candidate in self._parameters: next_index += 1 if next_index > MAX_PARAMETER_COLLISION_ATTEMPTS: return f"{base_name}_{uuid.uuid4().hex[:8]}" candidate = f"{base_name}_{next_index}" self._parameter_name_counters[base_name] = next_index return candidate def _create_placeholder(self, value: Any, base_name: str) -> tuple[exp.Placeholder, str]: """Backwards-compatible placeholder helper (delegates to create_placeholder).""" return self.create_placeholder(value, base_name)
[docs] def create_placeholder(self, value: Any, base_name: str) -> tuple[exp.Placeholder, str]: """Create placeholder expression with a unique parameter name. Args: value: Parameter value to bind. base_name: Seed for parameter naming. Returns: Tuple of placeholder expression and the final parameter name. """ param_name = self._generate_unique_parameter_name(base_name) _, param_name = self.add_parameter(value, name=param_name) return exp.Placeholder(this=param_name), param_name
def _merge_cte_parameters(self, cte_name: str, parameters: dict[str, Any]) -> dict[str, str]: """Merge CTE parameters with unique naming to prevent collisions. Args: cte_name: The name of the CTE for parameter prefixing parameters: The CTE's parameter dictionary Returns: Mapping of old parameter names to new unique names """ param_mapping = {} for old_name, value in parameters.items(): new_name = self._generate_unique_parameter_name(f"{cte_name}_{old_name}") param_mapping[old_name] = new_name self.add_parameter(value, name=new_name) return param_mapping def _update_placeholders_in_expression( self, expression: exp.Expression, param_mapping: dict[str, str] ) -> exp.Expression: """Update parameter placeholders in expression to use new names. Args: expression: The SQLGlot expression to update param_mapping: Mapping of old parameter names to new names Returns: Updated expression with new placeholder names """ return cast("exp.Expression", expression.transform(_PlaceholderReplacer(param_mapping), copy=False)) def _generate_builder_cache_key(self, config: "StatementConfig | None" = None) -> str: """Generate cache key based on builder state and configuration. Args: config: Optional SQL configuration that affects the generated SQL Returns: A unique cache key representing the builder state and configuration """ dialect_name: str = self.dialect_name or "default" if self._expression is None: self._expression = self._create_base_expression() if self._expression: expr_sql = self._expression.sql() expr_hash = hashlib.blake2b(expr_sql.encode(), digest_size=8).hexdigest() else: expr_hash = "None" parameters_snapshot = sorted(self._parameters.items()) parameters_hash = hashlib.sha256(str(parameters_snapshot).encode()).hexdigest()[:8] state_parts = [ f"expression_hash:{expr_hash}", f"parameters_hash:{parameters_hash}", f"ctes:{sorted(self._with_ctes.keys())}", f"dialect:{dialect_name}", f"schema_hash:{hashlib.sha256(str(self.schema).encode()).hexdigest()[:8]}", f"optimization:{self.enable_optimization}", f"optimize_joins:{self.optimize_joins}", f"optimize_predicates:{self.optimize_predicates}", f"simplify_expressions:{self.simplify_expressions}", ] if config: config_parts = [ f"config_dialect:{config.dialect or 'default'}", f"enable_parsing:{config.enable_parsing}", f"enable_validation:{config.enable_validation}", f"enable_transformations:{config.enable_transformations}", f"enable_analysis:{config.enable_analysis}", f"enable_caching:{config.enable_caching}", f"param_style:{config.parameter_config.default_parameter_style.value}", ] state_parts.extend(config_parts) state_string = "|".join(state_parts) return f"builder:{hashlib.sha256(state_string.encode()).hexdigest()[:16]}"
[docs] def with_cte(self: Self, alias: str, query: "QueryBuilder | exp.Select | str") -> Self: """Adds a Common Table Expression (CTE) to the query. Args: alias: The alias for the CTE. query: The CTE query, which can be another QueryBuilder instance, a raw SQL string, or a sqlglot Select expression. Returns: Self: The current builder instance for method chaining. """ if alias in self._with_ctes: self._raise_sql_builder_error(f"CTE with alias '{alias}' already exists.") cte_select_expression = self._resolve_cte_query(alias, query) self._with_ctes[alias] = exp.CTE(this=cte_select_expression, alias=exp.to_table(alias)) return self
[docs] def build(self, dialect: DialectType = None) -> "BuiltQuery": """Builds the SQL query string and parameters. Args: dialect: Optional dialect override. If provided, generates SQL for this dialect instead of the builder's default dialect. Returns: BuiltQuery: A dataclass containing the SQL string and parameters. Examples: # Use builder's default dialect query = sql.select("*").from_("products") result = query.build() # Override dialect at build time postgres_sql = query.build(dialect="postgres") mysql_sql = query.build(dialect="mysql") """ final_expression = self._build_final_expression() if self.enable_optimization and isinstance(final_expression, exp.Expression): final_expression = self._optimize_expression(final_expression) target_dialect = str(dialect) if dialect else self.dialect_name try: if isinstance(final_expression, exp.Expression): normalized_expression = ( self._unquote_identifiers_for_oracle(final_expression) if self._is_oracle_dialect(target_dialect) else final_expression ) identify = self._should_identify(target_dialect) sql_string = normalized_expression.sql(dialect=target_dialect, pretty=True, identify=identify) sql_string = self._strip_lock_identifier_quotes(sql_string) else: sql_string = str(final_expression) except Exception as e: err_msg = f"Error generating SQL from expression: {e!s}" self._raise_sql_builder_error(err_msg, e) return BuiltQuery(sql=sql_string, parameters=self._parameters.copy(), dialect=dialect or self.dialect)
[docs] def to_sql(self, show_parameters: bool = False, dialect: DialectType = None) -> str: """Return SQL string with optional parameter substitution. Args: show_parameters: If True, replace parameter placeholders with actual values (for debugging). If False (default), return SQL with parameter placeholders. dialect: Optional dialect override. If provided, generates SQL for this dialect instead of the builder's default dialect. Returns: SQL string with or without parameter values filled in Examples: Get SQL with placeholders (for execution): sql_str = query.to_sql() # "SELECT * FROM products WHERE id = :id" Get SQL with values (for debugging): sql_str = query.to_sql(show_parameters=True) # "SELECT * FROM products WHERE id = 123" Override dialect at output time: postgres_sql = query.to_sql(dialect="postgres") mysql_sql = query.to_sql(dialect="mysql") Warning: SQL with show_parameters=True is for debugging ONLY. Never execute SQL with interpolated parameters directly - use parameterized queries. """ safe_query = self.build(dialect=dialect) if not show_parameters: return safe_query.sql sql = safe_query.sql parameters = safe_query.parameters for param_name, param_value in parameters.items(): placeholder = f":{param_name}" if isinstance(param_value, str): replacement = f"'{param_value}'" elif param_value is None: replacement = "NULL" elif isinstance(param_value, bool): replacement = "TRUE" if param_value else "FALSE" else: replacement = str(param_value) sql = sql.replace(placeholder, replacement) return sql
def _optimize_expression(self, expression: exp.Expression) -> exp.Expression: """Apply SQLGlot optimizations to the expression. Args: expression: The expression to optimize Returns: The optimized expression """ if not self.enable_optimization: return expression if not self.optimize_joins and not self.optimize_predicates and not self.simplify_expressions: return expression optimizer_settings = { "optimize_joins": self.optimize_joins, "pushdown_predicates": self.optimize_predicates, "simplify_expressions": self.simplify_expressions, } dialect_name = self.dialect_name or "default" cache_key = hash_optimized_expression( expression, dialect=dialect_name, schema=self.schema, optimizer_settings=optimizer_settings ) cache = get_cache() cached_optimized = cache.get_optimized(cache_key) if cached_optimized: return cast("exp.Expression", cached_optimized) try: optimized = optimize( expression, schema=self.schema, dialect=self.dialect_name, optimizer_settings=optimizer_settings ) cache.put_optimized(cache_key, optimized) except Exception: logger.debug("Expression optimization failed, using original expression") return expression else: return optimized
[docs] def to_statement(self, config: "StatementConfig | None" = None) -> "SQL": """Converts the built query into a SQL statement object. Args: config: Optional SQL configuration. Returns: SQL: A SQL statement object. """ cache_config = get_cache_config() if not cache_config.compiled_cache_enabled: return self._to_statement(config) cache_key_str = self._generate_builder_cache_key(config) cache = get_cache() cached_sql = cache.get_builder(cache_key_str) if cached_sql is not None: return cast("SQL", cached_sql) sql_statement = self._to_statement(config) cache.put_builder(cache_key_str, sql_statement) return sql_statement
def _to_statement(self, config: "StatementConfig | None" = None) -> "SQL": """Internal method to create SQL statement. Args: config: Optional SQL configuration. Returns: SQL: A SQL statement object. """ dialect_override = config.dialect if config else None safe_query = self.build(dialect=dialect_override) kwargs, parameters = self._extract_statement_parameters(safe_query.parameters) if config is None: config = StatementConfig( parameter_config=ParameterStyleConfig( default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} ), dialect=safe_query.dialect, ) sql_string = safe_query.sql if ( config.dialect is not None and config.dialect != safe_query.dialect and isinstance(self._expression, exp.Expression) ): try: identify = self._should_identify(config.dialect) sql_string = self._expression.sql(dialect=config.dialect, pretty=True, identify=identify) except Exception: sql_string = safe_query.sql if kwargs: return SQL(sql_string, statement_config=config, **kwargs) if parameters: return SQL(sql_string, *parameters, statement_config=config) return SQL(sql_string, statement_config=config) def _extract_statement_parameters( self, raw_parameters: Any ) -> "tuple[dict[str, Any] | None, tuple[Any, ...] | None]": """Extract parameters for SQL statement creation. Args: raw_parameters: Raw parameter data from BuiltQuery Returns: Tuple of (kwargs, parameters) for SQL statement construction """ if isinstance(raw_parameters, dict): return raw_parameters, None if isinstance(raw_parameters, tuple): return None, raw_parameters if raw_parameters: return None, tuple(raw_parameters) return None, None
[docs] def __str__(self) -> str: """Return the SQL string representation of the query. Returns: str: The SQL string for this query. """ return self.build().sql
@property def dialect_name(self) -> "str | None": """Returns the name of the dialect, if set.""" if isinstance(self.dialect, str): return self.dialect if self.dialect is None: return None if isinstance(self.dialect, type) and issubclass(self.dialect, Dialect): return self.dialect.__name__.lower() if isinstance(self.dialect, Dialect): return type(self.dialect).__name__.lower() if has_name(self.dialect): return self.dialect.__name__.lower() return str(self.dialect).lower() def _merge_sql_object_parameters(self, sql_obj: Any) -> None: """Merge parameters from a SQL object into the builder. Args: sql_obj: Object with parameters attribute containing parameter mappings """ if not has_expression_and_parameters(sql_obj): return sql_parameters = sql_obj.parameters for param_name, param_value in sql_parameters.items(): unique_name = self._generate_unique_parameter_name(param_name) self.add_parameter(param_value, name=unique_name) @property def parameters(self) -> dict[str, Any]: """Public access to query parameters.""" return self._parameters
[docs] def set_parameters(self, parameters: dict[str, Any]) -> None: """Set query parameters (public API).""" self._parameters = parameters.copy()
def _is_oracle_dialect(self, dialect: "DialectType | str | None") -> bool: """Check if target dialect is Oracle.""" if dialect is None: return False return str(dialect).lower() == "oracle" def _unquote_identifiers_for_oracle(self, expression: exp.Expression) -> exp.Expression: """Remove identifier quoting to avoid Oracle case-sensitive lookup issues.""" return cast("exp.Expression", expression.copy().transform(_unquote_identifier, copy=False)) def _strip_lock_identifier_quotes(self, sql_string: str) -> str: for keyword in ("FOR UPDATE OF ", "FOR SHARE OF "): if keyword in sql_string and not self._lock_targets_quoted: head, tail = sql_string.split(keyword, 1) tail = tail.replace('"', "") return f"{head}{keyword}{tail}" if sql_string.startswith('MERGE INTO "') and not self._merge_target_quoted: # Remove quotes around target table only, leave alias/rest intact end_quote = sql_string.find('"', len('MERGE INTO "')) if end_quote > 0: table_name = sql_string[len('MERGE INTO "') : end_quote] remainder = sql_string[end_quote + 1 :] return f"MERGE INTO {table_name}{remainder}" return sql_string def _should_identify(self, dialect: "DialectType | str | None") -> bool: """Determine whether to quote identifiers for the given dialect.""" if dialect is None: return True dialect_name = str(dialect).lower() # Oracle folds unquoted identifiers to uppercase; quoting lower-case breaks table lookup return dialect_name != "oracle" @property def with_ctes(self) -> "dict[str, exp.CTE]": """Get WITH clause CTEs (public API).""" return dict(self._with_ctes)
[docs] def generate_unique_parameter_name(self, base_name: str) -> str: """Generate unique parameter name (public API).""" return self._generate_unique_parameter_name(base_name)
[docs] def build_static_expression( self, expression: exp.Expression | None = None, parameters: dict[str, Any] | None = None, *, cache_key: str | None = None, expression_factory: Callable[[], exp.Expression] | None = None, copy: bool = True, optimize_expression: bool | None = None, dialect: DialectType | None = None, ) -> "BuiltQuery": """Compile a pre-built expression with optional caching and parameters. Designed for hot paths that construct an AST once and reuse it with different parameters, avoiding repeated parse/optimize cycles. Args: expression: Pre-built sqlglot expression to render (required when cache_key is not provided). parameters: Optional parameter mapping to include in the result. cache_key: When provided, the expression will be cached under this key. expression_factory: Factory used to build the expression on cache miss. copy: Copy the expression before rendering to avoid caller mutation. optimize_expression: Override builder optimization toggle for this call. dialect: Optional dialect override for SQL generation. Returns: BuiltQuery containing SQL and parameters. """ expr: exp.Expression | None = None if cache_key is not None: cache = get_cache() cached_expr = cache.get_expression(cache_key) if cached_expr is None: if expression_factory is None: msg = "expression_factory is required when cache_key is provided" self._raise_sql_builder_error(msg) expr_candidate = expression_factory() if not is_expression(expr_candidate): self._raise_invalid_expression_type(expr_candidate) expr_to_store = expr_candidate.copy() if copy else expr_candidate should_optimize = self.enable_optimization if optimize_expression is None else optimize_expression if should_optimize: expr_to_store = self._optimize_expression(expr_to_store) cache.put_expression(cache_key, expr_to_store) cached_expr = expr_to_store expr = cached_expr.copy() if copy else cached_expr else: if expression is None: msg = "expression must be provided when cache_key is not set" self._raise_sql_builder_error(msg) expr = expression.copy() if copy else expression should_optimize = self.enable_optimization if optimize_expression is None else optimize_expression if should_optimize: expr = self._optimize_expression(expr) if expr is None: self._raise_sql_builder_error("Static expression could not be resolved.") target_dialect = str(dialect) if dialect else self.dialect_name identify = self._should_identify(target_dialect) sql_string = expr.sql(dialect=target_dialect, pretty=True, identify=identify) return BuiltQuery( sql=sql_string, parameters=parameters.copy() if parameters else {}, dialect=dialect or self.dialect )
[docs] class ExpressionBuilder(QueryBuilder): """Builder wrapper for a pre-parsed SQLGlot expression.""" __slots__ = ()
[docs] def __init__(self, expression: exp.Expression, **kwargs: Any) -> None: (dialect, schema, enable_optimization, optimize_joins, optimize_predicates, simplify_expressions) = ( self._parse_query_builder_kwargs(kwargs) ) super().__init__( dialect=dialect, schema=schema, enable_optimization=enable_optimization, optimize_joins=optimize_joins, optimize_predicates=optimize_predicates, simplify_expressions=simplify_expressions, ) if not is_expression(expression): self._raise_invalid_expression_type(expression) self._expression = expression
def _create_base_expression(self) -> exp.Expression: if self._expression is None: msg = "ExpressionBuilder requires an expression at construction." self._raise_sql_builder_error(msg) return self._expression @property def _expected_result_type(self) -> "type[SQLResult]": return SQLResult