Source code for sqlspec.adapters.mssql_python.data_dictionary

"""mssql-python data dictionary."""

from typing import TYPE_CHECKING, Any, ClassVar, cast

from mypy_extensions import mypyc_attr

from sqlspec.data_dictionary import get_dialect_config
from sqlspec.data_dictionary.dialects.mssql import (
    extract_mssql_version_value,
    is_mssql_azure_sql,
    list_mssql_available_features,
    merge_mssql_table_lists,
    mssql_supports_native_json,
    parse_mssql_engine_edition,
    parse_mssql_version_components,
    resolve_mssql_feature_flag,
)
from sqlspec.driver import AsyncDataDictionaryBase, SyncDataDictionaryBase
from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
    from sqlspec.data_dictionary._types import DialectConfig

logger = get_logger("sqlspec.adapters.mssql_python.data_dictionary")

__all__ = ("MssqlPythonAsyncDataDictionary", "MssqlPythonSyncDataDictionary", "MssqlVersionInfo")


class MssqlVersionInfo(VersionInfo):
    """MSSQL database version info with build, revision, and Azure SQL detection."""

    def __init__(
        self,
        major: int,
        minor: int = 0,
        build: int = 0,
        revision: int = 0,
        edition: str | None = None,
        engine_edition: int | None = None,
    ) -> None:
        super().__init__(major, minor, build)
        self.build = build
        self.revision = revision
        self.edition = edition
        self.engine_edition = engine_edition
        self.is_azure_sql = is_mssql_azure_sql(engine_edition)

    def supports_native_json(self) -> bool:
        """Return whether this server supports the native JSON type."""
        return mssql_supports_native_json(self.major, is_azure_sql=self.is_azure_sql)

    def __str__(self) -> str:
        """String representation of version info."""
        version_str = f"{self.major}.{self.minor}.{self.build}.{self.revision}"
        if self.edition:
            version_str += f" ({self.edition})"
        if self.is_azure_sql:
            version_str += " [Azure]"
        return version_str


class _MssqlDataDictionaryMixin:
    """Shared helpers for MSSQL data dictionaries."""

    dialect: ClassVar[str] = "mssql"

    def get_dialect_config(self) -> "DialectConfig":
        """Return the dialect configuration for this data dictionary."""
        return get_dialect_config(type(self).dialect)

    def resolve_schema(self, schema: str | None) -> str | None:
        """Return a schema name using dialect defaults when missing."""
        if schema is not None:
            return schema
        return self.get_dialect_config().default_schema

    def list_available_features(self) -> list[str]:
        """List available feature flags for this dialect."""
        return list_mssql_available_features(self.get_dialect_config())

    def _build_version_info(
        self, version_value: str | None, edition: str | None, engine_edition_value: Any
    ) -> MssqlVersionInfo | None:
        if not version_value:
            return None
        major, minor, build, revision = parse_mssql_version_components(version_value)
        return MssqlVersionInfo(
            major,
            minor,
            build,
            revision,
            edition=edition,
            engine_edition=parse_mssql_engine_edition(engine_edition_value),
        )

    def _get_optimal_type_from_version(self, version_info: MssqlVersionInfo | None, type_category: str) -> str:
        if type_category in {"json", "jsonb"} and version_info is not None and version_info.supports_native_json():
            return "JSON"
        return self.get_dialect_config().get_optimal_type(type_category)


[docs] @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class MssqlPythonSyncDataDictionary(_MssqlDataDictionaryMixin, SyncDataDictionaryBase): """MSSQL sync data dictionary.""" dialect: ClassVar[str] = "mssql"
[docs] def __init__(self) -> None: super().__init__()
[docs] def get_version(self, driver: Any) -> MssqlVersionInfo | None: """Get SQL Server version information.""" driver_id = id(driver) if driver_id in self._version_fetch_attempted: return cast("MssqlVersionInfo | None", self._version_cache.get(driver_id)) row = driver.select_one_or_none(self.get_query_text("version")) if not row: self._log_version_unavailable(type(self).dialect, "missing") self.cache_version(driver_id, None) return None version_value = extract_mssql_version_value( _row_value(row, "product_version") or _row_value(row, "version_string", "version") ) edition_value = _row_value(row, "edition") edition = str(edition_value) if edition_value is not None else None version_info = self._build_version_info(version_value, edition, _row_value(row, "engine_edition")) if version_info is None: self._log_version_unavailable(type(self).dialect, "parse_failed") self.cache_version(driver_id, None) return None self._log_version_detected(type(self).dialect, version_info) self.cache_version(driver_id, version_info) return version_info
[docs] def get_feature_flag(self, driver: Any, feature: str) -> bool: """Check whether SQL Server supports a feature.""" version_info = self.get_version(driver) return resolve_mssql_feature_flag( feature, major=version_info.major if version_info is not None else 0, is_azure_sql=bool(version_info and version_info.is_azure_sql), config=self.get_dialect_config(), version_info=version_info, )
[docs] def get_optimal_type(self, driver: Any, type_category: str) -> str: """Get optimal SQL Server type for a category.""" return self._get_optimal_type_from_version(self.get_version(driver), type_category)
[docs] def get_tables(self, driver: Any, schema: str | None = None) -> list[TableMetadata]: """Get tables sorted by dependency order with catalog fallback.""" schema_name = self.resolve_schema(schema) self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="tables") ordered = cast( "list[TableMetadata]", driver.select(self.get_query("tables_by_schema"), schema_name=schema_name, schema_type=TableMetadata), ) all_rows = cast( "list[TableMetadata]", driver.select(self.get_query("all_tables_by_schema"), schema_name=schema_name, schema_type=TableMetadata), ) return merge_mssql_table_lists(ordered, all_rows)
[docs] def get_columns(self, driver: Any, table: str | None = None, schema: str | None = None) -> list[ColumnMetadata]: """Get column information for a table or schema.""" schema_name = self.resolve_schema(schema) if table is None: self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="columns") return cast( "list[ColumnMetadata]", driver.select(self.get_query("columns_by_schema"), schema_name=schema_name, schema_type=ColumnMetadata), ) self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="columns") return cast( "list[ColumnMetadata]", driver.select( self.get_query("columns_by_table"), schema_name=schema_name, table_name=table, schema_type=ColumnMetadata, ), )
[docs] def get_indexes(self, driver: Any, table: str | None = None, schema: str | None = None) -> list[IndexMetadata]: """Get index metadata for a table or schema.""" schema_name = self.resolve_schema(schema) if table is None: self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="indexes") return cast( "list[IndexMetadata]", driver.select(self.get_query("indexes_by_schema"), schema_name=schema_name, schema_type=IndexMetadata), ) self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="indexes") return cast( "list[IndexMetadata]", driver.select( self.get_query("indexes_by_table"), schema_name=schema_name, table_name=table, schema_type=IndexMetadata ), )
[docs] def get_foreign_keys( self, driver: Any, table: str | None = None, schema: str | None = None ) -> list[ForeignKeyMetadata]: """Get foreign key metadata.""" schema_name = self.resolve_schema(schema) if table is None: self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="foreign_keys") return cast( "list[ForeignKeyMetadata]", driver.select( self.get_query("foreign_keys_by_schema"), schema_name=schema_name, schema_type=ForeignKeyMetadata ), ) self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="foreign_keys") return cast( "list[ForeignKeyMetadata]", driver.select( self.get_query("foreign_keys_by_table"), schema_name=schema_name, table_name=table, schema_type=ForeignKeyMetadata, ), )
[docs] @mypyc_attr(allow_interpreted_subclasses=True, native_class=False) class MssqlPythonAsyncDataDictionary(_MssqlDataDictionaryMixin, AsyncDataDictionaryBase): """MSSQL async data dictionary.""" dialect: ClassVar[str] = "mssql"
[docs] def __init__(self) -> None: super().__init__()
[docs] async def get_version(self, driver: Any) -> MssqlVersionInfo | None: """Get SQL Server version information.""" driver_id = id(driver) if driver_id in self._version_fetch_attempted: return cast("MssqlVersionInfo | None", self._version_cache.get(driver_id)) row = await driver.select_one_or_none(self.get_query_text("version")) if not row: self._log_version_unavailable(type(self).dialect, "missing") self.cache_version(driver_id, None) return None version_value = extract_mssql_version_value( _row_value(row, "product_version") or _row_value(row, "version_string", "version") ) edition_value = _row_value(row, "edition") edition = str(edition_value) if edition_value is not None else None version_info = self._build_version_info(version_value, edition, _row_value(row, "engine_edition")) if version_info is None: self._log_version_unavailable(type(self).dialect, "parse_failed") self.cache_version(driver_id, None) return None self._log_version_detected(type(self).dialect, version_info) self.cache_version(driver_id, version_info) return version_info
[docs] async def get_feature_flag(self, driver: Any, feature: str) -> bool: """Check whether SQL Server supports a feature.""" version_info = await self.get_version(driver) return resolve_mssql_feature_flag( feature, major=version_info.major if version_info is not None else 0, is_azure_sql=bool(version_info and version_info.is_azure_sql), config=self.get_dialect_config(), version_info=version_info, )
[docs] async def get_optimal_type(self, driver: Any, type_category: str) -> str: """Get optimal SQL Server type for a category.""" return self._get_optimal_type_from_version(await self.get_version(driver), type_category)
[docs] async def get_tables(self, driver: Any, schema: str | None = None) -> list[TableMetadata]: """Get tables sorted by dependency order with catalog fallback.""" schema_name = self.resolve_schema(schema) self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="tables") ordered = cast( "list[TableMetadata]", await driver.select(self.get_query("tables_by_schema"), schema_name=schema_name, schema_type=TableMetadata), ) all_rows = cast( "list[TableMetadata]", await driver.select( self.get_query("all_tables_by_schema"), schema_name=schema_name, schema_type=TableMetadata ), ) return merge_mssql_table_lists(ordered, all_rows)
[docs] async def get_columns( self, driver: Any, table: str | None = None, schema: str | None = None ) -> list[ColumnMetadata]: """Get column information for a table or schema.""" schema_name = self.resolve_schema(schema) if table is None: self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="columns") return cast( "list[ColumnMetadata]", await driver.select( self.get_query("columns_by_schema"), schema_name=schema_name, schema_type=ColumnMetadata ), ) self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="columns") return cast( "list[ColumnMetadata]", await driver.select( self.get_query("columns_by_table"), schema_name=schema_name, table_name=table, schema_type=ColumnMetadata, ), )
[docs] async def get_indexes( self, driver: Any, table: str | None = None, schema: str | None = None ) -> list[IndexMetadata]: """Get index metadata for a table or schema.""" schema_name = self.resolve_schema(schema) if table is None: self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="indexes") return cast( "list[IndexMetadata]", await driver.select( self.get_query("indexes_by_schema"), schema_name=schema_name, schema_type=IndexMetadata ), ) self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="indexes") return cast( "list[IndexMetadata]", await driver.select( self.get_query("indexes_by_table"), schema_name=schema_name, table_name=table, schema_type=IndexMetadata ), )
[docs] async def get_foreign_keys( self, driver: Any, table: str | None = None, schema: str | None = None ) -> list[ForeignKeyMetadata]: """Get foreign key metadata.""" schema_name = self.resolve_schema(schema) if table is None: self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="foreign_keys") return cast( "list[ForeignKeyMetadata]", await driver.select( self.get_query("foreign_keys_by_schema"), schema_name=schema_name, schema_type=ForeignKeyMetadata ), ) self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="foreign_keys") return cast( "list[ForeignKeyMetadata]", await driver.select( self.get_query("foreign_keys_by_table"), schema_name=schema_name, table_name=table, schema_type=ForeignKeyMetadata, ), )
def _row_value(row: object, *names: str) -> Any: """Return the first named value from a row-like object.""" if isinstance(row, dict): for name in names: if name in row: return row[name] upper_name = name.upper() if upper_name in row: return row[upper_name] return None return getattr(row, names[0], None) if names else None