"""Column expressions for query building.
Provides Column objects that support Python operators for building
SQL conditions with parameter binding.
"""
from collections.abc import Iterable
from datetime import date, datetime
from typing import Any, cast
from sqlglot import exp
__all__ = ("Column", "ColumnExpression", "FunctionColumn")
def _convert_value(value: Any) -> exp.Expression:
"""Convert a Python value to a SQLGlot expression.
Special handling for datetime objects to prevent SQLGlot from
converting them to TIME_STR_TO_TIME function calls. Datetime
objects should be passed as parameters, not converted to SQL functions.
Args:
value: The value to convert
Returns:
A SQLGlot expression representing the value
"""
if isinstance(value, (datetime, date)):
return exp.Literal(this=value, is_string=False)
return exp.convert(value)
class ColumnExpression:
"""Base class for column expressions that can be combined with operators."""
__slots__ = ("_expression",)
def __init__(self, expression: exp.Expression) -> None:
self._expression = expression
def __and__(self, other: "ColumnExpression") -> "ColumnExpression":
"""Combine with AND operator (&)."""
if not isinstance(other, ColumnExpression):
return NotImplemented
return ColumnExpression(exp.And(this=self._expression, expression=other._expression))
def __or__(self, other: "ColumnExpression") -> "ColumnExpression":
"""Combine with OR operator (|)."""
if not isinstance(other, ColumnExpression):
return NotImplemented
return ColumnExpression(exp.Or(this=self._expression, expression=other._expression))
def __invert__(self) -> "ColumnExpression":
"""Apply NOT operator (~)."""
return ColumnExpression(exp.Not(this=self._expression))
def __bool__(self) -> bool:
"""Prevent accidental use of 'and'/'or' keywords."""
msg = (
"Cannot use 'and'/'or' operators on ColumnExpression. "
"Use '&'/'|' operators instead. "
f"Expression: {self._expression.sql()}"
)
raise TypeError(msg)
@property
def sqlglot_expression(self) -> exp.Expression:
"""Get the underlying SQLGlot expression."""
return self._expression
class Column:
"""Represents a database column with Python operator support."""
__slots__ = ("_expression", "name", "table")
def __init__(self, name: str, table: str | None = None) -> None:
self.name = name
self.table = table
if table:
self._expression = exp.Column(this=exp.Identifier(this=name), table=exp.Identifier(this=table))
else:
self._expression = exp.Column(this=exp.Identifier(this=name))
def _convert_value(self, value: Any) -> exp.Expression:
"""Convert a Python value to a SQLGlot expression."""
return _convert_value(value)
def __eq__(self, other: object) -> ColumnExpression: # type: ignore[override]
"""Equal to (==)."""
if other is None:
return ColumnExpression(exp.Is(this=self._expression, expression=exp.Null()))
return ColumnExpression(exp.EQ(this=self._expression, expression=self._convert_value(other)))
def __ne__(self, other: object) -> ColumnExpression: # type: ignore[override]
"""Not equal to (!=)."""
if other is None:
return ColumnExpression(exp.Not(this=exp.Is(this=self._expression, expression=exp.Null())))
return ColumnExpression(exp.NEQ(this=self._expression, expression=self._convert_value(other)))
def __gt__(self, other: Any) -> ColumnExpression:
"""Greater than (>)."""
return ColumnExpression(exp.GT(this=self._expression, expression=self._convert_value(other)))
def __ge__(self, other: Any) -> ColumnExpression:
"""Greater than or equal (>=)."""
return ColumnExpression(exp.GTE(this=self._expression, expression=self._convert_value(other)))
def __lt__(self, other: Any) -> ColumnExpression:
"""Less than (<)."""
return ColumnExpression(exp.LT(this=self._expression, expression=self._convert_value(other)))
def __le__(self, other: Any) -> ColumnExpression:
"""Less than or equal (<=)."""
return ColumnExpression(exp.LTE(this=self._expression, expression=self._convert_value(other)))
def __invert__(self) -> ColumnExpression:
"""Apply NOT operator (~)."""
return ColumnExpression(exp.Not(this=self._expression))
def like(self, pattern: str, escape: str | None = None) -> ColumnExpression:
"""SQL LIKE pattern matching."""
if escape:
like_expr = exp.Like(
this=self._expression, expression=self._convert_value(pattern), escape=self._convert_value(escape)
)
else:
like_expr = exp.Like(this=self._expression, expression=self._convert_value(pattern))
return ColumnExpression(like_expr)
def ilike(self, pattern: str) -> ColumnExpression:
"""Case-insensitive LIKE."""
return ColumnExpression(exp.ILike(this=self._expression, expression=self._convert_value(pattern)))
def in_(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL IN clause."""
converted_values = [self._convert_value(v) for v in values]
return ColumnExpression(exp.In(this=self._expression, expressions=converted_values))
def not_in(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL NOT IN clause."""
return ~self.in_(values)
def between(self, start: Any, end: Any) -> ColumnExpression:
"""SQL BETWEEN clause."""
return ColumnExpression(
exp.Between(this=self._expression, low=self._convert_value(start), high=self._convert_value(end))
)
def is_null(self) -> ColumnExpression:
"""SQL IS NULL."""
return ColumnExpression(exp.Is(this=self._expression, expression=exp.Null()))
def is_not_null(self) -> ColumnExpression:
"""SQL IS NOT NULL."""
return ColumnExpression(exp.Not(this=exp.Is(this=self._expression, expression=exp.Null())))
def not_like(self, pattern: str, escape: str | None = None) -> ColumnExpression:
"""SQL NOT LIKE pattern matching."""
return ~self.like(pattern, escape)
def not_ilike(self, pattern: str) -> ColumnExpression:
"""Case-insensitive NOT LIKE."""
return ~self.ilike(pattern)
def any_(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL = ANY(...) clause."""
converted_values = [self._convert_value(v) for v in values]
return ColumnExpression(exp.EQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
def not_any_(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL <> ANY(...) clause."""
converted_values = [self._convert_value(v) for v in values]
return ColumnExpression(exp.NEQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
def lower(self) -> "FunctionColumn":
"""SQL LOWER() function."""
return FunctionColumn(exp.Lower(this=self._expression))
def upper(self) -> "FunctionColumn":
"""SQL UPPER() function."""
return FunctionColumn(exp.Upper(this=self._expression))
def length(self) -> "FunctionColumn":
"""SQL LENGTH() function."""
return FunctionColumn(exp.Length(this=self._expression))
def trim(self) -> "FunctionColumn":
"""SQL TRIM() function."""
return FunctionColumn(exp.Trim(this=self._expression))
def abs(self) -> "FunctionColumn":
"""SQL ABS() function."""
return FunctionColumn(exp.Abs(this=self._expression))
def round(self, decimals: int = 0) -> "FunctionColumn":
"""SQL ROUND() function."""
if decimals == 0:
return FunctionColumn(exp.Round(this=self._expression))
return FunctionColumn(exp.Round(this=self._expression, expression=exp.convert(decimals)))
def floor(self) -> "FunctionColumn":
"""SQL FLOOR() function."""
return FunctionColumn(exp.Floor(this=self._expression))
def ceil(self) -> "FunctionColumn":
"""SQL CEIL() function."""
return FunctionColumn(exp.Ceil(this=self._expression))
def substring(self, start: int, length: int | None = None) -> "FunctionColumn":
"""SQL SUBSTRING() function."""
args = [self._convert_value(start)]
if length is not None:
args.append(self._convert_value(length))
return FunctionColumn(exp.Substring(this=self._expression, expressions=args))
def coalesce(self, *values: Any) -> "FunctionColumn":
"""SQL COALESCE() function."""
expressions = [self._expression] + [self._convert_value(v) for v in values]
return FunctionColumn(exp.Coalesce(expressions=expressions))
def cast(self, data_type: str) -> "FunctionColumn":
"""SQL CAST() function."""
return FunctionColumn(exp.Cast(this=self._expression, to=exp.DataType.build(data_type)))
def count(self) -> "FunctionColumn":
"""SQL COUNT() function."""
return FunctionColumn(exp.Count(this=self._expression))
def sum(self) -> "FunctionColumn":
"""SQL SUM() function."""
return FunctionColumn(exp.Sum(this=self._expression))
def avg(self) -> "FunctionColumn":
"""SQL AVG() function."""
return FunctionColumn(exp.Avg(this=self._expression))
def min(self) -> "FunctionColumn":
"""SQL MIN() function."""
return FunctionColumn(exp.Min(this=self._expression))
def max(self) -> "FunctionColumn":
"""SQL MAX() function."""
return FunctionColumn(exp.Max(this=self._expression))
def count_distinct(self) -> "FunctionColumn":
"""SQL COUNT(DISTINCT column) function."""
return FunctionColumn(exp.Count(this=exp.Distinct(expressions=[self._expression])))
@staticmethod
def count_all() -> "FunctionColumn":
"""SQL COUNT(*) function."""
return FunctionColumn(exp.Count(this=exp.Star()))
@staticmethod
def _normalize_metric(metric: str) -> str:
"""Normalize and validate vector distance metric."""
normalized_metric = metric.lower()
valid_metrics = {"euclidean", "cosine", "inner_product", "euclidean_squared"}
if normalized_metric not in valid_metrics:
msg = f"Invalid metric: {metric}. Must be one of {valid_metrics}"
raise ValueError(msg)
return normalized_metric
def _convert_vector_value(self, value: "list[float] | Column | exp.Expression") -> "exp.Expression":
"""Convert a vector input into a SQLGlot expression."""
if isinstance(value, list):
return exp.Array(expressions=[exp.Literal.number(v) for v in value])
if isinstance(value, Column):
return value._expression
if isinstance(value, exp.Expression):
return value
msg = f"Unsupported vector type: {type(value)}"
raise TypeError(msg)
[docs]
def vector_distance(
self, other_vector: "list[float] | Column | exp.Expression", metric: str = "euclidean"
) -> "FunctionColumn":
"""Calculate vector distance using specified metric.
Generates dialect-specific SQL for vector distance calculation:
- PostgreSQL (pgvector): Operators <->, <=>, <#>
- MySQL 9+: DISTANCE(col, vec, 'METRIC') function
- Oracle 23ai+: VECTOR_DISTANCE(col, vec, METRIC) function
Args:
other_vector: Vector to compare against (list, Column, or SQLGlot expression).
metric: Distance metric to use. Options:
- "euclidean": L2 distance (default)
- "cosine": Cosine distance
- "inner_product": Negative inner product
- "euclidean_squared": L2² distance (Oracle only)
Returns:
FunctionColumn expression for use in SELECT, WHERE, ORDER BY.
Examples:
Basic distance query with threshold:
>>> query = (
... sql
... .select("*")
... .from_("docs")
... .where(
... Column("embedding").vector_distance(
... [0.1, 0.2], metric="euclidean"
... )
... < 0.5
... )
... )
Distance in SELECT clause with alias:
>>> query = (
... sql
... .select(
... "id",
... Column("embedding")
... .vector_distance([0.1, 0.2])
... .as_("dist"),
... )
... .from_("docs")
... .order_by("dist")
... )
Compare two vector columns:
>>> query = (
... sql
... .select("*")
... .from_("pairs")
... .where(
... Column("vec1").vector_distance(
... Column("vec2"), metric="cosine"
... )
... < 0.3
... )
... )
"""
from sqlspec.builder._vector_expressions import VectorDistance
normalized_metric = self._normalize_metric(metric)
vec_expr = self._convert_vector_value(other_vector)
distance_expr = VectorDistance(this=self._expression, expression=vec_expr, metric=normalized_metric)
return FunctionColumn(distance_expr)
[docs]
def cosine_similarity(self, other_vector: "list[float] | Column | exp.Expression") -> "FunctionColumn":
"""Calculate cosine similarity (1 - cosine_distance).
Convenience method that computes similarity instead of distance.
Returns values in range [-1, 1] where 1 = identical vectors.
Args:
other_vector: Vector to compare against (list, Column, or expression).
Returns:
FunctionColumn expression: 1 - cosine_distance(self, other_vector).
Examples:
Find most similar documents:
>>> query = (
... sql
... .select(
... "id",
... Column("embedding")
... .cosine_similarity([0.1, 0.2])
... .as_("score"),
... )
... .from_("docs")
... .order_by(sql.column("score").desc())
... .limit(10)
... )
"""
cosine_dist = self.vector_distance(other_vector, metric="cosine")
similarity_expr = exp.Sub(this=exp.Literal.number(1), expression=exp.Paren(this=cosine_dist._expression)) # pyright: ignore[reportPrivateUsage]
return FunctionColumn(similarity_expr)
def alias(self, alias_name: str) -> exp.Expression:
"""Create an aliased column expression."""
return exp.Alias(this=self._expression, alias=alias_name)
def asc(self) -> exp.Ordered:
"""Create an ASC ordering expression."""
return exp.Ordered(this=self._expression, desc=False)
def desc(self) -> exp.Ordered:
"""Create a DESC ordering expression."""
return exp.Ordered(this=self._expression, desc=True)
def as_(self, alias: str) -> exp.Alias:
"""Create an aliased expression."""
return cast("exp.Alias", exp.alias_(self._expression, alias))
def __repr__(self) -> str:
if self.table:
return f"Column<{self.table}.{self.name}>"
return f"Column<{self.name}>"
def __hash__(self) -> int:
"""Hash based on table and column name."""
return hash((self.table, self.name))
@property
def sqlglot_expression(self) -> exp.Expression:
"""Get the underlying SQLGlot expression (public API).
Returns:
The SQLGlot expression for this column
"""
return self._expression
class FunctionColumn:
"""Represents the result of a SQL function call on a column."""
__slots__ = ("_expression",)
def __init__(self, expression: "exp.Expression") -> None:
self._expression = expression
def _convert_value(self, value: Any) -> exp.Expression:
"""Convert a Python value to a SQLGlot expression."""
return _convert_value(value)
@property
def sqlglot_expression(self) -> "exp.Expression":
"""Return underlying SQLGlot expression."""
return self._expression
def __eq__(self, other: object) -> ColumnExpression: # type: ignore[override]
return ColumnExpression(exp.EQ(this=self._expression, expression=self._convert_value(other)))
def __ne__(self, other: object) -> ColumnExpression: # type: ignore[override]
return ColumnExpression(exp.NEQ(this=self._expression, expression=self._convert_value(other)))
def __gt__(self, other: Any) -> ColumnExpression:
"""Greater than (>)."""
return ColumnExpression(exp.GT(this=self._expression, expression=self._convert_value(other)))
def __ge__(self, other: Any) -> ColumnExpression:
"""Greater than or equal (>=)."""
return ColumnExpression(exp.GTE(this=self._expression, expression=self._convert_value(other)))
def __lt__(self, other: Any) -> ColumnExpression:
"""Less than (<)."""
return ColumnExpression(exp.LT(this=self._expression, expression=self._convert_value(other)))
def __le__(self, other: Any) -> ColumnExpression:
"""Less than or equal (<=)."""
return ColumnExpression(exp.LTE(this=self._expression, expression=self._convert_value(other)))
def like(self, pattern: str) -> ColumnExpression:
return ColumnExpression(exp.Like(this=self._expression, expression=self._convert_value(pattern)))
def ilike(self, pattern: str) -> ColumnExpression:
"""Case-insensitive LIKE."""
return ColumnExpression(exp.ILike(this=self._expression, expression=self._convert_value(pattern)))
def in_(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL IN clause."""
converted_values = [self._convert_value(v) for v in values]
return ColumnExpression(exp.In(this=self._expression, expressions=converted_values))
def not_in_(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL NOT IN clause."""
return ~self.in_(values)
def not_like(self, pattern: str) -> ColumnExpression:
"""SQL NOT LIKE."""
return ~self.like(pattern)
def not_ilike(self, pattern: str) -> ColumnExpression:
"""Case-insensitive NOT LIKE."""
return ~self.ilike(pattern)
def between(self, start: Any, end: Any) -> ColumnExpression:
"""SQL BETWEEN clause."""
return ColumnExpression(
exp.Between(this=self._expression, low=self._convert_value(start), high=self._convert_value(end))
)
def is_null(self) -> ColumnExpression:
"""SQL IS NULL."""
return ColumnExpression(exp.Is(this=self._expression, expression=exp.Null()))
def is_not_null(self) -> ColumnExpression:
"""SQL IS NOT NULL."""
return ColumnExpression(exp.Not(this=exp.Is(this=self._expression, expression=exp.Null())))
def any_(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL = ANY(...) clause."""
converted_values = [self._convert_value(v) for v in values]
return ColumnExpression(exp.EQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
def not_any_(self, values: Iterable[Any]) -> ColumnExpression:
"""SQL <> ANY(...) clause."""
converted_values = [self._convert_value(v) for v in values]
return ColumnExpression(exp.NEQ(this=self._expression, expression=exp.Any(expressions=converted_values)))
def alias(self, alias_name: str) -> "exp.Expression":
"""Create an aliased function expression."""
return exp.Alias(this=self._expression, alias=alias_name)
def as_(self, alias: str) -> "exp.Alias":
"""Create an aliased expression using sqlglot helper."""
return cast("exp.Alias", exp.alias_(self._expression, alias))
def cast(self, data_type: str) -> "FunctionColumn":
"""SQL CAST() function."""
return FunctionColumn(exp.Cast(this=self._expression, to=exp.DataType.build(data_type)))
def asc(self) -> "exp.Ordered":
"""Create an ASC ordering expression."""
return exp.Ordered(this=self._expression, desc=False)
def desc(self) -> "exp.Ordered":
"""Create a DESC ordering expression."""
return exp.Ordered(this=self._expression, desc=True)
def __hash__(self) -> int:
"""Hash based on the SQL expression."""
return hash(self._expression.sql())