Source code for sqlspec.dialects.spanner._spangres

r"""Google Cloud Spanner PostgreSQL-interface dialect ("Spangres")."""

from typing import Any, Final, cast

from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.parsers.postgres import PostgresParser
from sqlglot.tokenizer_core import TokenType

from sqlspec.dialects.spanner._generators import SpangresGenerator

__all__ = ("Spangres",)

_ROW_DELETION_NAME = "ROW_DELETION_POLICY"
_TTL_MIN_COMPONENTS = 2
_ORIGINAL_PARSE_PROPERTY_ATTR: Final[str] = "_sqlspec_original_parse_property"
_HOOKS_REGISTERED_ATTR: Final[str] = "_sqlspec_spangres_hooks_registered"


def _normalize_interval_expression(expression: exp.Expr) -> exp.Expr:
    if isinstance(expression, exp.Alias):
        alias = expression.args.get("alias")
        if isinstance(alias, exp.Identifier) and isinstance(expression.this, exp.Expr):
            return exp.Interval(this=expression.this.copy(), unit=alias.copy())
    return expression


def _is_spangres_dialect(parser: Any) -> bool:
    dialect = getattr(parser, "dialect", None)
    return dialect is not None and dialect.__class__.__name__ == "Spangres"


def _original_postgres_parse_property() -> Any:
    original = getattr(PostgresParser, _ORIGINAL_PARSE_PROPERTY_ATTR, None)
    if callable(original):
        return original
    original = PostgresParser._parse_property
    setattr(PostgresParser, _ORIGINAL_PARSE_PROPERTY_ATTR, original)
    return original


def _spangres_parse_property(self: Any) -> exp.Expr:
    if _is_spangres_dialect(self) and self._match_text_seq("ROW", "DELETION", "POLICY"):
        self._match(TokenType.L_PAREN)
        self._match_text_seq("OLDER_THAN")
        self._match(TokenType.L_PAREN)
        column = cast("exp.Expr", self._parse_id_var())
        self._match(TokenType.COMMA)
        self._match_text_seq("INTERVAL")
        interval = _normalize_interval_expression(cast("exp.Expr", self._parse_expression()))
        self._match(TokenType.R_PAREN)
        self._match(TokenType.R_PAREN)

        return exp.Property(
            this=exp.Literal.string(_ROW_DELETION_NAME), value=exp.Tuple(expressions=[column, interval])
        )

    return cast("exp.Expr", _original_postgres_parse_property()(self))


def _register_postgres_spangres_parser_hooks() -> None:
    if getattr(PostgresParser, _HOOKS_REGISTERED_ATTR, False):
        return

    _original_postgres_parse_property()
    setattr(PostgresParser, "_parse_property", _spangres_parse_property)
    setattr(PostgresParser, _HOOKS_REGISTERED_ATTR, True)


_register_postgres_spangres_parser_hooks()


[docs] class Spangres(Postgres): """Spanner PostgreSQL-compatible dialect.""" Parser = Postgres.Parser Generator = SpangresGenerator