Source code for sqlspec.builder._join

# pyright: reportPrivateUsage=false
"""JOIN operation mixins.

Provides mixins for JOIN operations in SELECT statements.
"""

from typing import TYPE_CHECKING, Any, Union, cast

from mypy_extensions import trait
from sqlglot import exp
from typing_extensions import Self

from sqlspec.builder._base import BuiltQuery, QueryBuilder
from sqlspec.builder._parsing_utils import parse_table_expression
from sqlspec.exceptions import SQLBuilderError
from sqlspec.utils.type_guards import has_expression_and_parameters, has_expression_and_sql, has_parameter_builder

if TYPE_CHECKING:
    from sqlspec.core import SQL
    from sqlspec.protocols import HasParameterBuilderProtocol, SQLBuilderProtocol

__all__ = ("JoinBuilder", "JoinClauseMixin", "create_join_builder")


def _handle_sql_object_condition(on: Any, builder: "SQLBuilderProtocol") -> exp.Expression:
    if has_expression_and_parameters(on) and on.expression is not None:
        for param_name, param_value in on.parameters.items():
            builder.add_parameter(param_value, name=param_name)
        return cast("exp.Expression", on.expression)
    if has_expression_and_parameters(on):
        for param_name, param_value in on.parameters.items():
            builder.add_parameter(param_value, name=param_name)
    parsed_expr = exp.maybe_parse(on.sql, dialect=builder.dialect)  # pyright: ignore[reportAttributeAccessIssue]
    return parsed_expr if parsed_expr is not None else exp.condition(str(on.sql))  # pyright: ignore[reportAttributeAccessIssue]


def _parse_join_condition(
    builder: "SQLBuilderProtocol", on: Union[str, exp.Expression, "SQL"] | None
) -> exp.Expression | None:
    if on is None:
        return None
    if isinstance(on, str):
        return exp.condition(on)
    if has_expression_and_sql(on):
        return _handle_sql_object_condition(on, builder)
    if isinstance(on, exp.Expression):
        return on
    return exp.condition(str(on))


def _handle_query_builder_table(table: Any, alias: str | None, builder: "SQLBuilderProtocol") -> exp.Expression:
    subquery_expression: exp.Expression
    builder_table = cast("HasParameterBuilderProtocol", table)
    parameters = builder_table.parameters

    if isinstance(table, QueryBuilder):
        subquery_expression = table._build_final_expression(copy=True)
    else:
        subquery_result = builder_table.build()
        sql_text = subquery_result.sql if isinstance(subquery_result, BuiltQuery) else str(subquery_result)
        subquery_expression = exp.maybe_parse(sql_text, dialect=builder.dialect) or exp.convert(sql_text)

    if parameters:
        for param_name, param_value in parameters.items():
            builder.add_parameter(param_value, name=param_name)

    subquery_exp = exp.paren(subquery_expression)
    return exp.alias_(subquery_exp, alias) if alias else subquery_exp


def _parse_join_table(
    builder: "SQLBuilderProtocol", table: str | exp.Expression | Any, alias: str | None
) -> exp.Expression:
    if isinstance(table, str):
        return parse_table_expression(table, alias)
    if has_parameter_builder(table):
        return _handle_query_builder_table(table, alias, builder)
    if isinstance(table, exp.Expression):
        return table
    return cast("exp.Expression", table)


def _create_join_expression(table_expr: exp.Expression, on_expr: exp.Expression | None, join_type: str) -> exp.Join:
    join_type_upper = join_type.upper()
    if join_type_upper == "INNER":
        return exp.Join(this=table_expr, on=on_expr)
    if join_type_upper == "LEFT":
        return exp.Join(this=table_expr, on=on_expr, side="LEFT")
    if join_type_upper == "RIGHT":
        return exp.Join(this=table_expr, on=on_expr, side="RIGHT")
    if join_type_upper == "FULL":
        return exp.Join(this=table_expr, on=on_expr, side="FULL", kind="OUTER")
    if join_type_upper == "CROSS":
        return exp.Join(this=table_expr, kind="CROSS")
    msg = f"Unsupported join type: {join_type}"
    raise SQLBuilderError(msg)


def _apply_lateral_modifier(join_expr: exp.Join) -> None:
    current_kind = join_expr.args.get("kind")
    current_side = join_expr.args.get("side")

    if current_kind == "CROSS":
        join_expr.set("kind", "CROSS LATERAL")
    elif current_kind == "OUTER" and current_side == "FULL":
        join_expr.set("side", "FULL")
        join_expr.set("kind", "OUTER LATERAL")
    elif current_side:
        join_expr.set("kind", f"{current_side} LATERAL")
        join_expr.set("side", None)
    else:
        join_expr.set("kind", "LATERAL")


def build_join_clause(
    builder: "SQLBuilderProtocol",
    table: str | exp.Expression | Any,
    on: Union[str, exp.Expression, "SQL"] | None,
    alias: str | None,
    join_type: str,
    *,
    lateral: bool = False,
) -> exp.Join:
    table_expr = _parse_join_table(builder, table, alias)
    on_expr = _parse_join_condition(builder, on)
    join_expr = _create_join_expression(table_expr, on_expr, join_type)
    if lateral:
        _apply_lateral_modifier(join_expr)
    return join_expr


@trait
class JoinClauseMixin:
    """Mixin providing JOIN clause methods for SELECT builders.

    ``_expression`` is populated by the base builder class so the mixin can append JOINs without initializing the underlying SELECT expression.
    """

    __slots__ = ()

    _expression: exp.Expression | None

    def join(
        self,
        table: str | exp.Expression | Any,
        on: Union[str, exp.Expression, "SQL"] | None = None,
        alias: str | None = None,
        join_type: str = "INNER",
        lateral: bool = False,
        as_of: Any | None = None,
        as_of_type: str | None = None,
    ) -> Self:
        """Add a JOIN clause to the SELECT expression.

        ``as_of`` attaches a temporal version clause by copying the inner table, honoring existing aliases, and updating the JOIN target without mutating shared expressions.
        """
        builder = cast("SQLBuilderProtocol", self)
        if builder._expression is None:
            builder._expression = exp.Select()
        if not isinstance(builder._expression, exp.Select):
            msg = "JOIN clause is only supported for SELECT statements."
            raise SQLBuilderError(msg)

        if isinstance(table, exp.Join):
            builder._expression = builder._expression.join(table, copy=False)
            return cast("Self", builder)

        join_expr = build_join_clause(builder, table, on, alias, join_type, lateral=lateral)

        if as_of is not None:
            inner_table = join_expr.this
            inner_table = inner_table.copy()

            target_alias = alias

            if isinstance(inner_table, exp.Alias):
                target_alias = inner_table.alias
                inner_table = inner_table.this

            if target_alias is None and isinstance(inner_table, exp.Table):
                alias_expr = inner_table.args.get("alias")
                if alias_expr is not None:
                    target_alias = alias_expr.this
                    inner_table.set("alias", None)

            version = exp.Version(this=as_of_type or "TIMESTAMP", kind="AS OF", expression=exp.convert(as_of))
            inner_table.set("version", version)

            new_this = exp.alias_(inner_table, target_alias) if target_alias else inner_table
            join_expr.set("this", new_this)

        builder._expression = builder._expression.join(join_expr, copy=False)
        return cast("Self", builder)

    def inner_join(
        self,
        table: str | exp.Expression | Any,
        on: Union[str, exp.Expression, "SQL"],
        alias: str | None = None,
        as_of: Any | None = None,
    ) -> Self:
        return self.join(table, on, alias, "INNER", as_of=as_of)

    def left_join(
        self,
        table: str | exp.Expression | Any,
        on: Union[str, exp.Expression, "SQL"],
        alias: str | None = None,
        as_of: Any | None = None,
    ) -> Self:
        return self.join(table, on, alias, "LEFT", as_of=as_of)

    def right_join(
        self,
        table: str | exp.Expression | Any,
        on: Union[str, exp.Expression, "SQL"],
        alias: str | None = None,
        as_of: Any | None = None,
    ) -> Self:
        return self.join(table, on, alias, "RIGHT", as_of=as_of)

    def full_join(
        self,
        table: str | exp.Expression | Any,
        on: Union[str, exp.Expression, "SQL"],
        alias: str | None = None,
        as_of: Any | None = None,
    ) -> Self:
        return self.join(table, on, alias, "FULL", as_of=as_of)

    def cross_join(
        self,
        table: str | exp.Expression | Any,
        alias: str | None = None,
        as_of: Any | None = None,
        as_of_type: str | None = None,
    ) -> Self:
        builder = cast("SQLBuilderProtocol", self)
        if builder._expression is None:
            builder._expression = exp.Select()
        if not isinstance(builder._expression, exp.Select):
            msg = "Cannot add cross join to a non-SELECT expression."
            raise SQLBuilderError(msg)
        table_expr = _parse_join_table(builder, table, alias)

        if as_of is not None:
            # Handle logic similar to join() but specifically for cross join which parses table earlier
            inner_table = table_expr

            # Copy to avoid mutating original expression
            inner_table = inner_table.copy()

            target_alias = alias
            if isinstance(inner_table, exp.Alias):
                target_alias = inner_table.alias
                inner_table = inner_table.this
            elif isinstance(inner_table, exp.Table):
                alias_expr = inner_table.args.get("alias")
                if alias_expr is not None:
                    target_alias = alias_expr.this
                    inner_table.set("alias", None)

            # Create Version expression and attach to table
            version = exp.Version(this=as_of_type or "TIMESTAMP", kind="AS OF", expression=exp.convert(as_of))
            inner_table.set("version", version)
            table_expr = exp.alias_(inner_table, target_alias) if target_alias else inner_table

        join_expr = exp.Join(this=table_expr, kind="CROSS")
        builder._expression = builder._expression.join(join_expr, copy=False)
        return cast("Self", builder)

    def lateral_join(
        self,
        table: str | exp.Expression | Any,
        on: Union[str, exp.Expression, "SQL"] | None = None,
        alias: str | None = None,
    ) -> Self:
        """Create a LATERAL JOIN.

        Args:
            table: Table, subquery, or table function to join
            on: Optional join condition (for LATERAL JOINs with ON clause)
            alias: Optional alias for the joined table/subquery

        Returns:
            Self for method chaining

        Example:
            ```python
            query = (
                sql
                .select("u.name", "arr.value")
                .from_("users u")
                .lateral_join("UNNEST(u.tags)", alias="arr")
            )
            ```
        """
        return self.join(table, on=on, alias=alias, join_type="INNER", lateral=True)

    def left_lateral_join(
        self,
        table: str | exp.Expression | Any,
        on: Union[str, exp.Expression, "SQL"] | None = None,
        alias: str | None = None,
    ) -> Self:
        """Create a LEFT LATERAL JOIN.

        Args:
            table: Table, subquery, or table function to join
            on: Optional join condition
            alias: Optional alias for the joined table/subquery

        Returns:
            Self for method chaining
        """
        return self.join(table, on=on, alias=alias, join_type="LEFT", lateral=True)

    def cross_lateral_join(self, table: str | exp.Expression | Any, alias: str | None = None) -> Self:
        """Create a CROSS LATERAL JOIN (no ON condition).

        Args:
            table: Table, subquery, or table function to join
            alias: Optional alias for the joined table/subquery

        Returns:
            Self for method chaining
        """
        return self.join(table, on=None, alias=alias, join_type="CROSS", lateral=True)


[docs] class JoinBuilder: """Builder for JOIN operations with fluent syntax. Example: ```python from sqlspec import sql # sql.left_join_("posts").on("users.id = posts.user_id") join_clause = sql.left_join_("posts").on( "users.id = posts.user_id" ) # Or with query builder query = ( sql .select("users.name", "posts.title") .from_("users") .join( sql.left_join_("posts").on( "users.id = posts.user_id" ) ) ) ``` """
[docs] def __init__(self, join_type: str, lateral: bool = False) -> None: """Initialize the join builder. Args: join_type: Type of join (inner, left, right, full, cross, lateral) lateral: Whether this is a LATERAL join """ self._join_type = join_type.upper() self._lateral = lateral self._table: str | exp.Expression | None = None self._condition: exp.Expression | None = None self._alias: str | None = None self._as_of: Any | None = None self._as_of_type: str | None = None
[docs] def __call__(self, table: str | exp.Expression, alias: str | None = None) -> Self: """Set the table to join. Args: table: Table name or expression to join alias: Optional alias for the table Returns: Self for method chaining """ self._table = table self._alias = alias return self
[docs] def as_of(self, time_expr: Any, kind: str | None = None) -> Self: """Set AS OF clause for the join (Time Travel/Flashback). Args: time_expr: Timestamp or system time expression kind: Type of AS OF clause (SYSTEM TIME, TIMESTAMP). If None, defaults based on dialect. Returns: Self for method chaining """ self._as_of = time_expr self._as_of_type = kind return self
[docs] def on(self, condition: str | exp.Expression) -> exp.Expression: """Set the join condition and build the JOIN expression. Args: condition: JOIN condition (e.g., "users.id = posts.user_id") Returns: Complete JOIN expression """ if not self._table: msg = "Table must be set before calling .on()" raise SQLBuilderError(msg) condition_expr: exp.Expression if isinstance(condition, str): parsed: exp.Expression | None = exp.maybe_parse(condition) condition_expr = parsed or exp.condition(condition) else: condition_expr = condition table_expr: exp.Expression if isinstance(self._table, str): table_expr = exp.to_table(self._table) if self._alias: table_expr = exp.alias_(table_expr, self._alias) else: table_expr = self._table if self._alias: table_expr = exp.alias_(table_expr, self._alias) if self._as_of is not None: inner_table = table_expr.copy() target_alias = self._alias if isinstance(inner_table, exp.Alias): target_alias = inner_table.alias inner_table = inner_table.this elif isinstance(inner_table, exp.Table): alias_expr = inner_table.args.get("alias") if alias_expr is not None: target_alias = alias_expr.this inner_table.set("alias", None) version = exp.Version( this=self._as_of_type or "TIMESTAMP", kind="AS OF", expression=exp.convert(self._as_of) ) inner_table.set("version", version) table_expr = exp.alias_(inner_table, target_alias) if target_alias else inner_table if self._join_type in {"INNER JOIN", "INNER", "LATERAL JOIN"}: join_expr = exp.Join(this=table_expr, on=condition_expr) elif self._join_type in {"LEFT JOIN", "LEFT"}: join_expr = exp.Join(this=table_expr, on=condition_expr, side="LEFT") elif self._join_type in {"RIGHT JOIN", "RIGHT"}: join_expr = exp.Join(this=table_expr, on=condition_expr, side="RIGHT") elif self._join_type in {"FULL JOIN", "FULL"}: join_expr = exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER") elif self._join_type in {"CROSS JOIN", "CROSS"}: join_expr = exp.Join(this=table_expr, kind="CROSS") else: join_expr = exp.Join(this=table_expr, on=condition_expr) if self._lateral or self._join_type == "LATERAL JOIN": current_kind = join_expr.args.get("kind") current_side = join_expr.args.get("side") if current_kind == "CROSS": join_expr.set("kind", "CROSS LATERAL") elif current_kind == "OUTER" and current_side == "FULL": join_expr.set("side", "FULL") join_expr.set("kind", "OUTER LATERAL") elif current_side: join_expr.set("kind", f"{current_side} LATERAL") join_expr.set("side", None) else: join_expr.set("kind", "LATERAL") return join_expr
def create_join_builder(join_type: str, lateral: bool = False) -> "JoinBuilder": """Create a JoinBuilder without tripping trait instantiation errors. This guards against runtime environments where a trait-decorated JoinBuilder may raise on direct construction. """ try: return JoinBuilder(join_type, lateral=lateral) except TypeError as exc: if "traits may not be directly created" not in str(exc): raise builder = object.__new__(JoinBuilder) builder._join_type = join_type.upper() builder._lateral = lateral builder._table = None builder._condition = None builder._alias = None builder._as_of = None builder._as_of_type = None return builder