# pyright: reportPrivateUsage=false, reportPrivateImportUsage=false
"""JOIN operation mixins.
Provides mixins for JOIN operations in SELECT statements.
"""
from typing import TYPE_CHECKING, Any, Union, cast, final
from mypy_extensions import trait
from sqlglot import exp
from typing_extensions import Self
from sqlspec.builder._base import BuiltQuery, QueryBuilder
from sqlspec.builder._parsing_utils import parse_table_expression
from sqlspec.exceptions import SQLBuilderError
from sqlspec.utils.type_guards import has_expression_and_parameters, has_expression_and_sql, has_parameter_builder
if TYPE_CHECKING:
from sqlspec.core import SQL
from sqlspec.protocols import HasParameterBuilderProtocol, SQLBuilderProtocol
__all__ = ("JoinBuilder", "JoinClauseMixin", "create_join_builder")
def _handle_sql_object_condition(on: Any, builder: "SQLBuilderProtocol") -> exp.Expr:
if has_expression_and_parameters(on) and on.expression is not None:
for param_name, param_value in on.parameters.items():
builder.add_parameter(param_value, name=param_name)
return cast("exp.Expr", on.expression)
if has_expression_and_parameters(on):
for param_name, param_value in on.parameters.items():
builder.add_parameter(param_value, name=param_name)
raw_sql = getattr(on, "raw_sql", None)
if raw_sql is None:
raw_sql = on.sql # pyright: ignore[reportAttributeAccessIssue]
parsed_expr = exp.maybe_parse(raw_sql, dialect=builder.dialect)
return parsed_expr if parsed_expr is not None else exp.condition(raw_sql)
def _parse_join_condition(builder: "SQLBuilderProtocol", on: Union[str, exp.Expr, "SQL"] | None) -> exp.Expr | None:
if on is None:
return None
if isinstance(on, str):
return exp.condition(on)
if has_expression_and_sql(on):
return _handle_sql_object_condition(on, builder)
if isinstance(on, exp.Expr):
return on
return exp.condition(str(on))
def _handle_query_builder_table(table: Any, alias: str | None, builder: "SQLBuilderProtocol") -> exp.Expr:
subquery_expression: exp.Expr
builder_table = cast("HasParameterBuilderProtocol", table)
parameters = builder_table.parameters
if isinstance(table, QueryBuilder):
subquery_expression = table._build_final_expression(copy=True)
else:
subquery_result = builder_table.build()
sql_text = subquery_result.sql if isinstance(subquery_result, BuiltQuery) else str(subquery_result)
subquery_expression = exp.maybe_parse(sql_text, dialect=builder.dialect) or exp.convert(sql_text)
if parameters:
for param_name, param_value in parameters.items():
builder.add_parameter(param_value, name=param_name)
subquery_exp = exp.paren(subquery_expression)
return exp.alias_(subquery_exp, alias) if alias else subquery_exp
def _parse_join_table(builder: "SQLBuilderProtocol", table: str | exp.Expr | Any, alias: str | None) -> exp.Expr:
if isinstance(table, str):
return parse_table_expression(table, alias, dialect=builder.dialect)
if has_parameter_builder(table):
return _handle_query_builder_table(table, alias, builder)
if isinstance(table, exp.Expr):
return table
return cast("exp.Expr", table)
def _create_join_expression(table_expr: exp.Expr, on_expr: exp.Expr | None, join_type: str) -> exp.Join:
join_type_upper = join_type.upper()
if join_type_upper == "INNER":
return exp.Join(this=table_expr, on=on_expr)
if join_type_upper == "LEFT":
return exp.Join(this=table_expr, on=on_expr, side="LEFT")
if join_type_upper == "RIGHT":
return exp.Join(this=table_expr, on=on_expr, side="RIGHT")
if join_type_upper == "FULL":
return exp.Join(this=table_expr, on=on_expr, side="FULL", kind="OUTER")
if join_type_upper == "CROSS":
return exp.Join(this=table_expr, kind="CROSS")
msg = f"Unsupported join type: {join_type}"
raise SQLBuilderError(msg)
def _apply_lateral_modifier(join_expr: exp.Join) -> None:
current_kind = join_expr.args.get("kind")
current_side = join_expr.args.get("side")
if current_kind == "CROSS":
join_expr.set("kind", "CROSS LATERAL")
elif current_kind == "OUTER" and current_side == "FULL":
join_expr.set("side", "FULL")
join_expr.set("kind", "OUTER LATERAL")
elif current_side:
join_expr.set("kind", f"{current_side} LATERAL")
join_expr.set("side", None)
else:
join_expr.set("kind", "LATERAL")
def _attach_as_of_version(
table_expr: exp.Expr, alias: str | None, as_of: Any, as_of_type: str | None = None
) -> exp.Expr:
inner_table = table_expr.copy()
target_alias = alias
if isinstance(inner_table, exp.Alias):
target_alias = inner_table.alias
inner_table = inner_table.this
elif isinstance(inner_table, exp.Table):
alias_expr = inner_table.args.get("alias")
if alias_expr is not None:
target_alias = alias_expr.this
inner_table.set("alias", None)
version = exp.Version(this=as_of_type or "TIMESTAMP", kind="AS OF", expression=exp.convert(as_of))
inner_table.set("version", version)
return exp.alias_(inner_table, target_alias) if target_alias else inner_table
def build_join_clause(
builder: "SQLBuilderProtocol",
table: str | exp.Expr | Any,
on: Union[str, exp.Expr, "SQL"] | None,
alias: str | None,
join_type: str,
*,
lateral: bool = False,
) -> exp.Join:
table_expr = _parse_join_table(builder, table, alias)
on_expr = _parse_join_condition(builder, on)
join_expr = _create_join_expression(table_expr, on_expr, join_type)
if lateral:
_apply_lateral_modifier(join_expr)
return join_expr
@trait
class JoinClauseMixin:
"""Mixin providing JOIN clause methods for SELECT builders.
``_expression`` is populated by the base builder class so the mixin can append JOINs without initializing the underlying SELECT expression.
"""
__slots__ = ()
_expression: exp.Expr | None
def join(
self,
table: str | exp.Expr | Any,
on: Union[str, exp.Expr, "SQL"] | None = None,
alias: str | None = None,
join_type: str = "INNER",
lateral: bool = False,
as_of: Any | None = None,
as_of_type: str | None = None,
) -> Self:
"""Add a JOIN clause to the SELECT expression.
``as_of`` attaches a temporal version clause by copying the inner table, honoring existing aliases, and updating the JOIN target without mutating shared expressions.
"""
builder = cast("SQLBuilderProtocol", self)
if builder._expression is None:
builder._expression = exp.Select()
if not isinstance(builder._expression, exp.Select):
msg = "JOIN clause is only supported for SELECT statements."
raise SQLBuilderError(msg)
if isinstance(table, exp.Join):
builder._expression = builder._expression.join(table, copy=False)
return cast("Self", builder)
join_expr = build_join_clause(builder, table, on, alias, join_type, lateral=lateral)
if as_of is not None:
join_expr.set("this", _attach_as_of_version(join_expr.this, alias, as_of, as_of_type))
builder._expression = builder._expression.join(join_expr, copy=False)
return cast("Self", builder)
def inner_join(
self,
table: str | exp.Expr | Any,
on: Union[str, exp.Expr, "SQL"],
alias: str | None = None,
as_of: Any | None = None,
) -> Self:
return self.join(table, on, alias, "INNER", as_of=as_of)
def left_join(
self,
table: str | exp.Expr | Any,
on: Union[str, exp.Expr, "SQL"],
alias: str | None = None,
as_of: Any | None = None,
) -> Self:
return self.join(table, on, alias, "LEFT", as_of=as_of)
def right_join(
self,
table: str | exp.Expr | Any,
on: Union[str, exp.Expr, "SQL"],
alias: str | None = None,
as_of: Any | None = None,
) -> Self:
return self.join(table, on, alias, "RIGHT", as_of=as_of)
def full_join(
self,
table: str | exp.Expr | Any,
on: Union[str, exp.Expr, "SQL"],
alias: str | None = None,
as_of: Any | None = None,
) -> Self:
return self.join(table, on, alias, "FULL", as_of=as_of)
def cross_join(
self,
table: str | exp.Expr | Any,
alias: str | None = None,
as_of: Any | None = None,
as_of_type: str | None = None,
) -> Self:
builder = cast("SQLBuilderProtocol", self)
if builder._expression is None:
builder._expression = exp.Select()
if not isinstance(builder._expression, exp.Select):
msg = "Cannot add cross join to a non-SELECT expression."
raise SQLBuilderError(msg)
table_expr = _parse_join_table(builder, table, alias)
if as_of is not None:
table_expr = _attach_as_of_version(table_expr, alias, as_of, as_of_type)
join_expr = exp.Join(this=table_expr, kind="CROSS")
builder._expression = builder._expression.join(join_expr, copy=False)
return cast("Self", builder)
def lateral_join(
self, table: str | exp.Expr | Any, on: Union[str, exp.Expr, "SQL"] | None = None, alias: str | None = None
) -> Self:
"""Create a LATERAL JOIN.
Args:
table: Table, subquery, or table function to join
on: Optional join condition (for LATERAL JOINs with ON clause)
alias: Optional alias for the joined table/subquery
Returns:
Self for method chaining
"""
return self.join(table, on=on, alias=alias, join_type="INNER", lateral=True)
def left_lateral_join(
self, table: str | exp.Expr | Any, on: Union[str, exp.Expr, "SQL"] | None = None, alias: str | None = None
) -> Self:
"""Create a LEFT LATERAL JOIN.
Args:
table: Table, subquery, or table function to join
on: Optional join condition
alias: Optional alias for the joined table/subquery
Returns:
Self for method chaining
"""
return self.join(table, on=on, alias=alias, join_type="LEFT", lateral=True)
def cross_lateral_join(self, table: str | exp.Expr | Any, alias: str | None = None) -> Self:
"""Create a CROSS LATERAL JOIN (no ON condition).
Args:
table: Table, subquery, or table function to join
alias: Optional alias for the joined table/subquery
Returns:
Self for method chaining
"""
return self.join(table, on=None, alias=alias, join_type="CROSS", lateral=True)
[docs]
@final
class JoinBuilder:
"""Builder for JOIN operations with fluent syntax."""
__slots__ = ("_alias", "_as_of", "_as_of_type", "_join_type", "_lateral", "_table")
[docs]
def __init__(self, join_type: str, lateral: bool = False) -> None:
"""Initialize the join builder.
Args:
join_type: Type of join (inner, left, right, full, cross, lateral)
lateral: Whether this is a LATERAL join
"""
self._join_type = join_type.upper()
self._lateral = lateral
self._table: str | exp.Expr | None = None
self._alias: str | None = None
self._as_of: Any | None = None
self._as_of_type: str | None = None
[docs]
def __call__(self, table: str | exp.Expr, alias: str | None = None) -> Self:
"""Set the table to join.
Args:
table: Table name or expression to join
alias: Optional alias for the table
Returns:
Self for method chaining
"""
self._table = table
self._alias = alias
return self
[docs]
def as_of(self, time_expr: Any, kind: str | None = None) -> Self:
"""Set AS OF clause for the join (Time Travel/Flashback).
Args:
time_expr: Timestamp or system time expression
kind: Type of AS OF clause (SYSTEM TIME, TIMESTAMP). If None, defaults based on dialect.
Returns:
Self for method chaining
"""
self._as_of = time_expr
self._as_of_type = kind
return self
[docs]
def on(self, condition: str | exp.Expr) -> exp.Expr:
"""Set the join condition and build the JOIN expression.
Args:
condition: JOIN condition
Returns:
Complete JOIN expression
"""
if not self._table:
msg = "Table must be set before calling .on()"
raise SQLBuilderError(msg)
condition_expr: exp.Expr
if isinstance(condition, str):
parsed: exp.Expr | None = exp.maybe_parse(condition)
condition_expr = parsed or exp.condition(condition)
else:
condition_expr = condition
table_expr: exp.Expr
if isinstance(self._table, str):
table_expr = exp.to_table(self._table)
if self._alias:
table_expr = exp.alias_(table_expr, self._alias)
else:
table_expr = self._table
if self._alias:
table_expr = exp.alias_(table_expr, self._alias)
if self._as_of is not None:
table_expr = _attach_as_of_version(table_expr, self._alias, self._as_of, self._as_of_type)
match self._join_type:
case "INNER JOIN" | "INNER" | "LATERAL JOIN":
join_expr = exp.Join(this=table_expr, on=condition_expr)
case "LEFT JOIN" | "LEFT":
join_expr = exp.Join(this=table_expr, on=condition_expr, side="LEFT")
case "RIGHT JOIN" | "RIGHT":
join_expr = exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
case "FULL JOIN" | "FULL":
join_expr = exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
case "CROSS JOIN" | "CROSS":
join_expr = exp.Join(this=table_expr, kind="CROSS")
case _:
join_expr = exp.Join(this=table_expr, on=condition_expr)
if self._lateral or self._join_type == "LATERAL JOIN":
current_kind = join_expr.args.get("kind")
current_side = join_expr.args.get("side")
if current_kind == "CROSS":
join_expr.set("kind", "CROSS LATERAL")
elif current_kind == "OUTER" and current_side == "FULL":
join_expr.set("side", "FULL")
join_expr.set("kind", "OUTER LATERAL")
elif current_side:
join_expr.set("kind", f"{current_side} LATERAL")
join_expr.set("side", None)
else:
join_expr.set("kind", "LATERAL")
return join_expr
def create_join_builder(join_type: str, lateral: bool = False) -> "JoinBuilder":
"""Create a JoinBuilder without tripping trait instantiation errors.
This guards against runtime environments where a trait-decorated JoinBuilder
may raise on direct construction.
"""
try:
return JoinBuilder(join_type, lateral=lateral)
except TypeError as exc:
if "traits may not be directly created" not in str(exc):
raise
builder = object.__new__(JoinBuilder)
builder._join_type = join_type.upper()
builder._lateral = lateral
builder._table = None
builder._alias = None
builder._as_of = None
builder._as_of_type = None
return builder