Source code for sqlspec.adapters.arrow_odbc.data_dictionary

"""Generic data dictionary for arrow-odbc connections."""

from typing import TYPE_CHECKING, Any, ClassVar

from mypy_extensions import mypyc_attr

from sqlspec.data_dictionary import get_data_dictionary_loader, get_dialect_config
from sqlspec.driver import SyncDataDictionaryBase
from sqlspec.exceptions import SQLFileNotFoundError
from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo

if TYPE_CHECKING:
    from sqlspec.adapters.arrow_odbc.driver import ArrowOdbcDriver
    from sqlspec.core import SQL
    from sqlspec.data_dictionary._types import DialectConfig

__all__ = ("ArrowOdbcDataDictionary",)


[docs] @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class ArrowOdbcDataDictionary(SyncDataDictionaryBase): """Runtime-dialect data dictionary for generic ODBC connections.""" dialect: ClassVar[str] = "sqlite"
[docs] def __init__(self, dialect: str = "sqlite") -> None: super().__init__() self._dialect = dialect
[docs] def get_dialect_config(self) -> "DialectConfig": """Return the runtime dialect configuration for this data dictionary.""" return get_dialect_config(self._dialect)
[docs] def get_query(self, name: str) -> "SQL": """Return a named SQL query for the runtime dialect.""" loader = get_data_dictionary_loader() return loader.get_query(self._dialect, name)
[docs] def get_query_text(self, name: str) -> str: """Return raw SQL text for a named runtime dialect query.""" loader = get_data_dictionary_loader() return loader.get_query_text(self._dialect, name)
[docs] def resolve_schema(self, schema: str | None) -> str | None: """Return a schema name using runtime dialect defaults when missing.""" if schema is not None: return schema return self.get_dialect_config().default_schema
[docs] def get_version(self, driver: "ArrowOdbcDriver") -> VersionInfo | None: """Get database version information when the runtime dialect provides a query.""" driver_id = id(driver) if driver_id in self._version_fetch_attempted: return self._version_cache.get(driver_id) try: version_value = driver.select_value_or_none(self.get_query("version")) except SQLFileNotFoundError: self._log_version_unavailable(self._dialect, "no_query") self.cache_version(driver_id, None) return None except Exception: self._log_version_unavailable(self._dialect, "query_failed") self.cache_version(driver_id, None) return None if not version_value: self._log_version_unavailable(self._dialect, "missing") self.cache_version(driver_id, None) return None config = self.get_dialect_config() version_info = self.parse_version_with_pattern(config.version_pattern, str(version_value)) if version_info is None: self._log_version_unavailable(self._dialect, "parse_failed") else: self._log_version_detected(self._dialect, version_info) self.cache_version(driver_id, version_info) return version_info
[docs] def get_feature_flag(self, driver: "ArrowOdbcDriver", feature: str) -> bool: """Check whether the runtime dialect supports a feature.""" return self.resolve_feature_flag(feature, self.get_version(driver))
[docs] def get_optimal_type(self, driver: "ArrowOdbcDriver", type_category: str) -> str: """Get the optimal runtime dialect type for a category.""" _ = driver return self.get_dialect_config().get_optimal_type(type_category)
[docs] def get_tables(self, driver: "ArrowOdbcDriver", schema: str | None = None) -> list[TableMetadata]: """Get table metadata for dialects with bundled catalog queries.""" try: return driver.select( self.get_query("tables_by_schema"), schema_name=self.resolve_schema(schema), schema_type=TableMetadata ) except SQLFileNotFoundError: return []
[docs] def get_columns( self, driver: "ArrowOdbcDriver", table: str | None = None, schema: str | None = None ) -> list[ColumnMetadata]: """Get column metadata for dialects with bundled catalog queries.""" query_name = "columns_by_table" if table is not None else "columns_by_schema" parameters: dict[str, Any] = {"schema_name": self.resolve_schema(schema)} if table is not None: parameters["table_name"] = table try: return driver.select(self.get_query(query_name), schema_type=ColumnMetadata, **parameters) except SQLFileNotFoundError: return []
[docs] def get_indexes( self, driver: "ArrowOdbcDriver", table: str | None = None, schema: str | None = None ) -> list[IndexMetadata]: """Get index metadata for dialects with bundled catalog queries.""" query_name = "indexes_by_table" if table is not None else "indexes_by_schema" parameters: dict[str, Any] = {"schema_name": self.resolve_schema(schema)} if table is not None: parameters["table_name"] = table try: return driver.select(self.get_query(query_name), schema_type=IndexMetadata, **parameters) except SQLFileNotFoundError: return []
[docs] def get_foreign_keys( self, driver: "ArrowOdbcDriver", table: str | None = None, schema: str | None = None ) -> list[ForeignKeyMetadata]: """Get foreign-key metadata for dialects with bundled catalog queries.""" query_name = "foreign_keys_by_table" if table is not None else "foreign_keys_by_schema" parameters: dict[str, Any] = {"schema_name": self.resolve_schema(schema)} if table is not None: parameters["table_name"] = table try: return driver.select(self.get_query(query_name), schema_type=ForeignKeyMetadata, **parameters) except SQLFileNotFoundError: return []