Source code for sqlspec.observability._config

"""Configuration objects for the observability suite."""

from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:  # pragma: no cover - import cycle guard
    from sqlspec.config import LifecycleConfig
    from sqlspec.observability._formatters._base import CloudLogFormatter
    from sqlspec.observability._observer import StatementEvent
    from sqlspec.observability._sampling import SamplingConfig


StatementObserver = Callable[["StatementEvent"], None]
LifecycleHook = Callable[[dict[str, Any]], None]


[docs] class RedactionConfig: """Controls SQL and parameter redaction before observers run.""" __slots__ = ("mask_literals", "mask_parameters", "parameter_allow_list")
[docs] def __init__( self, *, mask_parameters: bool | None = None, mask_literals: bool | None = None, parameter_allow_list: tuple[str, ...] | Iterable[str] | None = None, ) -> None: self.mask_parameters = mask_parameters self.mask_literals = mask_literals self.parameter_allow_list = tuple(parameter_allow_list) if parameter_allow_list is not None else None
def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior msg = "RedactionConfig objects are mutable and unhashable" raise TypeError(msg)
[docs] def copy(self) -> "RedactionConfig": """Return a copy to avoid sharing mutable state.""" allow_list = tuple(self.parameter_allow_list) if self.parameter_allow_list else None return RedactionConfig( mask_parameters=self.mask_parameters, mask_literals=self.mask_literals, parameter_allow_list=allow_list )
def __repr__(self) -> str: return f"RedactionConfig(mask_parameters={self.mask_parameters!r}, mask_literals={self.mask_literals!r}, parameter_allow_list={self.parameter_allow_list!r})" def __eq__(self, other: object) -> bool: if not isinstance(other, RedactionConfig): return NotImplemented return ( self.mask_parameters == other.mask_parameters and self.mask_literals == other.mask_literals and self.parameter_allow_list == other.parameter_allow_list )
[docs] class TelemetryConfig: """Span emission and tracer provider settings.""" __slots__ = ("enable_spans", "provider_factory", "resource_attributes")
[docs] def __init__( self, *, enable_spans: bool = False, provider_factory: Callable[[], Any] | None = None, resource_attributes: dict[str, Any] | None = None, ) -> None: self.enable_spans = enable_spans self.provider_factory = provider_factory self.resource_attributes = dict(resource_attributes) if resource_attributes else None
def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior msg = "TelemetryConfig objects are mutable and unhashable" raise TypeError(msg)
[docs] def copy(self) -> "TelemetryConfig": """Return a shallow copy preserving optional dictionaries.""" attributes = dict(self.resource_attributes) if self.resource_attributes else None return TelemetryConfig( enable_spans=self.enable_spans, provider_factory=self.provider_factory, resource_attributes=attributes )
def __repr__(self) -> str: return f"TelemetryConfig(enable_spans={self.enable_spans!r}, provider_factory={self.provider_factory!r}, resource_attributes={self.resource_attributes!r})" def __eq__(self, other: object) -> bool: if not isinstance(other, TelemetryConfig): return NotImplemented return ( self.enable_spans == other.enable_spans and self.provider_factory == other.provider_factory and self.resource_attributes == other.resource_attributes )
[docs] class LoggingConfig: """Controls log output format and verbosity.""" __slots__ = ( "include_driver_name", "include_sql_hash", "include_trace_context", "parameter_truncation_count", "sql_truncation_length", )
[docs] def __init__( self, *, include_sql_hash: bool = True, sql_truncation_length: int = 2000, parameter_truncation_count: int = 100, include_trace_context: bool = True, include_driver_name: bool = False, ) -> None: self.include_sql_hash = include_sql_hash self.sql_truncation_length = sql_truncation_length self.parameter_truncation_count = parameter_truncation_count self.include_trace_context = include_trace_context self.include_driver_name = include_driver_name
def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior msg = "LoggingConfig objects are mutable and unhashable" raise TypeError(msg)
[docs] def copy(self) -> "LoggingConfig": """Return a shallow copy of the logging configuration.""" return LoggingConfig( include_sql_hash=self.include_sql_hash, sql_truncation_length=self.sql_truncation_length, parameter_truncation_count=self.parameter_truncation_count, include_trace_context=self.include_trace_context, include_driver_name=self.include_driver_name, )
def __repr__(self) -> str: return ( f"LoggingConfig(include_sql_hash={self.include_sql_hash!r}, sql_truncation_length={self.sql_truncation_length!r}, " f"parameter_truncation_count={self.parameter_truncation_count!r}, include_trace_context={self.include_trace_context!r}, " f"include_driver_name={self.include_driver_name!r})" ) def __eq__(self, other: object) -> bool: if not isinstance(other, LoggingConfig): return NotImplemented return ( self.include_sql_hash == other.include_sql_hash and self.sql_truncation_length == other.sql_truncation_length and self.parameter_truncation_count == other.parameter_truncation_count and self.include_trace_context == other.include_trace_context and self.include_driver_name == other.include_driver_name )
[docs] class ObservabilityConfig: """Aggregates lifecycle hooks, observers, and telemetry toggles.""" __slots__ = ( "cloud_formatter", "lifecycle", "logging", "print_sql", "redaction", "sampling", "statement_observers", "telemetry", )
[docs] def __init__( self, *, lifecycle: "LifecycleConfig | None" = None, print_sql: bool | None = None, statement_observers: tuple[StatementObserver, ...] | Iterable[StatementObserver] | None = None, telemetry: "TelemetryConfig | None" = None, redaction: "RedactionConfig | None" = None, logging: "LoggingConfig | None" = None, sampling: "SamplingConfig | None" = None, cloud_formatter: "CloudLogFormatter | None" = None, ) -> None: self.lifecycle = lifecycle self.print_sql = print_sql self.statement_observers = tuple(statement_observers) if statement_observers is not None else None self.telemetry = telemetry self.redaction = redaction self.logging = logging self.sampling = sampling self.cloud_formatter = cloud_formatter
def __hash__(self) -> int: # pragma: no cover - explicit to mirror dataclass behavior msg = "ObservabilityConfig objects are mutable and unhashable" raise TypeError(msg)
[docs] def copy(self) -> "ObservabilityConfig": """Return a deep copy of the configuration.""" lifecycle_copy = _normalize_lifecycle(self.lifecycle) observers = tuple(self.statement_observers) if self.statement_observers else None telemetry_copy = self.telemetry.copy() if self.telemetry else None redaction_copy = self.redaction.copy() if self.redaction else None logging_copy = self.logging.copy() if self.logging else None sampling_copy = self.sampling.copy() if self.sampling else None return ObservabilityConfig( lifecycle=lifecycle_copy, print_sql=self.print_sql, statement_observers=observers, telemetry=telemetry_copy, redaction=redaction_copy, logging=logging_copy, sampling=sampling_copy, cloud_formatter=self.cloud_formatter, )
[docs] @classmethod def merge( cls, base_config: "ObservabilityConfig | None", override_config: "ObservabilityConfig | None" ) -> "ObservabilityConfig": """Merge registry-level and adapter-level configuration objects.""" if base_config is None and override_config is None: return cls() base = base_config.copy() if base_config else cls() override = override_config if override is None: return base lifecycle = _merge_lifecycle(base.lifecycle, override.lifecycle) observers: tuple[StatementObserver, ...] | None if base.statement_observers and override.statement_observers: observers = base.statement_observers + tuple(override.statement_observers) elif override.statement_observers: observers = tuple(override.statement_observers) else: observers = base.statement_observers print_sql = base.print_sql if override.print_sql is not None: print_sql = override.print_sql telemetry = override.telemetry.copy() if override.telemetry else base.telemetry redaction = _merge_redaction(base.redaction, override.redaction) logging = _merge_logging(base.logging, override.logging) sampling = _merge_sampling(base.sampling, override.sampling) cloud_formatter = override.cloud_formatter if override.cloud_formatter is not None else base.cloud_formatter return ObservabilityConfig( lifecycle=lifecycle, print_sql=print_sql, statement_observers=observers, telemetry=telemetry, redaction=redaction, logging=logging, sampling=sampling, cloud_formatter=cloud_formatter, )
def __repr__(self) -> str: return ( f"ObservabilityConfig(lifecycle={self.lifecycle!r}, print_sql={self.print_sql!r}, statement_observers={self.statement_observers!r}, telemetry={self.telemetry!r}, " f"redaction={self.redaction!r}, logging={self.logging!r}, sampling={self.sampling!r}, cloud_formatter={self.cloud_formatter!r})" ) def __eq__(self, other: object) -> bool: if not isinstance(other, ObservabilityConfig): return NotImplemented return ( _normalize_lifecycle(self.lifecycle) == _normalize_lifecycle(other.lifecycle) and self.print_sql == other.print_sql and self.statement_observers == other.statement_observers and self.telemetry == other.telemetry and self.redaction == other.redaction and self.logging == other.logging and self.sampling == other.sampling and self.cloud_formatter == other.cloud_formatter )
def _merge_redaction(base: "RedactionConfig | None", override: "RedactionConfig | None") -> "RedactionConfig | None": if base is None and override is None: return None if override is None: return base.copy() if base else None if base is None: return override.copy() merged = base.copy() if override.mask_parameters is not None: merged.mask_parameters = override.mask_parameters if override.mask_literals is not None: merged.mask_literals = override.mask_literals if override.parameter_allow_list is not None: merged.parameter_allow_list = tuple(override.parameter_allow_list) return merged def _merge_logging(base: "LoggingConfig | None", override: "LoggingConfig | None") -> "LoggingConfig | None": if base is None and override is None: return None if override is None: return base.copy() if base else None return override.copy() def _merge_sampling(base: "SamplingConfig | None", override: "SamplingConfig | None") -> "SamplingConfig | None": if base is None and override is None: return None if override is None: return base.copy() if base else None if base is None: return override.copy() merged = base.copy() if override.sample_rate != 1.0: merged.sample_rate = override.sample_rate if override.deterministic: merged.deterministic = override.deterministic if override.force_sample_on_error: merged.force_sample_on_error = override.force_sample_on_error if override.force_sample_slow_queries_ms is not None: merged.force_sample_slow_queries_ms = override.force_sample_slow_queries_ms return merged def _normalize_lifecycle(config: "LifecycleConfig | None") -> "LifecycleConfig | None": if config is None: return None normalized: dict[str, list[LifecycleHook]] = {} for event, hooks in config.items(): normalized[event] = list(cast("Iterable[LifecycleHook]", hooks)) return cast("LifecycleConfig", normalized) def _merge_lifecycle(base: "LifecycleConfig | None", override: "LifecycleConfig | None") -> "LifecycleConfig | None": if base is None and override is None: return None if base is None: return _normalize_lifecycle(override) if override is None: return _normalize_lifecycle(base) merged_dict: dict[str, list[LifecycleHook]] = ( cast("dict[str, list[LifecycleHook]]", _normalize_lifecycle(base)) or {} ) for event, hooks in override.items(): merged_dict.setdefault(event, []) merged_dict[event].extend(cast("Iterable[LifecycleHook]", hooks)) return cast("LifecycleConfig", merged_dict) __all__ = ( "LifecycleHook", "LoggingConfig", "ObservabilityConfig", "RedactionConfig", "StatementObserver", "TelemetryConfig", )