"""UPDATE statement builder.
Provides a fluent interface for building SQL UPDATE queries with
parameter binding and validation.
"""
from typing import TYPE_CHECKING, Any, cast
from sqlglot import exp
from typing_extensions import Self
from sqlspec.builder._base import QueryBuilder, SafeQuery
from sqlspec.builder._dml import UpdateFromClauseMixin, UpdateSetClauseMixin, UpdateTableClauseMixin
from sqlspec.builder._join import build_join_clause
from sqlspec.builder._select import ReturningClauseMixin, WhereClauseMixin
from sqlspec.core import SQLResult
from sqlspec.exceptions import SQLBuilderError
if TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlspec.builder._select import Select
from sqlspec.protocols import SQLBuilderProtocol
__all__ = ("Update",)
[docs]
class Update(
QueryBuilder,
WhereClauseMixin,
ReturningClauseMixin,
UpdateSetClauseMixin,
UpdateFromClauseMixin,
UpdateTableClauseMixin,
):
"""Builder for UPDATE statements.
Constructs SQL UPDATE statements with parameter binding and validation.
Example:
```python
update_query = (
Update()
.table("users")
.set_(name="John Doe")
.set_(email="[email protected]")
.where("id = 1")
)
update_query = (
Update("users").set_(name="John Doe").where("id = 1")
)
update_query = (
Update()
.table("users")
.set_(status="active")
.where_eq("id", 123)
)
update_query = (
Update()
.table("users", "u")
.set_(name="Updated Name")
.from_("profiles", "p")
.where("u.id = p.user_id AND p.is_verified = true")
)
```
"""
__slots__ = ("_table",)
_expression: exp.Expression | None
[docs]
def __init__(self, table: str | None = None, **kwargs: Any) -> None:
"""Initialize UPDATE with optional table.
Args:
table: Target table name
**kwargs: Additional QueryBuilder arguments
"""
super().__init__(**kwargs)
self._initialize_expression()
if table:
self.table(table)
@property
def _expected_result_type(self) -> "type[SQLResult]":
"""Return the expected result type for this builder."""
return SQLResult
def _create_base_expression(self) -> exp.Update:
"""Create a base UPDATE expression.
Returns:
A new sqlglot Update expression with empty clauses.
"""
return exp.Update(this=None, expressions=[], joins=[])
[docs]
def join(
self,
table: "str | exp.Expression | Select",
on: "str | exp.Expression",
alias: "str | None" = None,
join_type: str = "INNER",
) -> "Self":
"""Add JOIN clause to the UPDATE statement.
Args:
table: The table name, expression, or subquery to join.
on: The JOIN condition.
alias: Optional alias for the joined table.
join_type: Type of join (INNER, LEFT, RIGHT, FULL).
Returns:
The current builder instance for method chaining.
Raises:
SQLBuilderError: If the current expression is not an UPDATE statement.
"""
if self._expression is None or not isinstance(self._expression, exp.Update):
msg = "Cannot add JOIN clause to non-UPDATE expression."
raise SQLBuilderError(msg)
join_expr = build_join_clause(cast("SQLBuilderProtocol", self), table, on, alias, join_type)
if not self._expression.args.get("joins"):
self._expression.set("joins", [])
self._expression.args["joins"].append(join_expr)
return self
[docs]
def build(self, dialect: "DialectType" = None) -> "SafeQuery":
"""Build the UPDATE query with validation.
Args:
dialect: Optional dialect override for SQL generation.
Returns:
SafeQuery: The built query with SQL and parameters.
Raises:
SQLBuilderError: If no table is set or expression is not an UPDATE.
"""
if self._expression is None:
msg = "UPDATE expression not initialized."
raise SQLBuilderError(msg)
if not isinstance(self._expression, exp.Update):
msg = "No UPDATE expression to build or expression is of the wrong type."
raise SQLBuilderError(msg)
if self._expression.this is None:
msg = "No table specified for UPDATE statement."
raise SQLBuilderError(msg)
if not self._expression.args.get("expressions"):
msg = "At least one SET clause must be specified for UPDATE statement."
raise SQLBuilderError(msg)
return super().build(dialect=dialect)