"""INSERT statement builder.
Provides a fluent interface for building SQL INSERT queries with
parameter binding and validation.
"""
from typing import TYPE_CHECKING, Any, Final
from sqlglot import exp
from typing_extensions import Self
from sqlspec.builder._base import QueryBuilder
from sqlspec.builder._dml import InsertFromSelectMixin, InsertIntoClauseMixin, InsertValuesMixin
from sqlspec.builder._parsing_utils import extract_sql_object_expression
from sqlspec.builder._select import ReturningClauseMixin
from sqlspec.core import SQLResult
from sqlspec.exceptions import SQLBuilderError
from sqlspec.utils.type_guards import has_expression_and_sql
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
__all__ = ("Insert",)
ERR_MSG_TABLE_NOT_SET: Final[str] = "The target table must be set using .into() before adding values."
ERR_MSG_INTERNAL_EXPRESSION_TYPE: Final[str] = "Internal error: expression is not an Insert instance as expected."
ERR_MSG_EXPRESSION_NOT_INITIALIZED: Final[str] = "Internal error: base expression not initialized."
[docs]
class Insert(QueryBuilder, ReturningClauseMixin, InsertValuesMixin, InsertFromSelectMixin, InsertIntoClauseMixin):
"""Builder for INSERT statements.
Constructs SQL INSERT queries with parameter binding and validation.
"""
__slots__ = ("_columns", "_table", "_values_added_count")
[docs]
def __init__(self, table: str | None = None, **kwargs: Any) -> None:
"""Initialize INSERT with optional table.
Args:
table: Target table name
**kwargs: Additional QueryBuilder arguments
"""
super().__init__(**kwargs)
self._table: str | None = None
self._columns: list[str] = []
self._values_added_count: int = 0
self._initialize_expression()
if table:
self.into(table)
def _create_base_expression(self) -> exp.Insert:
"""Create a base INSERT expression.
This method is called by the base QueryBuilder during initialization.
Returns:
A new sqlglot Insert expression.
"""
return exp.Insert()
@property
def _expected_result_type(self) -> "type[SQLResult]":
"""Specifies the expected result type for an INSERT query.
Returns:
The type of result expected for INSERT operations.
"""
return SQLResult
def _get_insert_expression(self) -> exp.Insert:
"""Safely gets and casts the internal expression to exp.Insert.
Returns:
The internal expression as exp.Insert.
Raises:
SQLBuilderError: If the expression is not initialized or is not an Insert.
"""
if self._expression is None:
raise SQLBuilderError(ERR_MSG_EXPRESSION_NOT_INITIALIZED)
if not isinstance(self._expression, exp.Insert):
raise SQLBuilderError(ERR_MSG_INTERNAL_EXPRESSION_TYPE)
return self._expression
[docs]
def get_insert_expression(self) -> exp.Insert:
"""Get the insert expression (public API)."""
return self._get_insert_expression()
[docs]
def values_from_dict(self, data: "Mapping[str, Any]") -> "Self":
"""Adds a row of values from a dictionary.
This is a convenience method that automatically sets columns based on
the dictionary keys and values based on the dictionary values.
Args:
data: A mapping of column names to values.
Returns:
The current builder instance for method chaining.
Raises:
SQLBuilderError: If `into()` has not been called to set the table.
"""
if not self._table:
raise SQLBuilderError(ERR_MSG_TABLE_NOT_SET)
data_keys = list(data.keys())
if not self._columns:
self.columns(*data_keys)
elif set(self._columns) != set(data_keys):
msg = f"Dictionary keys {set(data_keys)} do not match existing columns {set(self._columns)}."
raise SQLBuilderError(msg)
return self.values(*[data[col] for col in self._columns])
[docs]
def values_from_dicts(self, data: "Sequence[Mapping[str, Any]]") -> "Self":
"""Adds multiple rows of values from a sequence of dictionaries.
This is a convenience method for bulk inserts from structured data.
Args:
data: A sequence of mappings, each representing a row of data.
Returns:
The current builder instance for method chaining.
Raises:
SQLBuilderError: If `into()` has not been called to set the table,
or if dictionaries have inconsistent keys.
"""
if not data:
return self
first_dict = data[0]
if not self._columns:
self.columns(*first_dict.keys())
expected_keys = set(self._columns)
for i, row_dict in enumerate(data):
if set(row_dict.keys()) != expected_keys:
msg = (
f"Dictionary at index {i} has keys {set(row_dict.keys())} "
f"which do not match expected keys {expected_keys}."
)
raise SQLBuilderError(msg)
for row_dict in data:
self.values(*[row_dict[col] for col in self._columns])
return self
[docs]
def on_conflict(self, *columns: str) -> "ConflictBuilder":
"""Adds an ON CONFLICT clause with specified columns.
Args:
*columns: Column names that define the conflict. If no columns provided,
creates an ON CONFLICT without specific columns (catches all conflicts).
Returns:
A ConflictBuilder instance for chaining conflict resolution methods.
Example:
```python
sql.insert("users").values(id=1, name="John").on_conflict(
"id"
).do_nothing()
sql.insert("users").values(...).on_conflict(
"email", "username"
).do_update(updated_at=sql.raw("NOW()"))
sql.insert("users").values(...).on_conflict().do_nothing()
```
"""
return ConflictBuilder(self, columns)
[docs]
def on_conflict_do_nothing(self, *columns: str) -> "Insert":
"""Adds an ON CONFLICT DO NOTHING clause (convenience method).
Args:
*columns: Column names that define the conflict. If no columns provided,
creates an ON CONFLICT without specific columns.
Returns:
The current builder instance for method chaining.
Note:
This is a convenience method. For more control, use on_conflict().do_nothing().
"""
return self.on_conflict(*columns).do_nothing()
[docs]
def on_duplicate_key_update(self, **kwargs: Any) -> "Insert":
"""Adds MySQL-style ON DUPLICATE KEY UPDATE clause.
Args:
**kwargs: Column-value pairs to update on duplicate key.
Returns:
The current builder instance for method chaining.
Note:
This method creates MySQL-specific ON DUPLICATE KEY UPDATE syntax.
For PostgreSQL, use on_conflict() instead.
"""
if not kwargs:
return self
insert_expr = self._get_insert_expression()
set_expressions = []
for col, val in kwargs.items():
if has_expression_and_sql(val):
value_expr = extract_sql_object_expression(val, builder=self)
elif isinstance(val, exp.Expression):
value_expr = val
else:
param_name = self.generate_unique_parameter_name(col)
_, param_name = self.add_parameter(val, name=param_name)
value_expr = exp.Placeholder(this=param_name)
set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
on_conflict = exp.OnConflict(duplicate=True, action=exp.var("UPDATE"), expressions=set_expressions or None)
insert_expr.set("conflict", on_conflict)
return self
class ConflictBuilder:
"""Builder for ON CONFLICT clauses in INSERT statements.
Constructs conflict resolution clauses using PostgreSQL-style syntax,
which SQLGlot can transpile to other dialects.
"""
__slots__ = ("_columns", "_insert_builder")
def __init__(self, insert_builder: "Insert", columns: tuple[str, ...]) -> None:
"""Initialize ConflictBuilder.
Args:
insert_builder: The parent Insert builder
columns: Column names that define the conflict
"""
self._insert_builder = insert_builder
self._columns = columns
def do_nothing(self) -> "Insert":
"""Add DO NOTHING conflict resolution.
Returns:
The parent Insert builder for method chaining.
Example:
```python
sql.insert("users").values(id=1, name="John").on_conflict(
"id"
).do_nothing()
```
"""
insert_expr = self._insert_builder.get_insert_expression()
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
on_conflict = exp.OnConflict(conflict_keys=conflict_keys, action=exp.var("DO NOTHING"))
insert_expr.set("conflict", on_conflict)
return self._insert_builder
def do_update(self, **kwargs: Any) -> "Insert":
"""Add DO UPDATE conflict resolution with SET clauses.
Args:
**kwargs: Column-value pairs to update on conflict.
Returns:
The parent Insert builder for method chaining.
Example:
```python
sql.insert("users").values(id=1, name="John").on_conflict(
"id"
).do_update(
name="Updated Name", updated_at=sql.raw("NOW()")
)
```
"""
insert_expr = self._insert_builder.get_insert_expression()
set_expressions = []
for col, val in kwargs.items():
if has_expression_and_sql(val):
value_expr = extract_sql_object_expression(val, builder=self._insert_builder)
elif isinstance(val, exp.Expression):
value_expr = val
else:
param_name = self._insert_builder.generate_unique_parameter_name(col)
_, param_name = self._insert_builder.add_parameter(val, name=param_name)
value_expr = exp.Placeholder(this=param_name)
set_expressions.append(exp.EQ(this=exp.column(col), expression=value_expr))
conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
on_conflict = exp.OnConflict(
conflict_keys=conflict_keys, action=exp.var("DO UPDATE"), expressions=set_expressions or None
)
insert_expr.set("conflict", on_conflict)
return self._insert_builder