Source code for sqlspec.core.parameters._converter

"""Parameter style conversion utilities."""

from collections.abc import Callable, Mapping, Sequence
from typing import Any

from mypy_extensions import mypyc_attr

from sqlspec.core.parameters._types import ParameterInfo, ParameterStyle
from sqlspec.core.parameters._validator import ParameterValidator

__all__ = ("ParameterConverter",)


[docs] @mypyc_attr(allow_interpreted_subclasses=False) class ParameterConverter: """Parameter style conversion helper.""" __slots__ = ("_format_converters", "_placeholder_generators", "validator")
[docs] def __init__(self) -> None: self.validator = ParameterValidator() self._format_converters = { ParameterStyle.POSITIONAL_COLON: self._convert_to_positional_colon_format, ParameterStyle.NAMED_COLON: self._convert_to_named_colon_format, ParameterStyle.NAMED_PYFORMAT: self._convert_to_named_pyformat_format, ParameterStyle.QMARK: self._convert_to_positional_format, ParameterStyle.NUMERIC: self._convert_to_positional_format, ParameterStyle.POSITIONAL_PYFORMAT: self._convert_to_positional_format, ParameterStyle.NAMED_AT: self._convert_to_named_colon_format, ParameterStyle.NAMED_DOLLAR: self._convert_to_named_colon_format, } self._placeholder_generators: dict[ParameterStyle, Callable[[Any], str]] = { ParameterStyle.QMARK: lambda _: "?", ParameterStyle.NUMERIC: lambda i: f"${int(i) + 1}", ParameterStyle.NAMED_COLON: lambda name: f":{name}", ParameterStyle.POSITIONAL_COLON: lambda i: f":{int(i) + 1}", ParameterStyle.NAMED_AT: lambda name: f"@{name}", ParameterStyle.NAMED_DOLLAR: lambda name: f"${name}", ParameterStyle.NAMED_PYFORMAT: lambda name: f"%({name})s", ParameterStyle.POSITIONAL_PYFORMAT: lambda _: "%s", }
[docs] def normalize_sql_for_parsing(self, sql: str, dialect: str | None = None) -> "tuple[str, list[ParameterInfo]]": param_info = self.validator.extract_parameters(sql) incompatible_styles = self.validator.get_sqlglot_incompatible_styles(dialect) needs_conversion = any(p.style in incompatible_styles for p in param_info) if not needs_conversion: return sql, param_info converted_sql = self._convert_to_sqlglot_compatible(sql, param_info, incompatible_styles) return converted_sql, param_info
def _convert_to_sqlglot_compatible( self, sql: str, param_info: "list[ParameterInfo]", incompatible_styles: "set[ParameterStyle]" ) -> str: converted_sql = sql for param in reversed(param_info): if param.style in incompatible_styles: canonical_placeholder = f":param_{param.ordinal}" converted_sql = ( converted_sql[: param.position] + canonical_placeholder + converted_sql[param.position + len(param.placeholder_text) :] ) return converted_sql
[docs] def convert_placeholder_style( self, sql: str, parameters: Any, target_style: ParameterStyle, is_many: bool = False ) -> tuple[str, Any]: param_info = self.validator.extract_parameters(sql) if target_style == ParameterStyle.STATIC: return self._embed_static_parameters(sql, parameters, param_info) current_styles = {p.style for p in param_info} if len(current_styles) == 1 and target_style in current_styles: converted_parameters = self._convert_parameter_format( parameters, param_info, target_style, parameters, preserve_parameter_format=True ) return sql, converted_parameters converted_sql = self._convert_placeholders_to_style(sql, param_info, target_style) converted_parameters = self._convert_parameter_format( parameters, param_info, target_style, parameters, preserve_parameter_format=True ) return converted_sql, converted_parameters
def _convert_placeholders_to_style( self, sql: str, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" ) -> str: generator = self._placeholder_generators.get(target_style) if generator is None: msg = f"Unsupported target parameter style: {target_style}" raise ValueError(msg) param_styles = {p.style for p in param_info} use_sequential_for_qmark = ( len(param_styles) == 1 and ParameterStyle.QMARK in param_styles and target_style == ParameterStyle.NUMERIC ) unique_params: dict[str, int] = {} for param in param_info: param_key = ( f"{param.placeholder_text}_{param.ordinal}" if use_sequential_for_qmark and param.style == ParameterStyle.QMARK else param.placeholder_text ) if param_key not in unique_params: unique_params[param_key] = len(unique_params) converted_sql = sql placeholder_text_len_cache: dict[str, int] = {} for param in reversed(param_info): if param.placeholder_text not in placeholder_text_len_cache: placeholder_text_len_cache[param.placeholder_text] = len(param.placeholder_text) text_len = placeholder_text_len_cache[param.placeholder_text] if target_style in { ParameterStyle.QMARK, ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.POSITIONAL_COLON, }: param_key = ( f"{param.placeholder_text}_{param.ordinal}" if use_sequential_for_qmark and param.style == ParameterStyle.QMARK else param.placeholder_text ) new_placeholder = generator(unique_params[param_key]) else: param_name = param.name or f"param_{param.ordinal}" new_placeholder = generator(param_name) converted_sql = ( converted_sql[: param.position] + new_placeholder + converted_sql[param.position + text_len :] ) return converted_sql def _convert_sequence_to_dict( self, parameters: Sequence[Any], param_info: "list[ParameterInfo]" ) -> "dict[str, Any]": param_dict: dict[str, Any] = {} for i, param in enumerate(param_info): if i < len(parameters): name = param.name or f"param_{param.ordinal}" param_dict[name] = parameters[i] return param_dict def _extract_param_value_mixed_styles( self, param: "ParameterInfo", parameters: Mapping[str, Any], param_keys: "list[str]" ) -> "tuple[Any, bool]": if param.name and param.name in parameters: return parameters[param.name], True if ( param.style == ParameterStyle.NUMERIC and param.name and param.name.isdigit() and param.ordinal < len(param_keys) ): key_to_use = param_keys[param.ordinal] return parameters[key_to_use], True if f"param_{param.ordinal}" in parameters: return parameters[f"param_{param.ordinal}"], True ordinal_key = str(param.ordinal + 1) if ordinal_key in parameters: return parameters[ordinal_key], True try: ordered_keys = list(parameters.keys()) except AttributeError: ordered_keys = [] if ordered_keys and param.ordinal < len(ordered_keys): key = ordered_keys[param.ordinal] if key in parameters: return parameters[key], True return None, False def _extract_param_value_single_style( self, param: "ParameterInfo", parameters: Mapping[str, Any] ) -> "tuple[Any, bool]": if param.name and param.name in parameters: return parameters[param.name], True if f"param_{param.ordinal}" in parameters: return parameters[f"param_{param.ordinal}"], True ordinal_key = str(param.ordinal + 1) if ordinal_key in parameters: return parameters[ordinal_key], True try: ordered_keys = list(parameters.keys()) except AttributeError: ordered_keys = [] if ordered_keys and param.ordinal < len(ordered_keys): key = ordered_keys[param.ordinal] if key in parameters: return parameters[key], True return None, False def _preserve_original_format(self, param_values: list[Any], original_parameters: Any) -> Any: if isinstance(original_parameters, tuple): return tuple(param_values) if isinstance(original_parameters, list): return param_values if isinstance(original_parameters, Mapping): return tuple(param_values) if hasattr(original_parameters, "__class__") and callable(original_parameters.__class__): try: return original_parameters.__class__(param_values) except (TypeError, ValueError): return tuple(param_values) return param_values def _convert_parameter_format( self, parameters: Any, param_info: "list[ParameterInfo]", target_style: "ParameterStyle", original_parameters: Any = None, preserve_parameter_format: bool = False, ) -> Any: if not parameters or not param_info: return parameters is_named_style = target_style in { ParameterStyle.NAMED_COLON, ParameterStyle.NAMED_AT, ParameterStyle.NAMED_DOLLAR, ParameterStyle.NAMED_PYFORMAT, } if is_named_style: if isinstance(parameters, Mapping): return parameters if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): return self._convert_sequence_to_dict(parameters, param_info) elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): return parameters elif isinstance(parameters, Mapping): param_values: list[Any] = [] parameter_styles = {p.style for p in param_info} has_mixed_styles = len(parameter_styles) > 1 unique_params: dict[str, Any] = {} param_order: list[str] = [] if has_mixed_styles: param_keys = list(parameters.keys()) for param in param_info: param_key = param.placeholder_text if param_key not in unique_params: value, found = self._extract_param_value_mixed_styles(param, parameters, param_keys) if found: unique_params[param_key] = value param_order.append(param_key) else: for param in param_info: param_key = param.placeholder_text if param_key not in unique_params: value, found = self._extract_param_value_single_style(param, parameters) if found: unique_params[param_key] = value param_order.append(param_key) param_values = [unique_params[param_key] for param_key in param_order] if preserve_parameter_format and original_parameters is not None: return self._preserve_original_format(param_values, original_parameters) return param_values return parameters def _embed_static_parameters( self, sql: str, parameters: Any, param_info: "list[ParameterInfo]" ) -> "tuple[str, Any]": if not param_info: return sql, None unique_params: dict[str, int] = {} for param in param_info: if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}: param_key = f"{param.placeholder_text}_{param.ordinal}" elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name: param_key = param.placeholder_text else: param_key = f"{param.placeholder_text}_{param.ordinal}" if param_key not in unique_params: unique_params[param_key] = len(unique_params) static_sql = sql for param in reversed(param_info): param_value = self._get_parameter_value_with_reuse(parameters, param, unique_params) if param_value is None: literal = "NULL" elif isinstance(param_value, str): escaped = param_value.replace("'", "''") literal = f"'{escaped}'" elif isinstance(param_value, bool): literal = "TRUE" if param_value else "FALSE" elif isinstance(param_value, (int, float)): literal = str(param_value) else: literal = f"'{param_value!s}'" static_sql = ( static_sql[: param.position] + literal + static_sql[param.position + len(param.placeholder_text) :] ) return static_sql, None def _get_parameter_value(self, parameters: Any, param: "ParameterInfo") -> Any: if isinstance(parameters, Mapping): if param.name and param.name in parameters: return parameters[param.name] if f"param_{param.ordinal}" in parameters: return parameters[f"param_{param.ordinal}"] if str(param.ordinal + 1) in parameters: return parameters[str(param.ordinal + 1)] elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): if param.ordinal < len(parameters): return parameters[param.ordinal] return None def _get_parameter_value_with_reuse( self, parameters: Any, param: "ParameterInfo", unique_params: "dict[str, int]" ) -> Any: if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}: param_key = f"{param.placeholder_text}_{param.ordinal}" elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name: param_key = param.placeholder_text else: param_key = f"{param.placeholder_text}_{param.ordinal}" unique_ordinal = unique_params.get(param_key) if unique_ordinal is None: return None if isinstance(parameters, Mapping): if param.name and param.name in parameters: return parameters[param.name] if f"param_{unique_ordinal}" in parameters: return parameters[f"param_{unique_ordinal}"] if str(unique_ordinal + 1) in parameters: return parameters[str(unique_ordinal + 1)] elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): if unique_ordinal < len(parameters): return parameters[unique_ordinal] return None def _convert_to_positional_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: return self._convert_parameter_format( parameters, param_info, ParameterStyle.QMARK, parameters, preserve_parameter_format=False ) def _convert_to_named_colon_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: return self._convert_parameter_format( parameters, param_info, ParameterStyle.NAMED_COLON, parameters, preserve_parameter_format=False ) def _convert_to_positional_colon_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: if isinstance(parameters, Mapping): return parameters param_dict: dict[str, Any] = {} if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): for index, value in enumerate(parameters): param_dict[str(index + 1)] = value return param_dict def _convert_to_named_pyformat_format(self, parameters: Any, param_info: "list[ParameterInfo]") -> Any: return self._convert_parameter_format( parameters, param_info, ParameterStyle.NAMED_PYFORMAT, parameters, preserve_parameter_format=False )