Source code for sqlspec.adapters.arrow_odbc.adk.store

"""arrow-odbc ADK stores for Google Agent Development Kit session storage."""

from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Final, cast

from typing_extensions import NotRequired

from sqlspec.config import ADKConfig
from sqlspec.exceptions import SQLSpecError
from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord
from sqlspec.extensions.adk.memory import BaseSyncADKMemoryStore, MemoryRecord
from sqlspec.utils.serializers import from_json, to_json

if TYPE_CHECKING:
    from datetime import timedelta

    from sqlspec.adapters.arrow_odbc.config import ArrowOdbcConfig
else:
    ArrowOdbcConfig = Any


__all__ = ("ArrowOdbcADKConfig", "ArrowOdbcADKMemoryStore", "ArrowOdbcADKStore")

MSSQL_SCHEMA: Final[str] = "dbo"
JSON_COLUMN_TYPE: Final[str] = "NVARCHAR(MAX)"


class ArrowOdbcADKConfig(ADKConfig):
    """arrow-odbc ADK extension settings."""

    native_json: NotRequired[bool]
    """Accepted for parity with SQL Server adapters; arrow-odbc uses NVARCHAR(MAX)."""


[docs] class ArrowOdbcADKStore(BaseSyncADKStore["ArrowOdbcConfig"]): """Synchronous SQL Server ADK session/event store using arrow-odbc.""" connector_name: ClassVar[str] = "arrow_odbc" __slots__ = ()
[docs] def create_tables(self) -> None: """Create all ADK session tables if they do not exist.""" with self._config.provide_session() as driver: driver.execute(self._get_create_sessions_table_sql()) driver.execute(self._get_create_events_table_sql()) driver.execute(self._get_create_app_states_table_sql()) driver.execute(self._get_create_user_states_table_sql()) driver.execute(self._get_create_metadata_table_sql()) driver.execute(self._get_seed_metadata_sql()) driver.commit()
[docs] def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: """Create a new ADK session.""" owner_column = f", {_quote_identifier(self._owner_id_column_name)}" if self._owner_id_column_name else "" owner_param = ", ?" if self._owner_id_column_name else "" params: tuple[Any, ...] if self._owner_id_column_name: params = (session_id, app_name, user_id, owner_id, to_json(state)) else: params = (session_id, app_name, user_id, to_json(state)) with self._config.provide_session() as driver: driver.execute( f""" INSERT INTO {_table_ref(self._session_table)} ( id, app_name, user_id{owner_column}, state, create_time, update_time ) VALUES (?, ?, ?{owner_param}, ?, SYSUTCDATETIME(), SYSUTCDATETIME()) """, params, ) row = driver.select_one_or_none( _get_session_select_sql(self._session_table), (app_name, user_id, session_id) ) driver.commit() if row is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return _session_record_from_row(row)
[docs] def get_session( self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Return a scoped session or ``None`` if absent.""" try: with self._config.provide_session() as driver: if renew_for is not None and self._calculate_expires_at(renew_for) is not None: driver.execute( f""" UPDATE {_table_ref(self._session_table)} SET update_time = SYSUTCDATETIME() WHERE app_name = ? AND user_id = ? AND id = ? """, (app_name, user_id, session_id), ) row = driver.select_one_or_none( _get_session_select_sql(self._session_table), (app_name, user_id, session_id) ) if renew_for is not None: driver.commit() except SQLSpecError as exc: if _is_table_missing(exc): return None raise return _session_record_from_row(row) if row is not None else None
[docs] def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Replace a session's durable state.""" self._execute( f""" UPDATE {_table_ref(self._session_table)} SET state = ?, update_time = SYSUTCDATETIME() WHERE app_name = ? AND user_id = ? AND id = ? """, (to_json(state), app_name, user_id, session_id), commit=True, )
[docs] def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": """List ADK sessions for an application, optionally scoped to a user.""" if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {_table_ref(self._session_table)} WHERE app_name = ? ORDER BY update_time DESC """ params: tuple[Any, ...] = (app_name,) else: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {_table_ref(self._session_table)} WHERE app_name = ? AND user_id = ? ORDER BY update_time DESC """ params = (app_name, user_id) try: rows = self._execute_fetchall(sql, params) except SQLSpecError as exc: if _is_table_missing(exc): return [] raise return [_session_record_from_row(row) for row in rows]
[docs] def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete a session. Event rows cascade through the FK.""" self._execute( f"DELETE FROM {_table_ref(self._session_table)} WHERE app_name = ? AND user_id = ? AND id = ?", (app_name, user_id, session_id), commit=True, )
[docs] def append_event(self, event_record: EventRecord) -> None: """Append an event to a session.""" self._execute(_get_insert_event_sql(self._events_table), _event_insert_params(event_record), commit=True)
[docs] def append_event_and_update_state( self, event_record: EventRecord, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]", *, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update durable session/scoped state.""" with self._config.provide_session() as driver: driver.execute( f""" UPDATE {_table_ref(self._session_table)} SET state = ?, update_time = SYSUTCDATETIME() WHERE app_name = ? AND user_id = ? AND id = ? """, (to_json(state), app_name, user_id, session_id), ) row = driver.select_one_or_none( _get_session_select_sql(self._session_table), (app_name, user_id, session_id) ) if row is None: _raise_session_not_found(session_id) driver.execute(_get_insert_event_sql(self._events_table), _event_insert_params(event_record)) if app_state is not None: driver.execute(self._get_upsert_app_state_sql(), (app_name, to_json(app_state))) if user_state is not None: driver.execute(self._get_upsert_user_state_sql(), (app_name, user_id, to_json(user_state))) driver.commit() return _session_record_from_row(row)
[docs] def get_events( self, app_name: str, user_id: str, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None, ) -> "list[EventRecord]": """Return events for a scoped session ordered by event timestamp.""" if limit is not None and limit <= 0: return [] sql, params = self._get_events_query(app_name, user_id, session_id, after_timestamp, limit) try: rows = self._execute_fetchall(sql, params) except SQLSpecError as exc: if _is_table_missing(exc): return [] raise return [_event_record_from_row(row) for row in rows]
[docs] def delete_expired_events(self, before: datetime) -> int: """Delete events older than ``before``.""" try: count = self._select_count( f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._events_table)} WHERE timestamp < ?", (_format_datetime(before),), ) self._execute( f"DELETE FROM {_table_ref(self._events_table)} WHERE timestamp < ?", (_format_datetime(before),), commit=True, ) except SQLSpecError as exc: if _is_table_missing(exc): return 0 raise else: return count
[docs] def delete_idle_sessions(self, updated_before: datetime) -> int: """Delete sessions whose update_time is older than ``updated_before``.""" try: count = self._select_count( f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._session_table)} WHERE update_time < ?", (_format_datetime(updated_before),), ) self._execute( f"DELETE FROM {_table_ref(self._session_table)} WHERE update_time < ?", (_format_datetime(updated_before),), commit=True, ) except SQLSpecError as exc: if _is_table_missing(exc): return 0 raise else: return count
[docs] def get_app_state(self, app_name: str) -> "dict[str, Any] | None": """Return app-scoped state.""" try: row = self._execute_fetchone( f"SELECT TOP 1 state FROM {_table_ref(self._app_state_table)} WHERE app_name = ?", (app_name,) ) except SQLSpecError as exc: if _is_table_missing(exc): return None raise return _json_dict(_row_value(row, "state", 0)) if row is not None else None
[docs] def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": """Return user-scoped state.""" try: row = self._execute_fetchone( f""" SELECT TOP 1 state FROM {_table_ref(self._user_state_table)} WHERE app_name = ? AND user_id = ? """, (app_name, user_id), ) except SQLSpecError as exc: if _is_table_missing(exc): return None raise return _json_dict(_row_value(row, "state", 0)) if row is not None else None
[docs] def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: """Insert or replace app-scoped state.""" self._execute(self._get_upsert_app_state_sql(), (app_name, to_json(state)), commit=True)
[docs] def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: """Insert or replace user-scoped state.""" self._execute(self._get_upsert_user_state_sql(), (app_name, user_id, to_json(state)), commit=True)
[docs] def get_metadata(self, key: str) -> "str | None": """Return an ADK metadata value.""" try: row = self._execute_fetchone( f"SELECT TOP 1 value FROM {_table_ref(self._metadata_table)} WHERE [key] = ?", (key,) ) except SQLSpecError as exc: if _is_table_missing(exc): return None raise value = _row_value(row, "value", 0) if row is not None else None return str(value) if value is not None else None
[docs] def set_metadata(self, key: str, value: str) -> None: """Set an ADK metadata value.""" self._execute(_get_upsert_metadata_sql(self._metadata_table), (key, value), commit=True)
def _get_create_sessions_table_sql(self) -> str: """Return T-SQL DDL for the ADK session table.""" return _get_create_sessions_table_sql(self._session_table, self._owner_id_column_ddl) def _get_create_events_table_sql(self) -> str: """Return T-SQL DDL for the ADK event table.""" return _get_create_events_table_sql(self._events_table, self._session_table) def _get_create_app_states_table_sql(self) -> str: """Return T-SQL DDL for the app-scoped state table.""" return _get_create_app_states_table_sql(self._app_state_table) def _get_create_user_states_table_sql(self) -> str: """Return T-SQL DDL for the user-scoped state table.""" return _get_create_user_states_table_sql(self._user_state_table) def _get_create_metadata_table_sql(self) -> str: """Return T-SQL DDL for the ADK metadata table.""" return _get_create_metadata_table_sql(self._metadata_table) def _get_seed_metadata_sql(self) -> str: """Return T-SQL to seed schema-version metadata.""" return _get_seed_metadata_sql(self._metadata_table) def _get_drop_app_states_table_sql(self) -> str: return f"DROP TABLE IF EXISTS {_table_ref(self._app_state_table)}" def _get_drop_user_states_table_sql(self) -> str: return f"DROP TABLE IF EXISTS {_table_ref(self._user_state_table)}" def _get_drop_metadata_table_sql(self) -> str: return f"DROP TABLE IF EXISTS {_table_ref(self._metadata_table)}" def _get_drop_tables_sql(self) -> "list[str]": return [ self._get_drop_metadata_table_sql(), self._get_drop_user_states_table_sql(), self._get_drop_app_states_table_sql(), f"DROP TABLE IF EXISTS {_table_ref(self._events_table)}", f"DROP TABLE IF EXISTS {_table_ref(self._session_table)}", ] def _get_upsert_app_state_sql(self) -> str: return _get_upsert_state_sql(self._app_state_table, ("app_name",), ("?",)) def _get_upsert_user_state_sql(self) -> str: return _get_upsert_state_sql(self._user_state_table, ("app_name", "user_id"), ("?", "?")) def _get_events_query( self, app_name: str, user_id: str, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None, ) -> "tuple[str, tuple[Any, ...]]": return _get_events_query(self._events_table, app_name, user_id, session_id, after_timestamp, limit) def _execute_fetchone(self, sql: str, params: "tuple[Any, ...]" = ()) -> "dict[str, Any] | None": with self._config.provide_session() as driver: return driver.select_one_or_none(sql, params) def _execute_fetchall(self, sql: str, params: "tuple[Any, ...]" = ()) -> "list[dict[str, Any]]": with self._config.provide_session() as driver: return driver.select(sql, params) def _execute(self, sql: str, params: "tuple[Any, ...]" = (), *, commit: bool = False) -> int: with self._config.provide_session() as driver: result = driver.execute(sql, params) if commit: driver.commit() return int(result.rows_affected) def _select_count(self, sql: str, params: "tuple[Any, ...]" = ()) -> int: with self._config.provide_session() as driver: value = driver.select_value(sql, params) return int(value or 0)
[docs] class ArrowOdbcADKMemoryStore(BaseSyncADKMemoryStore["ArrowOdbcConfig"]): """SQL Server ADK memory store using arrow-odbc.""" __slots__ = ()
[docs] def create_tables(self) -> None: """Create the memory table if memory storage is enabled.""" if not self._enabled: return with self._config.provide_session() as driver: for statement in self._get_create_memory_table_sql(): driver.execute(statement) driver.commit()
[docs] def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Insert memory entries, skipping duplicates by event_id.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) if not entries: return 0 inserted_count = 0 with self._config.provide_session() as driver: for entry in entries: exists = driver.select_one_or_none( f"SELECT TOP 1 id FROM {_table_ref(self._memory_table)} WHERE event_id = ?", (entry["event_id"],) ) if exists is not None: continue owner_column = ( f", {_quote_identifier(self._owner_id_column_name)}" if self._owner_id_column_name else "" ) owner_param = ", ?" if self._owner_id_column_name else "" params: tuple[Any, ...] if self._owner_id_column_name: params = (*_memory_insert_params(entry), owner_id) else: params = _memory_insert_params(entry) driver.execute( f""" INSERT INTO {_table_ref(self._memory_table)} ( id, session_id, app_name, user_id, event_id, author, timestamp, content_json, content_text, metadata_json, inserted_at{owner_column} ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?{owner_param}) """, params, ) inserted_count += 1 driver.commit() return inserted_count
[docs] def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": """Search memory entries with SQL Server LIKE matching.""" if not self._enabled: msg = "Memory store is disabled" raise RuntimeError(msg) effective_limit = max(0, int(limit if limit is not None else self._max_results)) if effective_limit == 0: return [] rows = self._execute_fetchall( f""" SELECT id, session_id, app_name, user_id, event_id, author, timestamp, content_json, content_text, metadata_json, inserted_at FROM {_table_ref(self._memory_table)} WHERE app_name = ? AND user_id = ? AND content_text LIKE ? ORDER BY timestamp DESC OFFSET 0 ROWS FETCH NEXT {effective_limit} ROWS ONLY """, (app_name, user_id, f"%{query}%"), ) return [_memory_record_from_row(row) for row in rows]
[docs] def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" count = self._select_count( f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._memory_table)} WHERE session_id = ?", (session_id,) ) self._execute(f"DELETE FROM {_table_ref(self._memory_table)} WHERE session_id = ?", (session_id,), commit=True) return count
[docs] def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than ``days`` days.""" cutoff = datetime.now(timezone.utc).timestamp() - (days * 86_400) cutoff_dt = datetime.fromtimestamp(cutoff, tz=timezone.utc) count = self._select_count( f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._memory_table)} WHERE inserted_at < ?", (_format_datetime(cutoff_dt),), ) self._execute( f"DELETE FROM {_table_ref(self._memory_table)} WHERE inserted_at < ?", (_format_datetime(cutoff_dt),), commit=True, ) return count
def _get_create_memory_table_sql(self) -> "list[str]": owner_line = f",\n {self._owner_id_column_ddl}" if self._owner_id_column_ddl else "" return [ f""" IF NOT EXISTS ( SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(self._memory_table)}' AND schema_id = SCHEMA_ID(N'dbo') ) BEGIN CREATE TABLE {_table_ref(self._memory_table)} ( id NVARCHAR(128) NOT NULL, session_id NVARCHAR(128) NOT NULL, app_name NVARCHAR(128) NOT NULL, user_id NVARCHAR(128) NOT NULL, event_id NVARCHAR(128) NOT NULL, author NVARCHAR(256) NULL, timestamp DATETIME2(6) NOT NULL, content_json NVARCHAR(MAX) NOT NULL, content_text NVARCHAR(MAX) NOT NULL, metadata_json NVARCHAR(MAX) NULL, inserted_at DATETIME2(6) NOT NULL{owner_line}, CONSTRAINT {_constraint_ref("pk", self._memory_table, "id")} PRIMARY KEY (id), CONSTRAINT {_constraint_ref("uq", self._memory_table, "event_id")} UNIQUE (event_id) ); END; """, _get_create_index_sql( self._memory_table, f"idx_{self._memory_table}_app_user_time", "app_name, user_id, timestamp DESC" ), _get_create_index_sql(self._memory_table, f"idx_{self._memory_table}_session", "session_id"), ] def _get_drop_memory_table_sql(self) -> "list[str]": return [f"DROP TABLE IF EXISTS {_table_ref(self._memory_table)}"] def _execute_fetchall(self, sql: str, params: "tuple[Any, ...]" = ()) -> "list[dict[str, Any]]": with self._config.provide_session() as driver: return driver.select(sql, params) def _execute(self, sql: str, params: "tuple[Any, ...]" = (), *, commit: bool = False) -> int: with self._config.provide_session() as driver: result = driver.execute(sql, params) if commit: driver.commit() return int(result.rows_affected) def _select_count(self, sql: str, params: "tuple[Any, ...]" = ()) -> int: with self._config.provide_session() as driver: value = driver.select_value(sql, params) return int(value or 0)
def _get_session_select_sql(table: str) -> str: return f""" SELECT TOP 1 id, app_name, user_id, state, create_time, update_time FROM {_table_ref(table)} WHERE app_name = ? AND user_id = ? AND id = ? """ def _get_create_sessions_table_sql(table: str, owner_id_column_ddl: "str | None") -> str: owner_line = f",\n {owner_id_column_ddl}" if owner_id_column_ddl else "" return f""" IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) BEGIN CREATE TABLE {_table_ref(table)} ( row_id UNIQUEIDENTIFIER NOT NULL CONSTRAINT {_constraint_ref("df", table, "row_id")} DEFAULT NEWSEQUENTIALID(), id NVARCHAR(128) NOT NULL, app_name NVARCHAR(128) NOT NULL, user_id NVARCHAR(128) NOT NULL{owner_line}, state {JSON_COLUMN_TYPE} NOT NULL, create_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "create_time")} DEFAULT SYSUTCDATETIME(), update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), CONSTRAINT {_constraint_ref("pk", table, "row_id")} PRIMARY KEY (row_id), CONSTRAINT {_constraint_ref("uq", table, "id")} UNIQUE (id) ); END; {_get_create_index_sql(table, f"idx_{table}_app_user", "app_name, user_id")} {_get_create_index_sql(table, f"idx_{table}_update_time", "update_time DESC")} """ def _get_create_events_table_sql(table: str, session_table: str) -> str: return f""" IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) BEGIN CREATE TABLE {_table_ref(table)} ( row_id UNIQUEIDENTIFIER NOT NULL CONSTRAINT {_constraint_ref("df", table, "row_id")} DEFAULT NEWSEQUENTIALID(), id NVARCHAR(128) NOT NULL, app_name NVARCHAR(128) NOT NULL, user_id NVARCHAR(128) NOT NULL, session_id NVARCHAR(128) NOT NULL, invocation_id NVARCHAR(256) NOT NULL, timestamp DATETIME2(6) NOT NULL, event_data {JSON_COLUMN_TYPE} NOT NULL, CONSTRAINT {_constraint_ref("pk", table, "row_id")} PRIMARY KEY (row_id), CONSTRAINT {_constraint_ref("uq", table, "id")} UNIQUE (id), CONSTRAINT {_constraint_ref("fk", table, "session")} FOREIGN KEY (session_id) REFERENCES {_table_ref(session_table)}(id) ON DELETE CASCADE ); END; {_get_create_index_sql(table, f"idx_{table}_scope", "app_name, user_id, session_id, timestamp ASC")} {_get_create_index_sql(table, f"idx_{table}_session", "session_id, timestamp ASC")} {_get_create_index_sql(table, f"idx_{table}_invocation", "invocation_id")} {_get_create_index_sql(table, f"idx_{table}_timestamp", "timestamp ASC")} """ def _get_create_app_states_table_sql(table: str) -> str: return f""" IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) BEGIN CREATE TABLE {_table_ref(table)} ( app_name NVARCHAR(128) NOT NULL, state {JSON_COLUMN_TYPE} NOT NULL, update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), CONSTRAINT {_constraint_ref("pk", table, "app_name")} PRIMARY KEY (app_name) ); END; """ def _get_create_user_states_table_sql(table: str) -> str: return f""" IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) BEGIN CREATE TABLE {_table_ref(table)} ( app_name NVARCHAR(128) NOT NULL, user_id NVARCHAR(128) NOT NULL, state {JSON_COLUMN_TYPE} NOT NULL, update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), CONSTRAINT {_constraint_ref("pk", table, "app_user")} PRIMARY KEY (app_name, user_id) ); END; """ def _get_create_metadata_table_sql(table: str) -> str: return f""" IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) BEGIN CREATE TABLE {_table_ref(table)} ( [key] NVARCHAR(128) NOT NULL, value NVARCHAR(512) NOT NULL, CONSTRAINT {_constraint_ref("pk", table, "key")} PRIMARY KEY ([key]) ); END; """ def _get_create_index_sql(table: str, index_name: str, columns: str) -> str: return f""" IF NOT EXISTS ( SELECT 1 FROM sys.indexes WHERE name = N'{_escape_sql_literal(index_name)}' AND object_id = OBJECT_ID(N'{_escape_sql_literal(MSSQL_SCHEMA)}.{_escape_sql_literal(table)}') ) BEGIN CREATE INDEX {_quote_identifier(index_name)} ON {_table_ref(table)} ({columns}); END; """ def _get_insert_event_sql(table: str) -> str: return f""" INSERT INTO {_table_ref(table)} ( id, app_name, user_id, session_id, invocation_id, timestamp, event_data ) VALUES (?, ?, ?, ?, ?, ?, ?) """ def _get_upsert_state_sql(table: str, key_columns: "tuple[str, ...]", key_params: "tuple[str, ...]") -> str: source_columns = ", ".join( f"{param} AS {_quote_identifier(column)}" for column, param in zip(key_columns, key_params, strict=False) ) source_columns = f"{source_columns}, ? AS state" insert_columns = ", ".join(_quote_identifier(column) for column in (*key_columns, "state", "update_time")) insert_values = ", ".join(f"source.{_quote_identifier(column)}" for column in (*key_columns, "state")) match_clause = " AND ".join( f"target.{_quote_identifier(column)} = source.{_quote_identifier(column)}" for column in key_columns ) return f""" MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target USING (SELECT {source_columns}) AS source ON ({match_clause}) WHEN MATCHED THEN UPDATE SET state = source.state, update_time = SYSUTCDATETIME() WHEN NOT MATCHED THEN INSERT ({insert_columns}) VALUES ({insert_values}, SYSUTCDATETIME()); """ def _get_upsert_metadata_sql(table: str) -> str: return f""" MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target USING (SELECT ? AS [key], ? AS value) AS source ON (target.[key] = source.[key]) WHEN MATCHED THEN UPDATE SET value = source.value WHEN NOT MATCHED THEN INSERT ([key], value) VALUES (source.[key], source.value); """ def _get_seed_metadata_sql(table: str) -> str: return f""" MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target USING (SELECT N'schema_version' AS [key], N'1' AS value) AS source ON (target.[key] = source.[key]) WHEN MATCHED THEN UPDATE SET value = source.value WHEN NOT MATCHED THEN INSERT ([key], value) VALUES (source.[key], source.value); """ def _get_events_query( table: str, app_name: str, user_id: str, session_id: str, after_timestamp: "datetime | None", limit: "int | None" ) -> "tuple[str, tuple[Any, ...]]": top_clause = f"TOP {int(limit)} " if limit is not None else "" params: list[Any] = [app_name, user_id, session_id] after_clause = "" if after_timestamp is not None: after_clause = " AND timestamp > ?" params.append(_format_datetime(after_timestamp)) sql = f""" SELECT {top_clause}id, app_name, user_id, session_id, invocation_id, timestamp, event_data FROM {_table_ref(table)} WHERE app_name = ? AND user_id = ? AND session_id = ?{after_clause} ORDER BY timestamp ASC """ return sql, tuple(params) def _event_insert_params(event_record: EventRecord) -> "tuple[Any, ...]": return ( event_record["id"], event_record["app_name"], event_record["user_id"], event_record["session_id"], event_record["invocation_id"], _format_datetime(event_record["timestamp"]), to_json(event_record["event_data"]), ) def _session_record_from_row(row: Any) -> SessionRecord: return SessionRecord( id=str(_row_value(row, "id", 0)), app_name=str(_row_value(row, "app_name", 1)), user_id=str(_row_value(row, "user_id", 2)), state=_json_dict(_row_value(row, "state", 3)), create_time=_datetime_value(_row_value(row, "create_time", 4)), update_time=_datetime_value(_row_value(row, "update_time", 5)), ) def _event_record_from_row(row: Any) -> EventRecord: return EventRecord( id=str(_row_value(row, "id", 0)), app_name=str(_row_value(row, "app_name", 1)), user_id=str(_row_value(row, "user_id", 2)), session_id=str(_row_value(row, "session_id", 3)), invocation_id=str(_row_value(row, "invocation_id", 4)), timestamp=_datetime_value(_row_value(row, "timestamp", 5)), event_data=_json_dict(_row_value(row, "event_data", 6)), ) def _memory_insert_params(entry: MemoryRecord) -> "tuple[Any, ...]": return ( entry["id"], entry["session_id"], entry["app_name"], entry["user_id"], entry["event_id"], entry["author"], _format_datetime(entry["timestamp"]), to_json(entry["content_json"]), entry["content_text"], to_json(entry["metadata_json"]) if entry["metadata_json"] is not None else None, _format_datetime(entry["inserted_at"]), ) def _memory_record_from_row(row: Any) -> MemoryRecord: return MemoryRecord( id=str(_row_value(row, "id", 0)), session_id=str(_row_value(row, "session_id", 1)), app_name=str(_row_value(row, "app_name", 2)), user_id=str(_row_value(row, "user_id", 3)), event_id=str(_row_value(row, "event_id", 4)), author=cast("str | None", _row_value(row, "author", 5)), timestamp=_datetime_value(_row_value(row, "timestamp", 6)), content_json=_json_dict(_row_value(row, "content_json", 7)), content_text=str(_row_value(row, "content_text", 8) or ""), metadata_json=_optional_json_dict(_row_value(row, "metadata_json", 9)), inserted_at=_datetime_value(_row_value(row, "inserted_at", 10)), ) def _row_value(row: Any, key: str, index: int) -> Any: if isinstance(row, dict): if key in row: return row[key] upper_key = key.upper() if upper_key in row: return row[upper_key] return None if isinstance(row, (list, tuple)) and len(row) > index: return row[index] return getattr(row, key, None) def _json_dict(value: Any) -> "dict[str, Any]": if value is None: return {} if isinstance(value, dict): return cast("dict[str, Any]", value) if isinstance(value, bytearray): value = bytes(value) if isinstance(value, bytes): value = value.decode("utf-8") if isinstance(value, str): return cast("dict[str, Any]", from_json(value)) return cast("dict[str, Any]", from_json(str(value))) def _optional_json_dict(value: Any) -> "dict[str, Any] | None": if value is None: return None return _json_dict(value) def _datetime_value(value: Any) -> datetime: if isinstance(value, datetime): if value.tzinfo is None: return value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc) if isinstance(value, bytearray): value = bytes(value) if isinstance(value, bytes): value = value.decode("utf-8") if isinstance(value, str): normalized = value.replace("Z", "+00:00") parsed = datetime.fromisoformat(normalized) if parsed.tzinfo is None: return parsed.replace(tzinfo=timezone.utc) return parsed.astimezone(timezone.utc) return datetime.now(timezone.utc) def _format_datetime(value: "datetime | None") -> "str | None": if value is None: return None normalized = value.replace(tzinfo=timezone.utc) if value.tzinfo is None else value.astimezone(timezone.utc) return normalized.replace(tzinfo=None).isoformat(timespec="microseconds") def _is_table_missing(exc: BaseException) -> bool: text = str(exc).lower() return "invalid object name" in text or "42s02" in text or "(208)" in text def _quote_identifier(identifier: str) -> str: return f"[{identifier.replace(']', ']]')}]" def _table_ref(table: str) -> str: return f"{_quote_identifier(MSSQL_SCHEMA)}.{_quote_identifier(table)}" def _constraint_ref(prefix: str, table: str, suffix: str) -> str: return _quote_identifier(f"{prefix}_{table}_{suffix}") def _escape_sql_literal(value: str) -> str: return value.replace("'", "''") def _raise_session_not_found(session_id: str) -> None: msg = f"Session {session_id} not found during append_event_and_update_state." raise ValueError(msg)