Source code for sqlspec.extensions.adk.memory.store

"""Base store classes for ADK memory backend (sync and async)."""

import logging
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast

from typing_extensions import NotRequired, TypedDict

from sqlspec.observability import resolve_db_system
from sqlspec.utils.logging import get_logger, log_with_context

if TYPE_CHECKING:
    from sqlspec.config import ADKConfig, DatabaseConfigProtocol
    from sqlspec.extensions.adk.memory._types import MemoryRecord

ConfigT = TypeVar("ConfigT", bound="DatabaseConfigProtocol[Any, Any, Any]")

logger = get_logger("sqlspec.extensions.adk.memory.store")

__all__ = ("BaseAsyncADKMemoryStore", "BaseSyncADKMemoryStore")

VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)")
MAX_TABLE_NAME_LENGTH: Final = 63


class _ADKMemoryStoreConfig(TypedDict):
    """Normalized ADK memory store configuration."""

    enable_memory: bool
    memory_table: str
    use_fts: bool
    max_results: int
    owner_id_column: NotRequired[str]


def _parse_owner_id_column(owner_id_column_ddl: str) -> str:
    """Extract column name from owner ID column DDL definition.

    Args:
        owner_id_column_ddl: Full column DDL string.

    Returns:
        Column name only (first word).

    Raises:
        ValueError: If DDL format is invalid.
    """
    match = COLUMN_NAME_PATTERN.match(owner_id_column_ddl.strip())
    if not match:
        msg = f"Invalid owner_id_column DDL: {owner_id_column_ddl!r}. Must start with column name."
        raise ValueError(msg)

    return match.group(1)


def _validate_table_name(table_name: str) -> None:
    """Validate table name for SQL safety.

    Args:
        table_name: Table name to validate.

    Raises:
        ValueError: If table name is invalid.
    """
    if not table_name:
        msg = "Table name cannot be empty"
        raise ValueError(msg)

    if len(table_name) > MAX_TABLE_NAME_LENGTH:
        msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})"
        raise ValueError(msg)

    if not VALID_TABLE_NAME_PATTERN.match(table_name):
        msg = (
            f"Invalid table name: {table_name!r}. "
            "Must start with letter/underscore and contain only alphanumeric characters and underscores"
        )
        raise ValueError(msg)


[docs] class BaseAsyncADKMemoryStore(ABC, Generic[ConfigT]): """Base class for async SQLSpec-backed ADK memory stores. Implements storage operations for Google ADK memory entries using SQLSpec database adapters with async/await. This abstract base class provides common functionality for all database-specific memory store implementations including: - Connection management via SQLSpec configs - Table name validation - Memory entry CRUD operations - Text search with optional full-text search support Subclasses must implement dialect-specific SQL queries and will be created in each adapter directory (e.g., sqlspec/adapters/asyncpg/adk/store.py). Args: config: SQLSpec database configuration with extension_config["adk"] settings. Notes: Configuration is read from config.extension_config["adk"]: - memory_table: Memory table name (default: "adk_memory_entries") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) - enable_memory: Whether memory is enabled (default: True) """ __slots__ = ( "_config", "_enabled", "_max_results", "_memory_table", "_owner_id_column_ddl", "_owner_id_column_name", "_use_fts", )
[docs] def __init__(self, config: ConfigT) -> None: """Initialize the ADK memory store. Args: config: SQLSpec database configuration. Notes: Reads configuration from config.extension_config["adk"]: - memory_table: Memory table name (default: "adk_memory_entries") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) - enable_memory: Whether memory is enabled (default: True) """ self._config = config store_config = self._get_store_config_from_extension() self._enabled: bool = store_config.get("enable_memory", True) self._memory_table: str = str(store_config["memory_table"]) self._use_fts: bool = bool(store_config.get("use_fts", False)) self._max_results: int = store_config.get("max_results", 20) self._owner_id_column_ddl: str | None = store_config.get("owner_id_column") self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) _validate_table_name(self._memory_table)
def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": """Extract ADK memory configuration from config.extension_config. Returns: Dict with memory_table, use_fts, max_results, and optionally owner_id_column. """ extension_config = self._config.extension_config adk_config = cast("ADKConfig", extension_config.get("adk", {})) enable_memory = adk_config.get("enable_memory") memory_table = adk_config.get("memory_table") use_fts = adk_config.get("memory_use_fts") max_results = adk_config.get("memory_max_results") result: _ADKMemoryStoreConfig = { "enable_memory": bool(enable_memory) if enable_memory is not None else True, "memory_table": str(memory_table) if memory_table is not None else "adk_memory_entries", "use_fts": bool(use_fts) if use_fts is not None else False, "max_results": int(max_results) if isinstance(max_results, int) else 20, } owner_id = adk_config.get("owner_id_column") if owner_id is not None: result["owner_id_column"] = owner_id return result @property def config(self) -> ConfigT: """Return the database configuration.""" return self._config @property def memory_table(self) -> str: """Return the memory table name.""" return self._memory_table @property def enabled(self) -> bool: """Return whether memory store is enabled.""" return self._enabled @property def use_fts(self) -> bool: """Return whether full-text search is enabled.""" return self._use_fts @property def max_results(self) -> int: """Return the max search results limit.""" return self._max_results @property def owner_id_column_ddl(self) -> "str | None": """Return the full owner ID column DDL (or None if not configured).""" return self._owner_id_column_ddl @property def owner_id_column_name(self) -> "str | None": """Return the owner ID column name only (or None if not configured).""" return self._owner_id_column_name
[docs] @abstractmethod async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist. Should check self._enabled and skip table creation if False. """ raise NotImplementedError
[docs] async def ensure_tables(self) -> None: """Create tables when enabled and emit a standardized log entry.""" if not self._enabled: self._log_memory_table_skipped() return await self.create_tables() self._log_memory_table_created()
[docs] @abstractmethod async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication. Uses UPSERT pattern to skip duplicates based on event_id. Args: entries: List of memory records to insert. owner_id: Optional owner ID value for owner_id_column (if configured). Returns: Number of entries actually inserted (excludes duplicates). Raises: RuntimeError: If memory store is disabled. """ raise NotImplementedError
def _log_memory_table_created(self) -> None: log_with_context( logger, logging.DEBUG, "adk.memory.table.ready", db_system=resolve_db_system(type(self).__name__), memory_table=self._memory_table, ) def _log_memory_table_skipped(self) -> None: log_with_context( logger, logging.DEBUG, "adk.memory.table.skipped", db_system=resolve_db_system(type(self).__name__), memory_table=self._memory_table, reason="disabled", )
[docs] @abstractmethod async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": """Search memory entries by text query. Uses the configured search strategy (simple ILIKE or FTS). Args: query: Text query to search for. app_name: Application name to filter by. user_id: User ID to filter by. limit: Maximum number of results (defaults to max_results config). Returns: List of matching memory records ordered by relevance/timestamp. Raises: RuntimeError: If memory store is disabled. """ raise NotImplementedError
[docs] @abstractmethod async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session. Args: session_id: Session ID to delete entries for. Returns: Number of entries deleted. """ raise NotImplementedError
[docs] @abstractmethod async def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days. Used for TTL cleanup operations. Args: days: Number of days to retain entries. Returns: Number of entries deleted. """ raise NotImplementedError
@abstractmethod async def _get_create_memory_table_sql(self) -> "str | list[str]": """Get the CREATE TABLE SQL for the memory table. Returns: SQL statement(s) to create the memory table with indexes. """ raise NotImplementedError @abstractmethod def _get_drop_memory_table_sql(self) -> "list[str]": """Get the DROP TABLE SQL statements for this database dialect. Returns: List of SQL statements to drop the memory table and indexes. """ raise NotImplementedError
[docs] class BaseSyncADKMemoryStore(ABC, Generic[ConfigT]): """Base class for sync SQLSpec-backed ADK memory stores. Implements storage operations for Google ADK memory entries using SQLSpec database adapters with synchronous execution. This abstract base class provides common functionality for sync database-specific memory store implementations including: - Connection management via SQLSpec configs - Table name validation - Memory entry CRUD operations - Text search with optional full-text search support Subclasses must implement dialect-specific SQL queries and will be created in each adapter directory (e.g., sqlspec/adapters/sqlite/adk/store.py). Args: config: SQLSpec database configuration with extension_config["adk"] settings. Notes: Configuration is read from config.extension_config["adk"]: - memory_table: Memory table name (default: "adk_memory_entries") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) - enable_memory: Whether memory is enabled (default: True) """ __slots__ = ( "_config", "_enabled", "_max_results", "_memory_table", "_owner_id_column_ddl", "_owner_id_column_name", "_use_fts", )
[docs] def __init__(self, config: ConfigT) -> None: """Initialize the sync ADK memory store. Args: config: SQLSpec database configuration. Notes: Reads configuration from config.extension_config["adk"]: - memory_table: Memory table name (default: "adk_memory_entries") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) - enable_memory: Whether memory is enabled (default: True) """ self._config = config store_config = self._get_store_config_from_extension() self._enabled: bool = store_config.get("enable_memory", True) self._memory_table: str = str(store_config["memory_table"]) self._use_fts: bool = bool(store_config.get("use_fts", False)) self._max_results: int = store_config.get("max_results", 20) self._owner_id_column_ddl: str | None = store_config.get("owner_id_column") self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) _validate_table_name(self._memory_table)
def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": """Extract ADK memory configuration from config.extension_config. Returns: Dict with memory_table, use_fts, max_results, and optionally owner_id_column. """ extension_config = self._config.extension_config adk_config = cast("ADKConfig", extension_config.get("adk", {})) enable_memory = adk_config.get("enable_memory") memory_table = adk_config.get("memory_table") use_fts = adk_config.get("memory_use_fts") max_results = adk_config.get("memory_max_results") result: _ADKMemoryStoreConfig = { "enable_memory": bool(enable_memory) if enable_memory is not None else True, "memory_table": str(memory_table) if memory_table is not None else "adk_memory_entries", "use_fts": bool(use_fts) if use_fts is not None else False, "max_results": int(max_results) if isinstance(max_results, int) else 20, } owner_id = adk_config.get("owner_id_column") if owner_id is not None: result["owner_id_column"] = owner_id return result @property def config(self) -> ConfigT: """Return the database configuration.""" return self._config @property def memory_table(self) -> str: """Return the memory table name.""" return self._memory_table @property def enabled(self) -> bool: """Return whether memory store is enabled.""" return self._enabled @property def use_fts(self) -> bool: """Return whether full-text search is enabled.""" return self._use_fts @property def max_results(self) -> int: """Return the max search results limit.""" return self._max_results @property def owner_id_column_ddl(self) -> "str | None": """Return the full owner ID column DDL (or None if not configured).""" return self._owner_id_column_ddl @property def owner_id_column_name(self) -> "str | None": """Return the owner ID column name only (or None if not configured).""" return self._owner_id_column_name
[docs] @abstractmethod def create_tables(self) -> None: """Create the memory table and indexes if they don't exist. Should check self._enabled and skip table creation if False. """ raise NotImplementedError
[docs] def ensure_tables(self) -> None: """Create tables when enabled and emit a standardized log entry.""" if not self._enabled: self._log_memory_table_skipped() return self.create_tables() self._log_memory_table_created()
[docs] @abstractmethod def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication. Uses UPSERT pattern to skip duplicates based on event_id. Args: entries: List of memory records to insert. owner_id: Optional owner ID value for owner_id_column (if configured). Returns: Number of entries actually inserted (excludes duplicates). Raises: RuntimeError: If memory store is disabled. """ raise NotImplementedError
def _log_memory_table_created(self) -> None: logger.debug( "ADK memory table ready", extra={"db.system": resolve_db_system(type(self).__name__), "memory_table": self._memory_table}, ) def _log_memory_table_skipped(self) -> None: logger.debug( "ADK memory table creation skipped", extra={ "db.system": resolve_db_system(type(self).__name__), "memory_table": self._memory_table, "reason": "disabled", }, )
[docs] @abstractmethod def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": """Search memory entries by text query. Uses the configured search strategy (simple ILIKE or FTS). Args: query: Text query to search for. app_name: Application name to filter by. user_id: User ID to filter by. limit: Maximum number of results (defaults to max_results config). Returns: List of matching memory records ordered by relevance/timestamp. Raises: RuntimeError: If memory store is disabled. """ raise NotImplementedError
[docs] @abstractmethod def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session. Args: session_id: Session ID to delete entries for. Returns: Number of entries deleted. """ raise NotImplementedError
[docs] @abstractmethod def delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days. Used for TTL cleanup operations. Args: days: Number of days to retain entries. Returns: Number of entries deleted. """ raise NotImplementedError
@abstractmethod def _get_create_memory_table_sql(self) -> "str | list[str]": """Get the CREATE TABLE SQL for the memory table. Returns: SQL statement(s) to create the memory table with indexes. """ raise NotImplementedError @abstractmethod def _get_drop_memory_table_sql(self) -> "list[str]": """Get the DROP TABLE SQL statements for this database dialect. Returns: List of SQL statements to drop the memory table and indexes. """ raise NotImplementedError