Source code for sqlspec.core.compiler

"""SQL compilation and caching.

Components:
- CompiledSQL: Immutable compilation result
- SQLProcessor: SQL compiler with caching
- Parameter processing via ParameterProcessor
"""

import hashlib
from collections import OrderedDict
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, Optional

import sqlglot
from mypy_extensions import mypyc_attr
from sqlglot import expressions as exp
from sqlglot.errors import ParseError

import sqlspec.exceptions
from sqlspec.core.parameters import ParameterProcessor, ParameterProfile, validate_parameter_alignment
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
    import logging

    from sqlspec.core.statement import StatementConfig


__all__ = (
    "CompiledSQL",
    "OperationProfile",
    "OperationType",
    "SQLProcessor",
    "is_copy_from_operation",
    "is_copy_operation",
    "is_copy_to_operation",
)

logger: "logging.Logger" = get_logger("sqlspec.core.compiler")
OperationType = Literal[
    "SELECT",
    "INSERT",
    "UPDATE",
    "DELETE",
    "COPY",
    "COPY_FROM",
    "COPY_TO",
    "EXECUTE",
    "SCRIPT",
    "DDL",
    "PRAGMA",
    "MERGE",
    "UNKNOWN",
]

OPERATION_TYPE_MAP: "dict[type[exp.Expression], OperationType]" = {
    exp.Select: "SELECT",
    exp.Insert: "INSERT",
    exp.Update: "UPDATE",
    exp.Delete: "DELETE",
    exp.Pragma: "PRAGMA",
    exp.Command: "EXECUTE",
    exp.Create: "DDL",
    exp.Drop: "DDL",
    exp.Alter: "DDL",
    exp.Merge: "MERGE",
}

COPY_OPERATION_TYPES: "tuple[OperationType, ...]" = ("COPY", "COPY_FROM", "COPY_TO")

COPY_FROM_OPERATION_TYPES: "tuple[OperationType, ...]" = ("COPY", "COPY_FROM")

COPY_TO_OPERATION_TYPES: "tuple[OperationType, ...]" = ("COPY_TO",)


def is_copy_operation(operation_type: "OperationType") -> bool:
    """Determine if the operation corresponds to any PostgreSQL COPY variant.

    Args:
        operation_type: Operation type detected by the compiler.

    Returns:
        True when the operation type represents COPY, COPY FROM, or COPY TO.
    """

    return operation_type in COPY_OPERATION_TYPES


def is_copy_from_operation(operation_type: "OperationType") -> bool:
    """Check if the operation streams data into the database using COPY.

    Args:
        operation_type: Operation type detected by the compiler.

    Returns:
        True for COPY operations that read from client input (COPY FROM).
    """

    return operation_type in COPY_FROM_OPERATION_TYPES


def is_copy_to_operation(operation_type: "OperationType") -> bool:
    """Check if the operation streams data out from the database using COPY.

    Args:
        operation_type: Operation type detected by the compiler.

    Returns:
        True for COPY operations that write to client output (COPY TO).
    """

    return operation_type in COPY_TO_OPERATION_TYPES


@mypyc_attr(allow_interpreted_subclasses=False)
class OperationProfile:
    """Semantic characteristics derived from the parsed SQL expression."""

    __slots__ = ("modifies_rows", "returns_rows")

    def __init__(self, returns_rows: bool = False, modifies_rows: bool = False) -> None:
        self.returns_rows = returns_rows
        self.modifies_rows = modifies_rows

    @classmethod
    def empty(cls) -> "OperationProfile":
        return cls(returns_rows=False, modifies_rows=False)

    def __repr__(self) -> str:
        return f"OperationProfile(returns_rows={self.returns_rows!r}, modifies_rows={self.modifies_rows!r})"


def _is_effectively_empty_parameters(value: Any) -> bool:
    if value is None:
        return True
    if isinstance(value, Mapping):
        return len(value) == 0
    if isinstance(value, (list, tuple, set, frozenset)):
        return len(value) == 0
    return False


[docs] @mypyc_attr(allow_interpreted_subclasses=False) class CompiledSQL: """Compiled SQL result. Contains the result of SQL compilation with information needed for execution. Immutable container holding compiled SQL text, processed parameters, operation type, and execution metadata. """ __slots__ = ( "_hash", "compiled_sql", "execution_parameters", "expression", "operation_profile", "operation_type", "parameter_casts", "parameter_profile", "parameter_style", "supports_many", ) operation_type: "OperationType"
[docs] def __init__( self, compiled_sql: str, execution_parameters: Any, operation_type: "OperationType", expression: Optional["exp.Expression"] = None, parameter_style: str | None = None, supports_many: bool = False, parameter_casts: Optional["dict[int, str]"] = None, parameter_profile: "ParameterProfile | None" = None, operation_profile: "OperationProfile | None" = None, ) -> None: """Initialize compiled result. Args: compiled_sql: SQL string ready for execution execution_parameters: Parameters in driver-specific format operation_type: SQL operation type (SELECT, INSERT, etc.) expression: SQLGlot AST expression parameter_style: Parameter style used in compilation supports_many: Whether this supports execute_many operations parameter_casts: Mapping of parameter positions to cast types parameter_profile: Profile describing detected placeholders operation_profile: Profile describing semantic characteristics """ self.compiled_sql = compiled_sql self.execution_parameters = execution_parameters self.operation_type = operation_type self.expression = expression self.parameter_style = parameter_style self.supports_many = supports_many self.parameter_casts = parameter_casts or {} self.parameter_profile = parameter_profile self.operation_profile = operation_profile or OperationProfile.empty() self._hash: int | None = None
[docs] def __hash__(self) -> int: """Cached hash value.""" if self._hash is None: param_str = str(self.execution_parameters) self._hash = hash((self.compiled_sql, param_str, self.operation_type, self.parameter_style)) return self._hash
[docs] def __eq__(self, other: object) -> bool: """Equality comparison.""" if not isinstance(other, CompiledSQL): return False return ( self.compiled_sql == other.compiled_sql and self.execution_parameters == other.execution_parameters and self.operation_type == other.operation_type and self.parameter_style == other.parameter_style )
[docs] def __repr__(self) -> str: """String representation.""" return ( f"CompiledSQL(sql={self.compiled_sql!r}, " f"params={self.execution_parameters!r}, " f"type={self.operation_type!r})" )
[docs] @mypyc_attr(allow_interpreted_subclasses=False) class SQLProcessor: """SQL processor with compilation and caching. Processes SQL statements by compiling them into executable format with parameter substitution. Includes LRU-style caching for compilation results to avoid re-processing identical statements. """ __slots__ = ("_cache", "_cache_hits", "_cache_misses", "_config", "_max_cache_size", "_parameter_processor")
[docs] def __init__(self, config: "StatementConfig", max_cache_size: int = 1000) -> None: """Initialize processor. Args: config: Statement configuration max_cache_size: Maximum number of compilation results to cache """ self._config = config self._cache: OrderedDict[str, CompiledSQL] = OrderedDict() self._parameter_processor = ParameterProcessor() self._max_cache_size = max_cache_size self._cache_hits = 0 self._cache_misses = 0
[docs] def compile(self, sql: str, parameters: Any = None, is_many: bool = False) -> CompiledSQL: """Compile SQL statement. Args: sql: SQL string for compilation parameters: Parameter values for substitution is_many: Whether this is for execute_many operation Returns: CompiledSQL with execution information """ if not self._config.enable_caching: return self._compile_uncached(sql, parameters, is_many) cache_key = self._make_cache_key(sql, parameters, is_many) if cache_key in self._cache: result = self._cache[cache_key] del self._cache[cache_key] self._cache[cache_key] = result self._cache_hits += 1 return result self._cache_misses += 1 result = self._compile_uncached(sql, parameters, is_many) if len(self._cache) >= self._max_cache_size: self._cache.popitem(last=False) self._cache[cache_key] = result return result
def _compile_uncached(self, sql: str, parameters: Any, is_many: bool = False) -> CompiledSQL: """Compile SQL without caching. Args: sql: SQL string parameters: Parameter values is_many: Whether this is for execute_many operation Returns: CompiledSQL result """ parameter_profile = ParameterProfile.empty() operation_profile = OperationProfile.empty() try: dialect_str = str(self._config.dialect) if self._config.dialect else None processed_sql: str processed_params: Any process_result = self._parameter_processor.process( sql=sql, parameters=parameters, config=self._config.parameter_config, dialect=dialect_str, is_many=is_many, ) processed_sql = process_result.sql processed_params = process_result.parameters parameter_profile = process_result.parameter_profile if self._config.parameter_config.needs_static_script_compilation and processed_params is None: sqlglot_sql = processed_sql else: sqlglot_sql, _ = self._parameter_processor.get_sqlglot_compatible_sql( sql, parameters, self._config.parameter_config, dialect_str ) final_parameters = processed_params ast_was_transformed = False expression = None operation_type: OperationType = "EXECUTE" parameter_casts: dict[int, str] = {} if self._config.enable_parsing: try: expression = sqlglot.parse_one(sqlglot_sql, dialect=dialect_str) operation_type = self._detect_operation_type(expression) parameter_casts = self._detect_parameter_casts(expression) operation_profile = self._build_operation_profile(expression, operation_type) ast_transformer = self._config.parameter_config.ast_transformer if ast_transformer: expression, final_parameters = ast_transformer(expression, processed_params) ast_was_transformed = True except ParseError: expression = None operation_type = "EXECUTE" parameter_casts = {} operation_profile = OperationProfile.empty() if self._config.parameter_config.needs_static_script_compilation and processed_params is None: final_sql, final_params = processed_sql, processed_params elif ast_was_transformed and expression is not None: transformed_result = self._parameter_processor.process( sql=expression.sql(dialect=dialect_str), parameters=final_parameters, config=self._config.parameter_config, dialect=dialect_str, is_many=is_many, ) final_sql = transformed_result.sql final_params = transformed_result.parameters parameter_profile = transformed_result.parameter_profile output_transformer = self._config.output_transformer if output_transformer: final_sql, final_params = output_transformer(final_sql, final_params) else: final_sql, final_params = self._apply_final_transformations( expression, processed_sql, final_parameters, dialect_str ) should_validate = True if ( _is_effectively_empty_parameters(final_params) and _is_effectively_empty_parameters(parameters) and not is_many ): should_validate = False if self._config.enable_validation and should_validate: validate_parameter_alignment(parameter_profile, final_params, is_many=is_many) return CompiledSQL( compiled_sql=final_sql, execution_parameters=final_params, operation_type=operation_type, expression=expression, parameter_style=self._config.parameter_config.default_parameter_style.value, supports_many=isinstance(final_params, list) and len(final_params) > 0, parameter_casts=parameter_casts, parameter_profile=parameter_profile, operation_profile=operation_profile, ) except sqlspec.exceptions.SQLSpecError: raise except Exception as e: logger.warning("Compilation failed, using fallback: %s", e) return CompiledSQL( compiled_sql=sql, execution_parameters=parameters, operation_type="UNKNOWN", parameter_casts={}, parameter_profile=parameter_profile, operation_profile=operation_profile, ) def _make_cache_key(self, sql: str, parameters: Any, is_many: bool = False) -> str: """Generate cache key. Args: sql: SQL string parameters: Parameter values is_many: Whether this is for execute_many operation Returns: Cache key string """ param_repr = repr(parameters) dialect_str = str(self._config.dialect) if self._config.dialect else None param_style = self._config.parameter_config.default_parameter_style.value hash_data = ( sql, param_repr, param_style, dialect_str, self._config.enable_parsing, self._config.enable_transformations, is_many, ) hash_str = hashlib.sha256(str(hash_data).encode("utf-8")).hexdigest()[:16] return f"sql_{hash_str}" def _detect_operation_type(self, expression: "exp.Expression") -> "OperationType": """Detect operation type from AST. Args: expression: AST expression Returns: Operation type literal """ expr_type = type(expression) if expr_type in OPERATION_TYPE_MAP: return OPERATION_TYPE_MAP[expr_type] # pyright: ignore if isinstance(expression, exp.Copy): copy_kind = expression.args.get("kind") if copy_kind is True: return "COPY_FROM" if copy_kind is False: return "COPY_TO" return "COPY" return "UNKNOWN" def _detect_parameter_casts(self, expression: Optional["exp.Expression"]) -> "dict[int, str]": """Detect explicit type casts on parameters in the AST. Args: expression: SQLGlot AST expression to analyze Returns: Dict mapping parameter positions (1-based) to cast type names """ if not expression: return {} cast_positions: dict[int, str] = {} placeholder_positions: dict[str, int] = {} placeholder_counter = 0 def _assign_placeholder_position(placeholder: "exp.Placeholder") -> "int | None": nonlocal placeholder_counter name_expr = placeholder.name if placeholder.name is not None else None if name_expr is not None: placeholder_key = str(name_expr) else: value = placeholder.args.get("this") placeholder_key = str(value) if value is not None else placeholder.sql() if not placeholder_key: return None if placeholder_key not in placeholder_positions: placeholder_counter += 1 placeholder_positions[placeholder_key] = placeholder_counter return placeholder_positions[placeholder_key] # Walk all nodes in order to track parameter positions for node in expression.walk(): if isinstance(node, exp.Placeholder): _assign_placeholder_position(node) # Check for cast nodes with parameter children if isinstance(node, exp.Cast): cast_target = node.this position = None if isinstance(cast_target, exp.Parameter): # Handle $1, $2 style parameters param_value = cast_target.this if isinstance(param_value, exp.Literal): position = int(param_value.this) elif isinstance(cast_target, exp.Placeholder): position = _assign_placeholder_position(cast_target) elif isinstance(cast_target, exp.Column): # Handle cases where $1 gets parsed as a column column_name = str(cast_target.this) if cast_target.this else str(cast_target) if column_name.startswith("$") and column_name[1:].isdigit(): position = int(column_name[1:]) if position is not None: # Extract cast type if isinstance(node.to, exp.DataType): cast_type = node.to.this.value if hasattr(node.to.this, "value") else str(node.to.this) else: cast_type = str(node.to) cast_positions[position] = cast_type.upper() return cast_positions def _apply_final_transformations( self, expression: "exp.Expression | None", sql: str, parameters: Any, dialect_str: "str | None" ) -> "tuple[str, Any]": """Apply final transformations. Args: expression: SQLGlot AST expression sql: SQL string parameters: Execution parameters dialect_str: SQL dialect Returns: Tuple of (final_sql, final_parameters) """ output_transformer = self._config.output_transformer if output_transformer: if expression is not None: ast_sql = expression.sql(dialect=dialect_str) return output_transformer(ast_sql, parameters) return output_transformer(sql, parameters) return sql, parameters def _build_operation_profile( self, expression: "exp.Expression | None", operation_type: "OperationType" ) -> "OperationProfile": if expression is None: return OperationProfile.empty() returns_rows = False modifies_rows = False expr = expression if isinstance(expr, (exp.Select, exp.Values, exp.Table, exp.TableSample, exp.With)): returns_rows = True elif isinstance(expr, (exp.Insert, exp.Update, exp.Delete, exp.Merge)): modifies_rows = True returns_rows = bool(expr.args.get("returning")) elif isinstance(expr, exp.Copy): copy_kind = expr.args.get("kind") modifies_rows = copy_kind is True returns_rows = copy_kind is False if not returns_rows and operation_type in {"SELECT", "WITH", "VALUES", "TABLE"}: returns_rows = True if not modifies_rows and operation_type in {"INSERT", "UPDATE", "DELETE", "MERGE"}: modifies_rows = True return OperationProfile(returns_rows=returns_rows, modifies_rows=modifies_rows)
[docs] def clear_cache(self) -> None: """Clear compilation cache and reset statistics.""" self._cache.clear() self._cache_hits = 0 self._cache_misses = 0
@property def cache_stats(self) -> "dict[str, int]": """Get cache statistics. Returns: Dictionary with cache statistics """ total_requests = self._cache_hits + self._cache_misses hit_rate_pct = int((self._cache_hits / total_requests) * 100) if total_requests > 0 else 0 return { "hits": self._cache_hits, "misses": self._cache_misses, "size": len(self._cache), "max_size": self._max_cache_size, "hit_rate_percent": hit_rate_pct, }