Source code for sqlspec.dialects.spanner._spanner

"""Google Cloud Spanner SQL dialect (GoogleSQL variant).

Extends the BigQuery dialect with Spanner-only DDL features:
`INTERLEAVE IN PARENT` for interleaved tables and `ROW DELETION POLICY`
for row-level time-to-live policies (GoogleSQL).
"""

import re
from typing import Any, Final, cast

from sqlglot import exp
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.parsers.bigquery import BigQueryParser
from sqlglot.tokenizer_core import TokenType

from sqlspec.dialects.spanner._generators import SpannerGenerator

__all__ = ("Spanner",)

_SPANNER_KEYWORDS: "dict[str, TokenType]" = {}
interleave_token = cast("TokenType | None", TokenType.__dict__.get("INTERLEAVE"))
if interleave_token is not None:
    _SPANNER_KEYWORDS["INTERLEAVE"] = interleave_token
ttl_token = cast("TokenType | None", TokenType.__dict__.get("TTL"))
if ttl_token is not None:
    _SPANNER_KEYWORDS["TTL"] = ttl_token

_TTL_MIN_COMPONENTS = 2
_ROW_DELETION_NAME = "ROW_DELETION_POLICY"
_INTERLEAVE_NAME = "INTERLEAVE_IN_PARENT"
_ORIGINAL_PARSE_PROPERTY_ATTR: Final[str] = "_sqlspec_original_parse_property"
_HOOKS_REGISTERED_ATTR: Final[str] = "_sqlspec_spanner_hooks_registered"
_INTERLEAVE_PATTERN: Final[re.Pattern[str]] = re.compile(
    r"""
    \bINTERLEAVE\s+IN\s+PARENT\s+
    (?P<parent>.+?)
    (?:\s+ON\s+DELETE\s+(?P<on_delete>CASCADE|NO\s+ACTION))?
    (?=\s+(?:ROW\s+DELETION\s+POLICY|TTL)\b|\s*$)
    """,
    re.IGNORECASE | re.DOTALL | re.VERBOSE,
)


def _normalize_on_delete_value(on_delete: str) -> str:
    return " ".join(on_delete.upper().split())


def _build_interleave_property(parent: exp.Expr, on_delete: str | None = None) -> exp.Property:
    values: list[exp.Expr] = [parent]
    if on_delete is not None:
        values.append(exp.Literal.string(_normalize_on_delete_value(on_delete)))
    return exp.Property(this=exp.Literal.string(_INTERLEAVE_NAME), value=exp.Tuple(expressions=values))


def _normalize_interval_expression(expression: exp.Expr) -> exp.Expr:
    if isinstance(expression, exp.Alias):
        alias = expression.args.get("alias")
        if isinstance(alias, exp.Identifier) and isinstance(expression.this, exp.Expr):
            return exp.Interval(this=expression.this.copy(), unit=alias.copy())
    return expression


def _extract_interleave_property(sql: str) -> tuple[str, exp.Property | None]:
    match = _INTERLEAVE_PATTERN.search(sql)
    if match is None:
        return sql, None

    parent = exp.to_table(match.group("parent").strip())
    on_delete = match.group("on_delete")
    interleave_property = _build_interleave_property(parent, on_delete)
    repaired_sql = f"{sql[: match.start()]} {sql[match.end() :]}".strip()
    return repaired_sql, interleave_property


def _attach_create_property(create: exp.Create, property_expression: exp.Property) -> exp.Create:
    properties = create.args.get("properties")
    if isinstance(properties, exp.Properties):
        expressions = list(properties.expressions)
        expressions.insert(0, property_expression)
        properties.set("expressions", expressions)
    else:
        create.set("properties", exp.Properties(expressions=[property_expression]))
    return create


def _is_spanner_dialect(parser: Any) -> bool:
    dialect = getattr(parser, "dialect", None)
    return dialect is not None and dialect.__class__.__name__ == "Spanner"


def _original_bigquery_parse_property() -> Any:
    original = getattr(BigQueryParser, _ORIGINAL_PARSE_PROPERTY_ATTR, None)
    if callable(original):
        return original
    original = BigQueryParser._parse_property
    setattr(BigQueryParser, _ORIGINAL_PARSE_PROPERTY_ATTR, original)
    return original


def _spanner_parse_property(self: Any) -> exp.Expr:
    if _is_spanner_dialect(self):
        if self._match_text_seq("INTERLEAVE", "IN", "PARENT"):
            parent = cast("exp.Expr", self._parse_table(schema=True, is_db_reference=True))
            on_delete: str | None = None

            if self._match_text_seq("ON", "DELETE"):
                if self._match_text_seq("CASCADE"):
                    on_delete = "CASCADE"
                elif self._match_text_seq("NO", "ACTION"):
                    on_delete = "NO ACTION"

            return _build_interleave_property(parent, on_delete)

        if self._match_text_seq("ROW", "DELETION", "POLICY"):
            self._match(TokenType.L_PAREN)
            self._match_text_seq("OLDER_THAN")
            self._match(TokenType.L_PAREN)
            column = cast("exp.Expr", self._parse_id_var())
            self._match(TokenType.COMMA)
            self._match_text_seq("INTERVAL")
            interval = _normalize_interval_expression(cast("exp.Expr", self._parse_expression()))
            self._match(TokenType.R_PAREN)
            self._match(TokenType.R_PAREN)

            return exp.Property(
                this=exp.Literal.string(_ROW_DELETION_NAME), value=exp.Tuple(expressions=[column, interval])
            )

        if self._match_text_seq("TTL"):
            self._match_text_seq("INTERVAL")
            interval = _normalize_interval_expression(cast("exp.Expr", self._parse_expression()))
            self._match_text_seq("ON")
            column = cast("exp.Expr", self._parse_id_var())

            return exp.Property(this=exp.Literal.string("TTL"), value=exp.Tuple(expressions=[interval, column]))

    return cast("exp.Expr", _original_bigquery_parse_property()(self))


def _register_bigquery_spanner_parser_hooks() -> None:
    if getattr(BigQueryParser, _HOOKS_REGISTERED_ATTR, False):
        return

    _original_bigquery_parse_property()
    setattr(BigQueryParser, "_parse_property", _spanner_parse_property)
    setattr(BigQueryParser, _HOOKS_REGISTERED_ATTR, True)


class SpannerTokenizer(BigQuery.Tokenizer):
    """Tokenizer adds Spanner-only keywords when supported by sqlglot."""

    KEYWORDS = {**BigQuery.Tokenizer.KEYWORDS, **_SPANNER_KEYWORDS}


_register_bigquery_spanner_parser_hooks()


[docs] class Spanner(BigQuery): """Google Cloud Spanner SQL dialect.""" Tokenizer = SpannerTokenizer Parser = BigQuery.Parser Generator = SpannerGenerator
[docs] def parse(self, sql: str, **opts: Any) -> list[exp.Expr | None]: """Repair CREATE TABLE statements that sqlglot still falls back to Command for.""" expressions = super().parse(sql, **opts) if len(expressions) != 1 or not isinstance(expressions[0], exp.Command): return expressions repaired_sql, interleave_property = _extract_interleave_property(sql) if interleave_property is None: return expressions reparsed = BigQuery.parse(self, repaired_sql, **opts) if len(reparsed) != 1 or not isinstance(reparsed[0], exp.Create): return expressions return [_attach_create_property(reparsed[0], interleave_property)]