"""Parameter style conversion utilities."""
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from mypy_extensions import mypyc_attr
from sqlspec.core.parameters._types import (
ConvertedParameters,
NamedParameterOutput,
ParameterInfo,
ParameterMapping,
ParameterPayload,
ParameterSequence,
ParameterStyle,
PositionalParameterOutput,
)
from sqlspec.core.parameters._validator import ParameterValidator
from sqlspec.exceptions import SQLSpecError
__all__ = ("ParameterConverter",)
def _placeholder_qmark(_: Any) -> str:
return "?"
def _placeholder_numeric(index: Any) -> str:
return f"${int(index) + 1}"
def _placeholder_named_colon(name: Any) -> str:
return f":{name}"
def _placeholder_positional_colon(index: Any) -> str:
return f":{int(index) + 1}"
def _placeholder_named_at(name: Any) -> str:
return f"@{name}"
def _placeholder_named_dollar(name: Any) -> str:
return f"${name}"
def _placeholder_named_pyformat(name: Any) -> str:
return f"%({name})s"
def _placeholder_positional_pyformat(_: Any) -> str:
return "%s"
[docs]
@mypyc_attr(allow_interpreted_subclasses=False)
class ParameterConverter:
"""Parameter style conversion helper."""
__slots__ = ("_placeholder_generators", "validator")
[docs]
def __init__(self, validator: "ParameterValidator | None" = None) -> None:
self.validator = validator or ParameterValidator()
self._placeholder_generators: dict[ParameterStyle, Callable[[Any], str]] = {
ParameterStyle.QMARK: _placeholder_qmark,
ParameterStyle.NUMERIC: _placeholder_numeric,
ParameterStyle.NAMED_COLON: _placeholder_named_colon,
ParameterStyle.POSITIONAL_COLON: _placeholder_positional_colon,
ParameterStyle.NAMED_AT: _placeholder_named_at,
ParameterStyle.NAMED_DOLLAR: _placeholder_named_dollar,
ParameterStyle.NAMED_PYFORMAT: _placeholder_named_pyformat,
ParameterStyle.POSITIONAL_PYFORMAT: _placeholder_positional_pyformat,
}
[docs]
def convert_placeholder_style(
self,
sql: str,
parameters: "ParameterPayload",
target_style: "ParameterStyle",
is_many: bool = False,
*,
strict_named_parameters: bool = True,
) -> "tuple[str, ConvertedParameters]":
param_info = self.validator.extract_parameters(sql)
if target_style == ParameterStyle.STATIC:
return self._embed_static_parameters(sql, parameters, param_info)
current_styles = {p.style for p in param_info}
if len(current_styles) == 1 and target_style in current_styles:
converted_parameters = self._convert_parameter_format(
parameters,
param_info,
target_style,
parameters,
preserve_parameter_format=True,
is_many=is_many,
strict_named_parameters=strict_named_parameters,
)
return sql, converted_parameters
converted_sql = self._convert_placeholders_to_style(sql, param_info, target_style)
converted_parameters = self._convert_parameter_format(
parameters,
param_info,
target_style,
parameters,
preserve_parameter_format=True,
is_many=is_many,
strict_named_parameters=strict_named_parameters,
)
return converted_sql, converted_parameters
def _convert_placeholders_to_style(
self, sql: str, param_info: "list[ParameterInfo]", target_style: "ParameterStyle"
) -> str:
generator = self._placeholder_generators.get(target_style)
if generator is None:
msg = f"Unsupported target parameter style: {target_style}"
raise ValueError(msg)
param_styles = {p.style for p in param_info}
use_sequential_for_qmark = (
len(param_styles) == 1 and ParameterStyle.QMARK in param_styles and target_style == ParameterStyle.NUMERIC
)
unique_params: dict[str, int] = {}
for param in param_info:
param_key = (
f"{param.placeholder_text}_{param.ordinal}"
if use_sequential_for_qmark and param.style == ParameterStyle.QMARK
else param.placeholder_text
)
if param_key not in unique_params:
unique_params[param_key] = len(unique_params)
# Sort by position for forward iteration (O(n) string building)
sorted_params = sorted(param_info, key=lambda p: p.position)
placeholder_text_len_cache: dict[str, int] = {}
# Build SQL using forward iteration with list join (O(n) vs O(n^2) string slicing)
segments: list[str] = []
last_end = 0
is_positional_style = target_style in {
ParameterStyle.QMARK,
ParameterStyle.NUMERIC,
ParameterStyle.POSITIONAL_PYFORMAT,
ParameterStyle.POSITIONAL_COLON,
}
for param in sorted_params:
# Cache placeholder text length
if param.placeholder_text not in placeholder_text_len_cache:
placeholder_text_len_cache[param.placeholder_text] = len(param.placeholder_text)
text_len = placeholder_text_len_cache[param.placeholder_text]
# Generate new placeholder based on target style
if is_positional_style:
param_key = (
f"{param.placeholder_text}_{param.ordinal}"
if use_sequential_for_qmark and param.style == ParameterStyle.QMARK
else param.placeholder_text
)
new_placeholder = generator(unique_params[param_key])
else:
param_name = param.name or f"param_{param.ordinal}"
if isinstance(param_name, str) and param_name.isdigit():
param_name = f"param_{param.ordinal}"
new_placeholder = generator(param_name)
# Append segment before this placeholder and the new placeholder
segments.extend((sql[last_end : param.position], new_placeholder))
last_end = param.position + text_len
# Append remaining SQL after last placeholder
segments.append(sql[last_end:])
return "".join(segments)
def _convert_sequence_to_dict(
self, parameters: "ParameterSequence", param_info: "list[ParameterInfo]"
) -> "NamedParameterOutput":
param_dict: dict[str, Any] = {}
for i, param in enumerate(param_info):
if i < len(parameters):
name = param.name or f"param_{param.ordinal}"
param_dict[name] = parameters[i]
return param_dict
def _extract_param_value_mixed_styles(
self, param: "ParameterInfo", parameters: "ParameterMapping", param_keys: "list[str]"
) -> "tuple[object | None, bool]":
if param.name and param.name in parameters:
return parameters[param.name], True
if param.placeholder_text in parameters:
return parameters[param.placeholder_text], True
if (
param.style == ParameterStyle.NUMERIC
and param.name
and param.name.isdigit()
and param.ordinal < len(param_keys)
):
key_to_use = param_keys[param.ordinal]
return parameters[key_to_use], True
if f"param_{param.ordinal}" in parameters:
return parameters[f"param_{param.ordinal}"], True
ordinal_key = str(param.ordinal + 1)
if ordinal_key in parameters:
return parameters[ordinal_key], True
try:
ordered_keys = list(parameters.keys())
except AttributeError:
ordered_keys = []
if ordered_keys and param.ordinal < len(ordered_keys):
key = ordered_keys[param.ordinal]
if key in parameters:
return parameters[key], True
return None, False
def _extract_param_value_single_style(
self, param: "ParameterInfo", parameters: "ParameterMapping"
) -> "tuple[object | None, bool]":
if param.name and param.name in parameters:
return parameters[param.name], True
if param.placeholder_text in parameters:
return parameters[param.placeholder_text], True
if f"param_{param.ordinal}" in parameters:
return parameters[f"param_{param.ordinal}"], True
ordinal_key = str(param.ordinal + 1)
if ordinal_key in parameters:
return parameters[ordinal_key], True
try:
ordered_keys = list(parameters.keys())
except AttributeError:
ordered_keys = []
if ordered_keys and param.ordinal < len(ordered_keys):
key = ordered_keys[param.ordinal]
if key in parameters:
return parameters[key], True
return None, False
def _collect_missing_named_parameters(
self, param_info: "list[ParameterInfo]", parameters: "ParameterMapping"
) -> "list[str]":
named_styles = {
ParameterStyle.NAMED_COLON,
ParameterStyle.NAMED_AT,
ParameterStyle.NAMED_DOLLAR,
ParameterStyle.NAMED_PYFORMAT,
}
missing: list[str] = []
for param in param_info:
if param.style not in named_styles or not param.name:
continue
if param.name in parameters or param.placeholder_text in parameters:
continue
missing.append(param.name)
return sorted(set(missing))
def _preserve_original_format(
self, param_values: "list[Any]", original_parameters: object
) -> "PositionalParameterOutput":
if isinstance(original_parameters, tuple):
return tuple(param_values)
if isinstance(original_parameters, list):
return param_values
if isinstance(original_parameters, Mapping):
return tuple(param_values)
return tuple(param_values)
def _convert_parameter_format(
self,
parameters: "ParameterPayload",
param_info: "list[ParameterInfo]",
target_style: "ParameterStyle",
original_parameters: object | None = None,
preserve_parameter_format: bool = False,
is_many: bool = False,
*,
strict_named_parameters: bool = True,
) -> "ConvertedParameters":
if not parameters or not param_info:
# When parameters is falsy, it's either None or empty - return None
if parameters is None:
return None
# For empty containers, convert to concrete type
if isinstance(parameters, Mapping):
return dict(parameters)
if isinstance(parameters, (list, tuple)):
return list(parameters) if isinstance(parameters, list) else tuple(parameters)
return None
if (
is_many
and isinstance(parameters, Sequence)
and not isinstance(parameters, (str, bytes, bytearray))
and parameters
and isinstance(parameters[0], Mapping)
):
normalized_sets: list[Any] = [
self._convert_parameter_format(
param_set,
param_info,
target_style,
param_set,
preserve_parameter_format,
is_many=False,
strict_named_parameters=strict_named_parameters,
)
if isinstance(param_set, Mapping)
else param_set
for param_set in parameters
]
if preserve_parameter_format and isinstance(parameters, tuple):
return tuple(normalized_sets)
return normalized_sets
is_named_style = target_style in {
ParameterStyle.NAMED_COLON,
ParameterStyle.NAMED_AT,
ParameterStyle.NAMED_DOLLAR,
ParameterStyle.NAMED_PYFORMAT,
}
if is_named_style:
if isinstance(parameters, Mapping):
return dict(parameters)
if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)):
return self._convert_sequence_to_dict(parameters, param_info)
elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)):
return list(parameters) if isinstance(parameters, list) else tuple(parameters)
elif isinstance(parameters, Mapping):
if strict_named_parameters:
missing_names = self._collect_missing_named_parameters(param_info, parameters)
if missing_names:
msg = f"Missing named parameter(s): {', '.join(missing_names)}"
raise SQLSpecError(msg)
param_values: list[Any] = []
parameter_styles = {p.style for p in param_info}
has_mixed_styles = len(parameter_styles) > 1
unique_params: dict[str, Any] = {}
param_order: list[str] = []
if has_mixed_styles:
param_keys = list(parameters.keys())
for param in param_info:
param_key = param.placeholder_text if param.name else f"{param.placeholder_text}_{param.ordinal}"
if param_key not in unique_params:
value, found = self._extract_param_value_mixed_styles(param, parameters, param_keys)
if found:
unique_params[param_key] = value
param_order.append(param_key)
else:
for param in param_info:
param_key = param.placeholder_text if param.name else f"{param.placeholder_text}_{param.ordinal}"
if param_key not in unique_params:
value, found = self._extract_param_value_single_style(param, parameters)
if found:
unique_params[param_key] = value
param_order.append(param_key)
needs_expansion = target_style in {
ParameterStyle.QMARK,
ParameterStyle.POSITIONAL_PYFORMAT,
ParameterStyle.POSITIONAL_COLON,
}
if needs_expansion:
param_values = []
for param in param_info:
param_key = param.placeholder_text if param.name else f"{param.placeholder_text}_{param.ordinal}"
if param_key in unique_params:
param_values.append(unique_params[param_key])
else:
param_values = [unique_params[param_key] for param_key in param_order]
if preserve_parameter_format and original_parameters is not None:
return self._preserve_original_format(param_values, original_parameters)
return param_values
# Fallback for non-standard parameters - return None
return None
def _embed_static_parameters(
self, sql: str, parameters: "ParameterPayload", param_info: "list[ParameterInfo]"
) -> "tuple[str, None]":
if not param_info:
return sql, None
unique_params: dict[str, int] = {}
for param in param_info:
if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}:
param_key = f"{param.placeholder_text}_{param.ordinal}"
elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name:
param_key = param.placeholder_text
else:
param_key = f"{param.placeholder_text}_{param.ordinal}"
if param_key not in unique_params:
unique_params[param_key] = len(unique_params)
static_sql = sql
for param in reversed(param_info):
param_value = self._get_parameter_value_with_reuse(parameters, param, unique_params)
if param_value is None:
literal = "NULL"
elif isinstance(param_value, str):
escaped = param_value.replace("'", "''")
literal = f"'{escaped}'"
elif isinstance(param_value, bool):
literal = "TRUE" if param_value else "FALSE"
elif isinstance(param_value, (int, float)):
literal = str(param_value)
else:
literal = f"'{param_value!s}'"
static_sql = (
static_sql[: param.position] + literal + static_sql[param.position + len(param.placeholder_text) :]
)
return static_sql, None
def _get_parameter_value(self, parameters: "ParameterPayload", param: "ParameterInfo") -> object | None:
if isinstance(parameters, Mapping):
if param.name and param.name in parameters:
return parameters[param.name]
if f"param_{param.ordinal}" in parameters:
return parameters[f"param_{param.ordinal}"]
if str(param.ordinal + 1) in parameters:
return parameters[str(param.ordinal + 1)]
elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)):
if param.ordinal < len(parameters):
return parameters[param.ordinal]
return None
def _get_parameter_value_with_reuse(
self, parameters: "ParameterPayload", param: "ParameterInfo", unique_params: "dict[str, int]"
) -> object | None:
if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}:
param_key = f"{param.placeholder_text}_{param.ordinal}"
elif (param.style == ParameterStyle.NUMERIC and param.name) or param.name:
param_key = param.placeholder_text
else:
param_key = f"{param.placeholder_text}_{param.ordinal}"
unique_ordinal = unique_params.get(param_key)
if unique_ordinal is None:
return None
if isinstance(parameters, Mapping):
if param.name and param.name in parameters:
return parameters[param.name]
if f"param_{unique_ordinal}" in parameters:
return parameters[f"param_{unique_ordinal}"]
if str(unique_ordinal + 1) in parameters:
return parameters[str(unique_ordinal + 1)]
elif isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)):
if unique_ordinal < len(parameters):
return parameters[unique_ordinal]
return None