Source code for sqlspec.core.cache

"""Caching system for SQL statement processing.

This module provides a caching system with LRU eviction and TTL support for
SQL statement processing, parameter processing, and expression caching.

Components:
- CacheKey: Immutable cache key
- UnifiedCache: Cache implementation with LRU eviction and TTL
- StatementCache: Cache for compiled SQL statements
- ExpressionCache: Cache for parsed expressions
- ParameterCache: Cache for processed parameters
"""

import threading
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Final, Optional

from mypy_extensions import mypyc_attr
from typing_extensions import TypeVar

from sqlspec.core.pipeline import get_statement_pipeline_metrics, reset_statement_pipeline_cache
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
    from collections.abc import Iterator

    import sqlglot.expressions as exp


__all__ = (
    "CacheKey",
    "CacheStats",
    "CachedStatement",
    "FiltersView",
    "MultiLevelCache",
    "ParametersView",
    "UnifiedCache",
    "canonicalize_filters",
    "create_cache_key",
    "get_cache",
    "get_cache_config",
    "get_default_cache",
    "get_pipeline_metrics",
    "reset_pipeline_registry",
)

T = TypeVar("T")
CacheValueT = TypeVar("CacheValueT")


DEFAULT_MAX_SIZE: Final = 10000
DEFAULT_TTL_SECONDS: Final = 3600
CACHE_STATS_UPDATE_INTERVAL: Final = 100


CACHE_KEY_SLOTS: Final = ("_hash", "_key_data")
CACHE_NODE_SLOTS: Final = ("key", "value", "prev", "next", "timestamp", "access_count")
UNIFIED_CACHE_SLOTS: Final = ("_cache", "_lock", "_max_size", "_ttl", "_head", "_tail", "_stats")
CACHE_STATS_SLOTS: Final = ("hits", "misses", "evictions", "total_operations", "memory_usage")


[docs] @mypyc_attr(allow_interpreted_subclasses=False) class CacheKey: """Immutable cache key. Args: key_data: Tuple of hashable values that uniquely identify the cached item """ __slots__ = ("_hash", "_key_data")
[docs] def __init__(self, key_data: tuple[Any, ...]) -> None: """Initialize cache key. Args: key_data: Tuple of hashable values for the cache key """ self._key_data = key_data self._hash = hash(key_data)
@property def key_data(self) -> tuple[Any, ...]: """Get the key data tuple.""" return self._key_data
[docs] def __hash__(self) -> int: """Return cached hash value.""" return self._hash
[docs] def __eq__(self, other: object) -> bool: """Equality comparison.""" if type(other) is not CacheKey: return False other_key = other if self._hash != other_key._hash: return False return self._key_data == other_key._key_data
[docs] def __repr__(self) -> str: """String representation of the cache key.""" return f"CacheKey({self._key_data!r})"
[docs] @mypyc_attr(allow_interpreted_subclasses=False) class CacheStats: """Cache statistics tracking. Tracks cache metrics including hit rates, evictions, and memory usage. """ __slots__ = CACHE_STATS_SLOTS
[docs] def __init__(self) -> None: """Initialize cache statistics.""" self.hits = 0 self.misses = 0 self.evictions = 0 self.total_operations = 0 self.memory_usage = 0
@property def hit_rate(self) -> float: """Calculate cache hit rate as percentage.""" total = self.hits + self.misses return (self.hits / total * 100) if total > 0 else 0.0 @property def miss_rate(self) -> float: """Calculate cache miss rate as percentage.""" return 100.0 - self.hit_rate
[docs] def record_hit(self) -> None: """Record a cache hit.""" self.hits += 1 self.total_operations += 1
[docs] def record_miss(self) -> None: """Record a cache miss.""" self.misses += 1 self.total_operations += 1
[docs] def record_eviction(self) -> None: """Record a cache eviction.""" self.evictions += 1
[docs] def reset(self) -> None: """Reset all statistics.""" self.hits = 0 self.misses = 0 self.evictions = 0 self.total_operations = 0 self.memory_usage = 0
[docs] def __repr__(self) -> str: """String representation of cache statistics.""" return ( f"CacheStats(hit_rate={self.hit_rate:.1f}%, " f"hits={self.hits}, misses={self.misses}, " f"evictions={self.evictions}, ops={self.total_operations})" )
@mypyc_attr(allow_interpreted_subclasses=False) class CacheNode: """Internal cache node for LRU linked list implementation.""" __slots__ = CACHE_NODE_SLOTS def __init__(self, key: CacheKey, value: Any) -> None: """Initialize cache node. Args: key: Cache key for this node value: Cached value """ self.key = key self.value = value self.prev: CacheNode | None = None self.next: CacheNode | None = None self.timestamp = time.time() self.access_count = 1
[docs] @mypyc_attr(allow_interpreted_subclasses=False) class UnifiedCache: """Cache with LRU eviction and TTL support. Args: max_size: Maximum number of items to cache (LRU eviction when exceeded) ttl_seconds: Time-to-live in seconds (None for no expiration) """ __slots__ = UNIFIED_CACHE_SLOTS
[docs] def __init__(self, max_size: int = DEFAULT_MAX_SIZE, ttl_seconds: int | None = DEFAULT_TTL_SECONDS) -> None: """Initialize unified cache. Args: max_size: Maximum number of cache entries ttl_seconds: Time-to-live in seconds (None for no expiration) """ self._cache: dict[CacheKey, CacheNode] = {} self._lock = threading.RLock() self._max_size = max_size self._ttl = ttl_seconds self._stats = CacheStats() self._head = CacheNode(CacheKey(()), None) self._tail = CacheNode(CacheKey(()), None) self._head.next = self._tail self._tail.prev = self._head
[docs] def get(self, key: CacheKey) -> Any | None: """Get value from cache. Args: key: Cache key to lookup Returns: Cached value or None if not found or expired """ with self._lock: node = self._cache.get(key) if node is None: self._stats.record_miss() return None ttl = self._ttl if ttl is not None: current_time = time.time() if (current_time - node.timestamp) > ttl: self._remove_node(node) del self._cache[key] self._stats.record_miss() self._stats.record_eviction() return None self._move_to_head(node) node.access_count += 1 self._stats.record_hit() return node.value
[docs] def put(self, key: CacheKey, value: Any) -> None: """Put value in cache. Args: key: Cache key value: Value to cache """ with self._lock: existing_node = self._cache.get(key) if existing_node is not None: existing_node.value = value existing_node.timestamp = time.time() existing_node.access_count += 1 self._move_to_head(existing_node) return new_node = CacheNode(key, value) self._cache[key] = new_node self._add_to_head(new_node) if len(self._cache) > self._max_size: tail_node = self._tail.prev if tail_node is not None and tail_node is not self._head: self._remove_node(tail_node) del self._cache[tail_node.key] self._stats.record_eviction()
[docs] def delete(self, key: CacheKey) -> bool: """Delete entry from cache. Args: key: Cache key to delete Returns: True if key was found and deleted, False otherwise """ with self._lock: node: CacheNode | None = self._cache.get(key) if node is None: return False self._remove_node(node) del self._cache[key] return True
[docs] def clear(self) -> None: """Clear all cache entries.""" with self._lock: self._cache.clear() self._head.next = self._tail self._tail.prev = self._head self._stats.reset()
[docs] def size(self) -> int: """Get current cache size.""" return len(self._cache)
[docs] def is_empty(self) -> bool: """Check if cache is empty.""" return not self._cache
[docs] def get_stats(self) -> CacheStats: """Get cache statistics.""" return self._stats
def _add_to_head(self, node: CacheNode) -> None: """Add node to head of list.""" node.prev = self._head head_next: CacheNode | None = self._head.next node.next = head_next if head_next is not None: head_next.prev = node self._head.next = node def _remove_node(self, node: CacheNode) -> None: """Remove node from linked list.""" node_prev: CacheNode | None = node.prev node_next: CacheNode | None = node.next if node_prev is not None: node_prev.next = node_next if node_next is not None: node_next.prev = node_prev def _move_to_head(self, node: CacheNode) -> None: """Move node to head of list.""" self._remove_node(node) self._add_to_head(node)
[docs] def __len__(self) -> int: """Get current cache size.""" return len(self._cache)
[docs] def __contains__(self, key: CacheKey) -> bool: """Check if key exists in cache.""" with self._lock: node = self._cache.get(key) if node is None: return False ttl = self._ttl return not (ttl is not None and time.time() - node.timestamp > ttl)
_default_cache: UnifiedCache | None = None _cache_lock = threading.Lock() def get_default_cache() -> UnifiedCache: """Get the default unified cache instance. Returns: Singleton default cache instance """ global _default_cache if _default_cache is None: with _cache_lock: if _default_cache is None: _default_cache = UnifiedCache() return _default_cache def clear_all_caches() -> None: """Clear all cache instances.""" if _default_cache is not None: _default_cache.clear() cache = get_cache() cache.clear() def get_cache_statistics() -> dict[str, CacheStats]: """Get statistics from all cache instances. Returns: Dictionary mapping cache type to statistics """ stats = {} if _default_cache is not None: stats["default"] = _default_cache.get_stats() cache = get_cache() stats["multi_level"] = cache.get_stats() return stats _global_cache_config: "CacheConfig | None" = None
[docs] @mypyc_attr(allow_interpreted_subclasses=False) class CacheConfig: """Global cache configuration for SQLSpec."""
[docs] def __init__( self, *, compiled_cache_enabled: bool = True, sql_cache_enabled: bool = True, fragment_cache_enabled: bool = True, optimized_cache_enabled: bool = True, sql_cache_size: int = 1000, fragment_cache_size: int = 5000, optimized_cache_size: int = 2000, ) -> None: """Initialize cache configuration. Args: compiled_cache_enabled: Enable compiled SQL caching sql_cache_enabled: Enable SQL statement caching fragment_cache_enabled: Enable AST fragment caching optimized_cache_enabled: Enable optimized expression caching sql_cache_size: Maximum SQL cache entries fragment_cache_size: Maximum fragment cache entries optimized_cache_size: Maximum optimized cache entries """ self.compiled_cache_enabled = compiled_cache_enabled self.sql_cache_enabled = sql_cache_enabled self.fragment_cache_enabled = fragment_cache_enabled self.optimized_cache_enabled = optimized_cache_enabled self.sql_cache_size = sql_cache_size self.fragment_cache_size = fragment_cache_size self.optimized_cache_size = optimized_cache_size
def get_cache_config() -> CacheConfig: """Get the global cache configuration. Returns: Current global cache configuration instance """ global _global_cache_config if _global_cache_config is None: _global_cache_config = CacheConfig() return _global_cache_config def update_cache_config(config: CacheConfig) -> None: """Update the global cache configuration. Clears all existing caches when configuration changes. Args: config: New cache configuration to apply globally """ logger = get_logger("sqlspec.cache") logger.info("Cache configuration updated: %s", config) global _global_cache_config _global_cache_config = config unified_cache = get_default_cache() unified_cache.clear() cache = get_cache() cache.clear() logger = get_logger("sqlspec.cache") logger.info( "Cache configuration updated - all caches cleared", extra={ "compiled_cache_enabled": config.compiled_cache_enabled, "sql_cache_enabled": config.sql_cache_enabled, "fragment_cache_enabled": config.fragment_cache_enabled, "optimized_cache_enabled": config.optimized_cache_enabled, }, ) def get_cache_stats() -> dict[str, CacheStats]: """Get cache statistics from all caches. Returns: Dictionary of cache statistics """ return get_cache_statistics() def reset_cache_stats() -> None: """Reset all cache statistics.""" clear_all_caches() def log_cache_stats() -> None: """Log cache statistics.""" logger = get_logger("sqlspec.cache") stats = get_cache_stats() logger.info("Cache Statistics: %s", stats) @mypyc_attr(allow_interpreted_subclasses=False) class ParametersView: """Read-only view of parameters without copying. Provides read-only access to parameters without making copies, enabling zero-copy parameter access patterns. """ __slots__ = ("_named_ref", "_positional_ref") def __init__(self, positional: list[Any], named: dict[str, Any]) -> None: """Initialize parameters view. Args: positional: List of positional parameters (will be referenced, not copied) named: Dictionary of named parameters (will be referenced, not copied) """ self._positional_ref = positional self._named_ref = named def get_positional(self, index: int) -> Any: """Get positional parameter by index. Args: index: Parameter index Returns: Parameter value """ return self._positional_ref[index] def get_named(self, key: str) -> Any: """Get named parameter by key. Args: key: Parameter name Returns: Parameter value """ return self._named_ref[key] def has_named(self, key: str) -> bool: """Check if named parameter exists. Args: key: Parameter name Returns: True if parameter exists """ return key in self._named_ref @property def positional_count(self) -> int: """Number of positional parameters.""" return len(self._positional_ref) @property def named_count(self) -> int: """Number of named parameters.""" return len(self._named_ref) @mypyc_attr(allow_interpreted_subclasses=False) @dataclass(frozen=True) class CachedStatement: """Immutable cached statement result. This class stores compiled SQL and parameters in an immutable format that can be safely shared between different parts of the system without risk of mutation. Tuple parameters ensure no copying is needed. """ compiled_sql: str parameters: tuple[Any, ...] | dict[str, Any] | None # None allowed for static script compilation expression: Optional["exp.Expression"] def get_parameters_view(self) -> "ParametersView": """Get read-only parameter view. Returns: View object that provides read-only access to parameters """ if self.parameters is None: return ParametersView([], {}) return ParametersView(list(self.parameters), {}) def create_cache_key(level: str, key: str, dialect: str | None = None) -> str: """Create optimized cache key using string concatenation. Args: level: Cache level (statement, expression, parameter) key: Base cache key dialect: SQL dialect (optional) Returns: Optimized cache key string """ return f"{level}:{dialect or 'default'}:{key}"
[docs] @mypyc_attr(allow_interpreted_subclasses=False) class MultiLevelCache: """Single cache with namespace isolation - no connection pool complexity.""" __slots__ = ("_cache",)
[docs] def __init__(self, max_size: int = DEFAULT_MAX_SIZE, ttl_seconds: int | None = DEFAULT_TTL_SECONDS) -> None: """Initialize multi-level cache. Args: max_size: Maximum number of cache entries ttl_seconds: Time-to-live in seconds (None for no expiration) """ self._cache = UnifiedCache(max_size, ttl_seconds)
[docs] def get(self, level: str, key: str, dialect: str | None = None) -> Any | None: """Get value from cache with level and dialect namespace. Args: level: Cache level (e.g., "statement", "expression", "parameter") key: Cache key dialect: SQL dialect (optional) Returns: Cached value or None if not found """ full_key = create_cache_key(level, key, dialect) cache_key = CacheKey((full_key,)) return self._cache.get(cache_key)
[docs] def put(self, level: str, key: str, value: Any, dialect: str | None = None) -> None: """Put value in cache with level and dialect namespace. Args: level: Cache level (e.g., "statement", "expression", "parameter") key: Cache key value: Value to cache dialect: SQL dialect (optional) """ full_key = create_cache_key(level, key, dialect) cache_key = CacheKey((full_key,)) self._cache.put(cache_key, value)
[docs] def delete(self, level: str, key: str, dialect: str | None = None) -> bool: """Delete entry from cache. Args: level: Cache level key: Cache key to delete dialect: SQL dialect (optional) Returns: True if key was found and deleted, False otherwise """ full_key = create_cache_key(level, key, dialect) cache_key = CacheKey((full_key,)) return self._cache.delete(cache_key)
[docs] def clear(self) -> None: """Clear all cache entries.""" self._cache.clear()
[docs] def get_stats(self) -> CacheStats: """Get cache statistics.""" return self._cache.get_stats()
_multi_level_cache: MultiLevelCache | None = None def get_cache() -> MultiLevelCache: """Get the multi-level cache instance. Returns: Singleton multi-level cache instance """ global _multi_level_cache if _multi_level_cache is None: with _cache_lock: if _multi_level_cache is None: _multi_level_cache = MultiLevelCache() return _multi_level_cache @dataclass(frozen=True) class Filter: """Immutable filter that can be safely shared.""" field_name: str operation: str value: Any def __post_init__(self) -> None: """Validate filter parameters.""" if not self.field_name: msg = "Field name cannot be empty" raise ValueError(msg) if not self.operation: msg = "Operation cannot be empty" raise ValueError(msg) def canonicalize_filters(filters: "list[Filter]") -> "tuple[Filter, ...]": """Create canonical representation of filters for cache keys. Args: filters: List of filters to canonicalize Returns: Tuple of unique filters sorted by field_name, operation, then value """ if not filters: return () # Deduplicate and sort for canonical representation unique_filters = set(filters) return tuple(sorted(unique_filters, key=lambda f: (f.field_name, f.operation, str(f.value)))) @mypyc_attr(allow_interpreted_subclasses=False) class FiltersView: """Read-only view of filters without copying. Provides zero-copy access to filters with methods for querying, iteration, and canonical representation generation. """ __slots__ = ("_filters_ref",) def __init__(self, filters: "list[Any]") -> None: """Initialize filters view. Args: filters: List of filters (will be referenced, not copied) """ self._filters_ref = filters def __len__(self) -> int: """Get number of filters.""" return len(self._filters_ref) def __iter__(self) -> "Iterator[Any]": """Iterate over filters.""" return iter(self._filters_ref) def get_by_field(self, field_name: str) -> "list[Any]": """Get all filters for a specific field. Args: field_name: Field name to filter by Returns: List of filters matching the field name """ return [f for f in self._filters_ref if hasattr(f, "field_name") and f.field_name == field_name] def has_field(self, field_name: str) -> bool: """Check if any filter exists for a field. Args: field_name: Field name to check Returns: True if field has filters """ return any(hasattr(f, "field_name") and f.field_name == field_name for f in self._filters_ref) def to_canonical(self) -> "tuple[Any, ...]": """Create canonical representation for cache keys. Returns: Canonical tuple representation of filters """ # Convert to Filter objects if needed, then canonicalize filter_objects = [] for f in self._filters_ref: if isinstance(f, Filter): filter_objects.append(f) elif hasattr(f, "field_name") and hasattr(f, "operation") and hasattr(f, "value"): filter_objects.append(Filter(f.field_name, f.operation, f.value)) return canonicalize_filters(filter_objects) def get_pipeline_metrics() -> "list[dict[str, Any]]": """Return metrics for the shared statement pipeline cache when enabled.""" return get_statement_pipeline_metrics() def reset_pipeline_registry() -> None: """Clear shared statement pipeline caches and metrics.""" reset_statement_pipeline_cache()