"""Parameter extraction utilities."""
import re
from collections import OrderedDict
from mypy_extensions import mypyc_attr
from sqlspec.core.parameters._types import ParameterInfo, ParameterStyle
__all__ = ("PARAMETER_REGEX", "ParameterValidator")
PARAMETER_REGEX = re.compile(
r"""
(?P<dquote>"(?:[^"\\]|\\.)*") |
(?P<squote>'(?:[^'\\]|\\.)*') |
(?P<dollar_quoted_string>\$(?P<dollar_quote_tag_inner>\w*)?\$[\s\S]*?\$\4\$) |
(?P<line_comment>--[^\r\n]*) |
(?P<block_comment>/\*(?:[^*]|\*(?!/))*\*/) |
(?P<pg_q_operator>\?\?|\?\||\?&) |
(?P<pg_cast>::(?P<cast_type>\w+)) |
(?P<sql_server_global>@@(?P<global_var_name>\w+)) |
(?P<pyformat_named>%\((?P<pyformat_name>\w+)\)s) |
(?P<pyformat_pos>%s) |
(?P<positional_colon>(?<![A-Za-z0-9_]):(?P<colon_num>\d+)) |
(?P<named_colon>(?<![A-Za-z0-9_]):(?P<colon_name>\w+)) |
(?P<named_at>(?<![A-Za-z0-9_])@(?P<at_name>\w+)) |
(?P<numeric>(?<![A-Za-z0-9_])\$(?P<numeric_num>\d+)) |
(?P<named_dollar_param>(?<![A-Za-z0-9_])\$(?P<dollar_param_name>\w+)) |
(?P<qmark>\?)
""",
re.VERBOSE | re.IGNORECASE | re.MULTILINE | re.DOTALL,
)
[docs]
@mypyc_attr(allow_interpreted_subclasses=False)
class ParameterValidator:
"""Extracts placeholder metadata and dialect compatibility information."""
__slots__ = ("_cache_max_size", "_parameter_cache")
[docs]
def __init__(self, cache_max_size: int = 5000) -> None:
self._parameter_cache: OrderedDict[str, list[ParameterInfo]] = OrderedDict()
self._cache_max_size = cache_max_size
def _extract_parameter_style(self, match: re.Match[str]) -> "tuple[ParameterStyle | None, str | None]":
"""Map a regex match to a placeholder style and optional name."""
if match.group("qmark"):
return ParameterStyle.QMARK, None
if match.group("named_colon"):
return ParameterStyle.NAMED_COLON, match.group("colon_name")
if match.group("numeric"):
return ParameterStyle.NUMERIC, match.group("numeric_num")
if match.group("named_at"):
return ParameterStyle.NAMED_AT, match.group("at_name")
if match.group("pyformat_named"):
return ParameterStyle.NAMED_PYFORMAT, match.group("pyformat_name")
if match.group("pyformat_pos"):
return ParameterStyle.POSITIONAL_PYFORMAT, None
if match.group("positional_colon"):
return ParameterStyle.POSITIONAL_COLON, match.group("colon_num")
if match.group("named_dollar_param"):
return ParameterStyle.NAMED_DOLLAR, match.group("dollar_param_name")
return None, None
[docs]
def get_sqlglot_incompatible_styles(self, dialect: str | None = None) -> "set[ParameterStyle]":
"""Return placeholder styles incompatible with SQLGlot for the dialect."""
base_incompatible = {
ParameterStyle.NAMED_PYFORMAT,
ParameterStyle.POSITIONAL_PYFORMAT,
ParameterStyle.POSITIONAL_COLON,
}
if dialect and dialect.lower() in {"mysql", "mariadb"}:
return base_incompatible
if dialect and dialect.lower() in {"postgres", "postgresql"}:
return {ParameterStyle.POSITIONAL_COLON}
if dialect and dialect.lower() == "sqlite":
return {ParameterStyle.POSITIONAL_COLON}
if dialect and dialect.lower() in {"oracle", "bigquery"}:
return base_incompatible
return base_incompatible