Source code for sqlspec.extensions.adk.service

"""SQLSpec-backed session service for Google ADK."""

import inspect
import logging
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast

from google.adk.sessions.base_session_service import BaseSessionService, GetSessionConfig, ListSessionsResponse

from sqlspec.extensions.adk.converters import (
    compute_update_marker,
    event_to_record,
    filter_temp_state,
    merge_scoped_state,
    record_to_session,
    split_scoped_state,
)
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.sync_tools import async_

if TYPE_CHECKING:
    from collections.abc import Callable

    from google.adk.events.event import Event
    from google.adk.sessions import Session

    from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore

    ADKStore = BaseAsyncADKStore | BaseSyncADKStore

__all__ = ("SQLSpecSessionService",)

logger = get_logger("sqlspec.extensions.adk.service")


[docs] class SQLSpecSessionService(BaseSessionService): """SQLSpec-backed implementation of BaseSessionService. Provides session and event storage using SQLSpec database adapters. Delegates all database operations to a store implementation. Args: store: Database store implementation. """
[docs] def __init__(self, store: "ADKStore") -> None: """Initialize the session service. Args: store: Database store implementation. """ self._store = store
@property def store(self) -> "ADKStore": """Return the database store.""" return self._store
[docs] async def create_session( self, *, app_name: str, user_id: str, state: "dict[str, Any] | None" = None, session_id: "str | None" = None ) -> "Session": """Create a new session. Args: app_name: Name of the application. user_id: ID of the user. state: Initial state of the session. session_id: Client-provided session ID. If None, generates a UUID. Returns: The newly created session. """ if session_id is None: session_id = str(uuid.uuid4()) if state is None: state = {} persisted_state = filter_temp_state(state) app_state_delta, user_state_delta, session_state = split_scoped_state(persisted_state) current_app_state = await self._call_store("get_app_state", app_name) current_user_state = await self._call_store("get_user_state", app_name, user_id) app_state = dict(current_app_state or {}) if app_state_delta: app_state.update(app_state_delta) user_state = dict(current_user_state or {}) if user_state_delta: user_state.update(user_state_delta) record = await self._call_store( "create_session", session_id=session_id, app_name=app_name, user_id=user_id, state=session_state ) if app_state_delta: await self._call_store("upsert_app_state", app_name, app_state) if user_state_delta: await self._call_store("upsert_user_state", app_name, user_id, user_state) record["state"] = merge_scoped_state(record["state"], app_state, user_state) log_with_context( logger, logging.DEBUG, "adk.session.create", app_name=app_name, session_id=session_id, has_state=bool(state) ) return record_to_session(record, events=[])
[docs] async def get_session( self, *, app_name: str, user_id: str, session_id: str, config: "GetSessionConfig | None" = None ) -> "Session | None": """Get a session by ID. Args: app_name: Name of the application. user_id: ID of the user. session_id: Session identifier. config: Configuration for retrieving events. Returns: Session object if found, None otherwise. """ record = await self._call_store("get_session", app_name, user_id, session_id) if not record: log_with_context( logger, logging.DEBUG, "adk.session.get", app_name=app_name, session_id=session_id, found=False ) return None if record["app_name"] != app_name or record["user_id"] != user_id: log_with_context( logger, logging.DEBUG, "adk.session.get", app_name=app_name, session_id=session_id, found=False ) return None app_state = await self._call_store("get_app_state", app_name) user_state = await self._call_store("get_user_state", app_name, user_id) record["state"] = merge_scoped_state(record["state"], app_state, user_state) after_timestamp = None limit = None if config: if config.after_timestamp: after_timestamp = datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc) limit = config.num_recent_events if limit == 0: events = [] else: events = await self._call_store( "get_events", app_name, user_id, session_id, after_timestamp=after_timestamp, limit=limit ) log_with_context( logger, logging.DEBUG, "adk.session.get", app_name=app_name, session_id=session_id, found=True, event_count=len(events), ) return record_to_session(record, events)
[docs] async def get_user_state(self, *, app_name: str, user_id: str) -> "dict[str, Any]": """Get user-scoped state for an app and user. ADK's service API returns unprefixed user state keys, while SQLSpec stores the durable state using ADK's documented ``user:`` prefix. Args: app_name: Name of the application. user_id: ID of the user. Returns: User-scoped state with ``user:`` prefixes removed. """ state = await self._call_store("get_user_state", app_name, user_id) if not state: return {} return {key.removeprefix("user:") if key.startswith("user:") else key: value for key, value in state.items()}
[docs] async def list_sessions(self, *, app_name: str, user_id: str | None = None) -> "ListSessionsResponse": """List all sessions for an app, optionally filtered by user. Args: app_name: Name of the application. user_id: ID of the user. If None, all sessions for the app are listed. Returns: Response containing list of sessions (without events). """ records = await self._call_store("list_sessions", app_name, user_id=user_id) sessions = [record_to_session(record, events=[]) for record in records] log_with_context( logger, logging.DEBUG, "adk.session.list", app_name=app_name, has_user_id=user_id is not None, count=len(sessions), ) return ListSessionsResponse(sessions=sessions)
[docs] async def delete_session(self, *, app_name: str, user_id: str, session_id: str) -> None: """Delete a session and all its events. Args: app_name: Name of the application. user_id: ID of the user. session_id: Session identifier. """ record = await self._call_store("get_session", app_name, user_id, session_id) if not record: log_with_context( logger, logging.DEBUG, "adk.session.delete", app_name=app_name, session_id=session_id, deleted=False ) return if record["app_name"] != app_name or record["user_id"] != user_id: log_with_context( logger, logging.DEBUG, "adk.session.delete", app_name=app_name, session_id=session_id, deleted=False ) return await self._call_store("delete_session", app_name, user_id, session_id) log_with_context( logger, logging.DEBUG, "adk.session.delete", app_name=app_name, session_id=session_id, deleted=True )
[docs] async def append_event(self, session: "Session", event: "Event") -> "Event": """Append an event to a session. Persists the event record and the post-append durable state atomically via ``store.append_event_and_update_state()``, then updates the in-memory session only after persistence succeeds. Implements stale-session detection: if the session has been modified in storage since it was last loaded, a ``ValueError`` is raised instead of silently overwriting. ``temp:`` keys are stripped from the persisted state snapshot so they never survive a reload. Args: session: Session to append to. event: Event to append. Returns: The appended event. Raises: ValueError: If the session has been modified in storage since it was loaded (stale session). """ if event.partial: return event # Apply temp state to in-memory session so subsequent agents in # the same invocation can read temp values, then strip temp keys # from the event delta before persistence. self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) event_record = event_to_record( event=event, app_name=session.app_name, user_id=session.user_id, session_id=session.id ) # --- Stale-session detection --- current_record = await self._call_store("get_session", session.app_name, session.user_id, session.id) if current_record is None: msg = f"Session {session.id} not found." raise ValueError(msg) if session._storage_update_marker is not None: # pyright: ignore[reportPrivateUsage] current_marker = compute_update_marker(current_record["update_time"]) if session._storage_update_marker != current_marker: # pyright: ignore[reportPrivateUsage] msg = ( "The session has been modified in storage since it was loaded. " "Please reload the session before appending more events." ) raise ValueError(msg) elif current_record["update_time"].timestamp() > session.last_update_time: msg = ( "The session has been modified in storage since it was loaded. " "Please reload the session before appending more events." ) raise ValueError(msg) state_delta = (event.actions.state_delta if event.actions else None) or {} app_state_delta, user_state_delta, session_state_delta = split_scoped_state(filter_temp_state(state_delta)) _, _, session_state = split_scoped_state(filter_temp_state(session.state)) session_state.update(session_state_delta) app_state = None if app_state_delta: app_state = dict(await self._call_store("get_app_state", session.app_name) or {}) app_state.update(app_state_delta) user_state = None if user_state_delta: user_state = dict(await self._call_store("get_user_state", session.app_name, session.user_id) or {}) user_state.update(user_state_delta) # --- Persist event and state atomically --- updated_record = await self._call_store( "append_event_and_update_state", event_record, session.app_name, session.user_id, session.id, session_state, app_state=app_state, user_state=user_state, ) updated_record["state"] = merge_scoped_state(updated_record["state"], app_state, user_state) # Use the returned record directly — saves a round-trip vs a follow-up get_session(). session.last_update_time = updated_record["update_time"].timestamp() session._storage_update_marker = compute_update_marker(updated_record["update_time"]) # pyright: ignore[reportPrivateUsage] # Update in-memory session AFTER successful persistence self._update_session_state(session, event) session.events.append(event) log_with_context( logger, logging.DEBUG, "adk.session.event.append", app_name=session.app_name, session_id=session.id ) return event
async def _call_store(self, method_name: str, *args: Any, **kwargs: Any) -> Any: """Call an async store method or offload a sync store method.""" method = getattr(self._store, method_name) if inspect.iscoroutinefunction(method): return await method(*args, **kwargs) sync_method = method if TYPE_CHECKING: sync_method = cast("Callable[..., Any]", method) return await async_(sync_method)(*args, **kwargs)