Source code for sqlspec.adapters.mssql_python.migrations

"""mssql-python-specific migration tracker."""

import logging
import os
from contextlib import suppress
from typing import TYPE_CHECKING

from sqlspec.builder import CreateTable, sql
from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
from sqlspec.migrations.version import parse_version
from sqlspec.observability import resolve_db_system
from sqlspec.utils.logging import get_logger, log_with_context

if TYPE_CHECKING:
    from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase

__all__ = ("MssqlPythonAsyncMigrationTracker", "MssqlPythonSyncMigrationTracker")

logger = get_logger("sqlspec.migrations.mssql_python")


class MssqlPythonMigrationTrackerMixin:
    """T-SQL-specific migration table DDL and schema maintenance."""

    __slots__ = ()

    version_table: str

    def _get_create_table_sql(self) -> CreateTable:
        """Return T-SQL-compatible migration tracking table DDL."""
        return (
            sql
            .create_table(self.version_table)
            .column("version_num", "NVARCHAR(32)", primary_key=True)
            .column("version_type", "NVARCHAR(16)")
            .column("execution_sequence", "INT")
            .column("description", "NVARCHAR(MAX)")
            .column("applied_at", "DATETIME2(6)", default="SYSUTCDATETIME()", not_null=True)
            .column("execution_time_ms", "INT")
            .column("checksum", "NVARCHAR(64)")
            .column("applied_by", "NVARCHAR(255)")
            .column("replaces", "NVARCHAR(MAX)")
        )

    def _get_idempotent_create_table_sql_text(self) -> str:
        """Wrap CREATE TABLE in a T-SQL sys.tables existence probe."""
        schema_name, table_name = _split_schema_table(self.version_table)
        create_sql = self._get_create_table_sql_text().rstrip().rstrip(";")
        return (
            "IF NOT EXISTS (SELECT 1 FROM sys.tables "
            f"WHERE name = '{_escape_sql_literal(table_name)}' "
            f"AND schema_id = SCHEMA_ID('{_escape_sql_literal(schema_name)}')) "
            f"BEGIN {create_sql}; END;"
        )

    def _get_create_table_sql_text(self) -> str:
        """Render CREATE TABLE text without routing SQL Server types through sqlglot."""
        column_lines: list[str] = []
        for column_def in self._get_create_table_sql().columns:
            default_clause = f" DEFAULT {column_def.default}" if column_def.default else ""
            not_null_clause = " NOT NULL" if column_def.not_null else ""
            primary_key_clause = " PRIMARY KEY" if column_def.primary_key else ""
            column_lines.append(
                f"    {column_def.name} {column_def.dtype}{primary_key_clause}{default_clause}{not_null_clause}"
            )
        return f"CREATE TABLE {self.version_table} (\n" + ",\n".join(column_lines) + "\n)"

    def _get_existing_columns_sql(self) -> str:
        """Return T-SQL query text for migration tracking table columns."""
        schema_name, table_name = _split_schema_table(self.version_table)
        return f"""
            SELECT c.name AS column_name
            FROM sys.columns c
            INNER JOIN sys.tables t ON c.object_id = t.object_id
            WHERE t.name = '{_escape_sql_literal(table_name)}'
              AND t.schema_id = SCHEMA_ID('{_escape_sql_literal(schema_name)}')
        """

    def _get_add_column_sql_text(self, column_name: str) -> str | None:
        """Return T-SQL ALTER TABLE text for a missing migration column."""
        target_create = self._get_create_table_sql()
        column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None)
        if column_def is None:
            return None
        default_clause = f" DEFAULT {column_def.default}" if column_def.default else ""
        nullable_clause = " NOT NULL" if column_def.not_null else " NULL"
        return f"ALTER TABLE {self.version_table} ADD {column_def.name} {column_def.dtype}{default_clause}{nullable_clause};"


[docs] class MssqlPythonSyncMigrationTracker(MssqlPythonMigrationTrackerMixin, SyncMigrationTracker): """T-SQL sync migration tracker."""
[docs] def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None: """Create the migration tracking table if it does not exist.""" driver.execute_script(self._get_idempotent_create_table_sql_text()) driver.commit() self._migrate_schema_if_needed(driver)
[docs] def record_migration( self, driver: "SyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str ) -> None: """Record a successfully applied migration with T-SQL-compatible metadata.""" parsed_version = parse_version(version) version_type = parsed_version.type.value result = driver.execute(self._get_next_execution_sequence_sql()) next_sequence = result.get_data()[0]["next_seq"] if result.data else 1 driver.execute( self._get_record_migration_sql( version, version_type, next_sequence, description, execution_time_ms, checksum, os.environ.get("USER", "unknown"), ) ) driver.commit()
def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None: """Check and add missing tracking table columns through SQL Server catalog views.""" try: rows = driver.select(self._get_existing_columns_sql()) existing_columns = {str(row["column_name"]).lower() for row in rows if row.get("column_name") is not None} missing_columns = self._detect_missing_columns(existing_columns) if not missing_columns: return for column_name in sorted(missing_columns): self._add_column(driver, column_name) driver.commit() except Exception as exc: with suppress(Exception): driver.rollback() log_with_context( logger, logging.ERROR, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="schema_check", status="failed", error_type=type(exc).__name__, ) def _add_column(self, driver: "SyncDriverAdapterBase", column_name: str) -> None: """Add a single missing migration tracking column.""" add_column_sql = self._get_add_column_sql_text(column_name) if add_column_sql is None: return driver.execute_script(add_column_sql)
[docs] class MssqlPythonAsyncMigrationTracker(MssqlPythonMigrationTrackerMixin, AsyncMigrationTracker): """T-SQL async migration tracker."""
[docs] async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None: """Create the migration tracking table if it does not exist.""" await driver.execute_script(self._get_idempotent_create_table_sql_text()) await driver.commit() await self._migrate_schema_if_needed(driver)
[docs] async def record_migration( self, driver: "AsyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str ) -> None: """Record a successfully applied migration with T-SQL-compatible metadata.""" parsed_version = parse_version(version) version_type = parsed_version.type.value result = await driver.execute(self._get_next_execution_sequence_sql()) next_sequence = result.get_data()[0]["next_seq"] if result.data else 1 await driver.execute( self._get_record_migration_sql( version, version_type, next_sequence, description, execution_time_ms, checksum, os.environ.get("USER", "unknown"), ) ) await driver.commit()
async def _migrate_schema_if_needed(self, driver: "AsyncDriverAdapterBase") -> None: """Check and add missing tracking table columns through SQL Server catalog views.""" try: rows = await driver.select(self._get_existing_columns_sql()) existing_columns = {str(row["column_name"]).lower() for row in rows if row.get("column_name") is not None} missing_columns = self._detect_missing_columns(existing_columns) if not missing_columns: return for column_name in sorted(missing_columns): await self._add_column(driver, column_name) await driver.commit() except Exception as exc: with suppress(Exception): await driver.rollback() log_with_context( logger, logging.ERROR, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="schema_check", status="failed", error_type=type(exc).__name__, ) async def _add_column(self, driver: "AsyncDriverAdapterBase", column_name: str) -> None: """Add a single missing migration tracking column.""" add_column_sql = self._get_add_column_sql_text(column_name) if add_column_sql is None: return await driver.execute_script(add_column_sql)
def _escape_sql_literal(value: str) -> str: """Escape a string for inclusion in a T-SQL string literal.""" return value.replace("'", "''") def _split_schema_table(table_name: str) -> tuple[str, str]: """Split a schema-qualified table name into schema and table parts.""" if "." not in table_name: return "dbo", table_name schema_name, bare_table_name = table_name.rsplit(".", 1) return schema_name.strip("[]") or "dbo", bare_table_name.strip("[]")