"""Google Cloud Spanner SQL dialect (GoogleSQL variant).
Extends the BigQuery dialect with Spanner-only DDL features:
`INTERLEAVE IN PARENT` for interleaved tables and `ROW DELETION POLICY`
for row-level time-to-live policies (GoogleSQL).
"""
import re
from typing import Any, Final, cast
from sqlglot import exp
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.parsers.bigquery import BigQueryParser
from sqlglot.tokenizer_core import TokenType
from sqlspec.dialects.spanner._generators import SpannerGenerator
__all__ = ("Spanner",)
_SPANNER_KEYWORDS: "dict[str, TokenType]" = {}
interleave_token = cast("TokenType | None", TokenType.__dict__.get("INTERLEAVE"))
if interleave_token is not None:
_SPANNER_KEYWORDS["INTERLEAVE"] = interleave_token
ttl_token = cast("TokenType | None", TokenType.__dict__.get("TTL"))
if ttl_token is not None:
_SPANNER_KEYWORDS["TTL"] = ttl_token
_TTL_MIN_COMPONENTS = 2
_ROW_DELETION_NAME = "ROW_DELETION_POLICY"
_INTERLEAVE_NAME = "INTERLEAVE_IN_PARENT"
_ORIGINAL_PARSE_PROPERTY_ATTR: Final[str] = "_sqlspec_original_parse_property"
_HOOKS_REGISTERED_ATTR: Final[str] = "_sqlspec_spanner_hooks_registered"
_INTERLEAVE_PATTERN: Final[re.Pattern[str]] = re.compile(
r"""
\bINTERLEAVE\s+IN\s+PARENT\s+
(?P<parent>.+?)
(?:\s+ON\s+DELETE\s+(?P<on_delete>CASCADE|NO\s+ACTION))?
(?=\s+(?:ROW\s+DELETION\s+POLICY|TTL)\b|\s*$)
""",
re.IGNORECASE | re.DOTALL | re.VERBOSE,
)
def _normalize_on_delete_value(on_delete: str) -> str:
return " ".join(on_delete.upper().split())
def _build_interleave_property(parent: exp.Expr, on_delete: str | None = None) -> exp.Property:
values: list[exp.Expr] = [parent]
if on_delete is not None:
values.append(exp.Literal.string(_normalize_on_delete_value(on_delete)))
return exp.Property(this=exp.Literal.string(_INTERLEAVE_NAME), value=exp.Tuple(expressions=values))
def _normalize_interval_expression(expression: exp.Expr) -> exp.Expr:
if isinstance(expression, exp.Alias):
alias = expression.args.get("alias")
if isinstance(alias, exp.Identifier) and isinstance(expression.this, exp.Expr):
return exp.Interval(this=expression.this.copy(), unit=alias.copy())
return expression
def _extract_interleave_property(sql: str) -> tuple[str, exp.Property | None]:
match = _INTERLEAVE_PATTERN.search(sql)
if match is None:
return sql, None
parent = exp.to_table(match.group("parent").strip())
on_delete = match.group("on_delete")
interleave_property = _build_interleave_property(parent, on_delete)
repaired_sql = f"{sql[: match.start()]} {sql[match.end() :]}".strip()
return repaired_sql, interleave_property
def _attach_create_property(create: exp.Create, property_expression: exp.Property) -> exp.Create:
properties = create.args.get("properties")
if isinstance(properties, exp.Properties):
expressions = list(properties.expressions)
expressions.insert(0, property_expression)
properties.set("expressions", expressions)
else:
create.set("properties", exp.Properties(expressions=[property_expression]))
return create
def _is_spanner_dialect(parser: Any) -> bool:
dialect = getattr(parser, "dialect", None)
return dialect is not None and dialect.__class__.__name__ == "Spanner"
def _original_bigquery_parse_property() -> Any:
original = getattr(BigQueryParser, _ORIGINAL_PARSE_PROPERTY_ATTR, None)
if callable(original):
return original
original = BigQueryParser._parse_property
setattr(BigQueryParser, _ORIGINAL_PARSE_PROPERTY_ATTR, original)
return original
def _spanner_parse_property(self: Any) -> exp.Expr:
if _is_spanner_dialect(self):
if self._match_text_seq("INTERLEAVE", "IN", "PARENT"):
parent = cast("exp.Expr", self._parse_table(schema=True, is_db_reference=True))
on_delete: str | None = None
if self._match_text_seq("ON", "DELETE"):
if self._match_text_seq("CASCADE"):
on_delete = "CASCADE"
elif self._match_text_seq("NO", "ACTION"):
on_delete = "NO ACTION"
return _build_interleave_property(parent, on_delete)
if self._match_text_seq("ROW", "DELETION", "POLICY"):
self._match(TokenType.L_PAREN)
self._match_text_seq("OLDER_THAN")
self._match(TokenType.L_PAREN)
column = cast("exp.Expr", self._parse_id_var())
self._match(TokenType.COMMA)
self._match_text_seq("INTERVAL")
interval = _normalize_interval_expression(cast("exp.Expr", self._parse_expression()))
self._match(TokenType.R_PAREN)
self._match(TokenType.R_PAREN)
return exp.Property(
this=exp.Literal.string(_ROW_DELETION_NAME), value=exp.Tuple(expressions=[column, interval])
)
if self._match_text_seq("TTL"):
self._match_text_seq("INTERVAL")
interval = _normalize_interval_expression(cast("exp.Expr", self._parse_expression()))
self._match_text_seq("ON")
column = cast("exp.Expr", self._parse_id_var())
return exp.Property(this=exp.Literal.string("TTL"), value=exp.Tuple(expressions=[interval, column]))
return cast("exp.Expr", _original_bigquery_parse_property()(self))
def _register_bigquery_spanner_parser_hooks() -> None:
if getattr(BigQueryParser, _HOOKS_REGISTERED_ATTR, False):
return
_original_bigquery_parse_property()
setattr(BigQueryParser, "_parse_property", _spanner_parse_property)
setattr(BigQueryParser, _HOOKS_REGISTERED_ATTR, True)
class SpannerTokenizer(BigQuery.Tokenizer):
"""Tokenizer adds Spanner-only keywords when supported by sqlglot."""
KEYWORDS = {**BigQuery.Tokenizer.KEYWORDS, **_SPANNER_KEYWORDS}
_register_bigquery_spanner_parser_hooks()
[docs]
class Spanner(BigQuery):
"""Google Cloud Spanner SQL dialect."""
Tokenizer = SpannerTokenizer
Parser = BigQuery.Parser
Generator = SpannerGenerator
[docs]
def parse(self, sql: str, **opts: Any) -> list[exp.Expr | None]:
"""Repair CREATE TABLE statements that sqlglot still falls back to Command for."""
expressions = super().parse(sql, **opts)
if len(expressions) != 1 or not isinstance(expressions[0], exp.Command):
return expressions
repaired_sql, interleave_property = _extract_interleave_property(sql)
if interleave_property is None:
return expressions
reparsed = BigQuery.parse(self, repaired_sql, **opts)
if len(reparsed) != 1 or not isinstance(reparsed[0], exp.Create):
return expressions
return [_attach_create_property(reparsed[0], interleave_property)]