Source code for sqlspec.migrations.tracker

"""Migration version tracking for SQLSpec.

This module provides functionality to track applied migrations in the database.
"""

import logging
import os
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

from rich.console import Console

from sqlspec.builder import sql
from sqlspec.migrations.base import BaseMigrationTracker
from sqlspec.migrations.version import parse_version
from sqlspec.observability import resolve_db_system
from sqlspec.utils.logging import get_logger, log_with_context

if TYPE_CHECKING:
    from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase

__all__ = ("AsyncMigrationTracker", "SyncMigrationTracker")

logger = get_logger("sqlspec.migrations.tracker")


def _extract_column_name(metadata: Any) -> "str | None":
    """Extract column name from a metadata entry."""
    if isinstance(metadata, Mapping):
        value = metadata.get("column_name")
        if value is None:
            value = metadata.get("COLUMN_NAME")
        return str(value).lower() if value is not None else None
    value = getattr(metadata, "column_name", None)
    if value is not None:
        return str(value).lower()
    return None


[docs] class SyncMigrationTracker(BaseMigrationTracker["SyncDriverAdapterBase"]): """Synchronous migration version tracker.""" def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None: """Check for and add any missing columns to the tracking table. Uses the adapter's data_dictionary to query existing columns, then compares to the target schema and adds missing columns one by one. Args: driver: The database driver to use. """ console = Console() try: columns_data = driver.data_dictionary.get_columns(driver, self.version_table) if not columns_data: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="table_check", status="missing", ) columns_data = [] existing_columns = {name for col in columns_data if (name := _extract_column_name(col)) is not None} missing_columns = self._detect_missing_columns(existing_columns) if not missing_columns: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="schema_check", status="current", ) return if self._should_echo(): console.print( f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]" ) for col_name in sorted(missing_columns): self._add_column(driver, col_name) driver.commit() if self._should_echo(): console.print("[green]Migration tracking table schema updated successfully[/]") except Exception as exc: log_with_context( logger, logging.ERROR, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="schema_check", status="failed", error_type=type(exc).__name__, ) def _add_column(self, driver: "SyncDriverAdapterBase", column_name: str) -> None: """Add a single column to the tracking table. Args: driver: The database driver to use. column_name: Name of the column to add (lowercase). """ target_create = self._get_create_table_sql() column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None) if not column_def: return alter_sql = sql.alter_table(self.version_table).add_column( name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null ) driver.execute(alter_sql) log_with_context( logger, logging.INFO, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, column_name=column_name, operation="schema_update", status="column_added", )
[docs] def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None: """Create the migration tracking table if it doesn't exist. Also checks for and adds any missing columns to support schema migrations. Args: driver: The database driver to use. """ driver.execute(self._get_create_table_sql()) self._safe_commit(driver) self._migrate_schema_if_needed(driver)
[docs] def get_current_version(self, driver: "SyncDriverAdapterBase") -> str | None: """Get the latest applied migration version. Args: driver: The database driver to use. Returns: The current version number or None if no migrations applied. """ result = driver.execute(self._get_current_version_sql()) current = result.get_data()[0]["version_num"] if result.data else None log_with_context( logger, logging.DEBUG, "migration.history", db_system=resolve_db_system(type(driver).__name__), current_version=current, status="current", ) return current
[docs] def get_applied_migrations(self, driver: "SyncDriverAdapterBase") -> "list[dict[str, Any]]": """Get all applied migrations in order. Args: driver: The database driver to use. Returns: List of migration records. """ result = driver.execute(self._get_applied_migrations_sql()) applied = result.get_data() log_with_context( logger, logging.DEBUG, "migration.history", db_system=resolve_db_system(type(driver).__name__), applied_count=len(applied), status="listed", ) return applied
[docs] def record_migration( self, driver: "SyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str ) -> None: """Record a successfully applied migration. Parses version to determine type (sequential or timestamp) and auto-increments execution_sequence for application order tracking. Args: driver: The database driver to use. version: Version number of the migration. description: Description of the migration. execution_time_ms: Execution time in milliseconds. checksum: MD5 checksum of the migration content. """ parsed_version = parse_version(version) version_type = parsed_version.type.value result = driver.execute(self._get_next_execution_sequence_sql()) next_sequence = result.get_data()[0]["next_seq"] if result.data else 1 driver.execute( self._get_record_migration_sql( version, version_type, next_sequence, description, execution_time_ms, checksum, os.environ.get("USER", "unknown"), ) ) self._safe_commit(driver) log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), version=version, operation="record", status="recorded", )
[docs] def remove_migration(self, driver: "SyncDriverAdapterBase", version: str) -> None: """Remove a migration record (used during downgrade). Args: driver: The database driver to use. version: Version number to remove. """ driver.execute(self._get_remove_migration_sql(version)) self._safe_commit(driver) log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), version=version, operation="remove", status="removed", )
[docs] def update_version_record(self, driver: "SyncDriverAdapterBase", old_version: str, new_version: str) -> None: """Update migration version record from timestamp to sequential. Updates version_num and version_type while preserving execution_sequence, applied_at, and other tracking metadata. Used during fix command. Idempotent: If the version is already updated, logs and continues without error. This allows fix command to be safely re-run after pulling changes. Args: driver: The database driver to use. old_version: Current timestamp version string. new_version: New sequential version string. Raises: ValueError: If neither old_version nor new_version found in database. """ parsed_new_version = parse_version(new_version) new_version_type = parsed_new_version.type.value result = driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type)) if result.rows_affected == 0: check_result = driver.execute(self._get_applied_migrations_sql()) applied_versions = {row["version_num"] for row in check_result.get_data()} if check_result.data else set() if new_version in applied_versions: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), old_version=old_version, new_version=new_version, operation="version_update", status="skipped", ) return msg = f"Migration version {old_version} not found in database" raise ValueError(msg) self._safe_commit(driver) log_with_context( logger, logging.INFO, "migration.track", db_system=resolve_db_system(type(driver).__name__), old_version=old_version, new_version=new_version, operation="version_update", status="updated", )
[docs] def replace_with_squash( self, driver: "SyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]", description: str, checksum: str, ) -> None: """Replace multiple migration records with a single squashed record. Deletes all replaced version records and inserts a new record for the squashed migration with metadata about which versions it replaces. Args: driver: The database driver to use. squashed_version: Version number of the squashed migration. replaced_versions: List of version strings being replaced. description: Description of the squashed migration. checksum: MD5 checksum of the squashed migration content. """ driver.execute(self._get_delete_versions_sql(replaced_versions)) result = driver.execute(self._get_next_execution_sequence_sql()) next_sequence = result.get_data()[0]["next_seq"] if result.data else 1 parsed_version = parse_version(squashed_version) version_type = parsed_version.type.value replaces_str = ",".join(replaced_versions) driver.execute( self._get_record_squashed_migration_sql( squashed_version, version_type, next_sequence, description, 0, checksum, os.environ.get("USER", "unknown"), replaces_str, ) ) self._safe_commit(driver) log_with_context( logger, logging.INFO, "migration.track", db_system=resolve_db_system(type(driver).__name__), squashed_version=squashed_version, replaced_count=len(replaced_versions), operation="squash", status="recorded", )
[docs] def is_squash_already_applied( self, driver: "SyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]" ) -> bool: """Check if a squash operation has already been applied. Determines if any of the replaced versions exist in the database, indicating that the original migrations were applied before the squash. Args: driver: The database driver to use. squashed_version: Version number of the squashed migration (unused but kept for API consistency). replaced_versions: List of version strings that would be replaced. Returns: True if any replaced version exists (squash already applied), False otherwise. """ result = driver.execute(self._get_applied_migrations_sql()) applied_versions = {row["version_num"] for row in result.get_data()} if result.data else set() # Check if any replaced version exists in applied migrations return any(version in applied_versions for version in replaced_versions)
def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None: """Safely commit a transaction only if autocommit is disabled. Args: driver: The database driver to use. Raises: Exception: Re-raises non-autocommit related exceptions to prevent silent failures (e.g., SQLite isolation errors). """ if driver.driver_features.get("autocommit", False): return try: driver.commit() except Exception as exc: # Only suppress autocommit-related exceptions exc_str = str(exc).lower() if "autocommit" in exc_str or "cannot commit" in exc_str: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), operation="commit", status="skipped", reason="autocommit", error_type=type(exc).__name__, ) else: # Re-raise non-autocommit exceptions to prevent silent failures raise
[docs] class AsyncMigrationTracker(BaseMigrationTracker["AsyncDriverAdapterBase"]): """Asynchronous migration version tracker.""" async def _migrate_schema_if_needed(self, driver: "AsyncDriverAdapterBase") -> None: """Check for and add any missing columns to the tracking table. Uses the driver's data_dictionary to query existing columns, then compares to the target schema and adds missing columns one by one. Args: driver: The database driver to use. """ console = Console() try: columns_data = await driver.data_dictionary.get_columns(driver, self.version_table) if not columns_data: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="table_check", status="missing", ) columns_data = [] existing_columns = {name for col in columns_data if (name := _extract_column_name(col)) is not None} missing_columns = self._detect_missing_columns(existing_columns) if not missing_columns: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="schema_check", status="current", ) return if self._should_echo(): console.print( f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]" ) for col_name in sorted(missing_columns): await self._add_column(driver, col_name) await driver.commit() if self._should_echo(): console.print("[green]Migration tracking table schema updated successfully[/]") except Exception as exc: log_with_context( logger, logging.ERROR, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, operation="schema_check", status="failed", error_type=type(exc).__name__, ) async def _add_column(self, driver: "AsyncDriverAdapterBase", column_name: str) -> None: """Add a single column to the tracking table. Args: driver: The database driver to use. column_name: Name of the column to add (lowercase). """ target_create = self._get_create_table_sql() column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None) if not column_def: return alter_sql = sql.alter_table(self.version_table).add_column( name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null ) await driver.execute(alter_sql) log_with_context( logger, logging.INFO, "migration.track", db_system=resolve_db_system(type(driver).__name__), table=self.version_table, column_name=column_name, operation="schema_update", status="column_added", )
[docs] async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None: """Create the migration tracking table if it doesn't exist. Also checks for and adds any missing columns to support schema migrations. Args: driver: The database driver to use. """ await driver.execute(self._get_create_table_sql()) await self._safe_commit_async(driver) await self._migrate_schema_if_needed(driver)
[docs] async def get_current_version(self, driver: "AsyncDriverAdapterBase") -> str | None: """Get the latest applied migration version. Args: driver: The database driver to use. Returns: The current version number or None if no migrations applied. """ result = await driver.execute(self._get_current_version_sql()) current = result.get_data()[0]["version_num"] if result.data else None log_with_context( logger, logging.DEBUG, "migration.history", db_system=resolve_db_system(type(driver).__name__), current_version=current, status="current", ) return current
[docs] async def get_applied_migrations(self, driver: "AsyncDriverAdapterBase") -> "list[dict[str, Any]]": """Get all applied migrations in order. Args: driver: The database driver to use. Returns: List of migration records. """ result = await driver.execute(self._get_applied_migrations_sql()) applied = result.get_data() log_with_context( logger, logging.DEBUG, "migration.history", db_system=resolve_db_system(type(driver).__name__), applied_count=len(applied), status="listed", ) return applied
[docs] async def record_migration( self, driver: "AsyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str ) -> None: """Record a successfully applied migration. Parses version to determine type (sequential or timestamp) and auto-increments execution_sequence for application order tracking. Args: driver: The database driver to use. version: Version number of the migration. description: Description of the migration. execution_time_ms: Execution time in milliseconds. checksum: MD5 checksum of the migration content. """ parsed_version = parse_version(version) version_type = parsed_version.type.value result = await driver.execute(self._get_next_execution_sequence_sql()) next_sequence = result.get_data()[0]["next_seq"] if result.data else 1 await driver.execute( self._get_record_migration_sql( version, version_type, next_sequence, description, execution_time_ms, checksum, os.environ.get("USER", "unknown"), ) ) await self._safe_commit_async(driver) log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), version=version, operation="record", status="recorded", )
[docs] async def remove_migration(self, driver: "AsyncDriverAdapterBase", version: str) -> None: """Remove a migration record (used during downgrade). Args: driver: The database driver to use. version: Version number to remove. """ await driver.execute(self._get_remove_migration_sql(version)) await self._safe_commit_async(driver) log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), version=version, operation="remove", status="removed", )
[docs] async def update_version_record(self, driver: "AsyncDriverAdapterBase", old_version: str, new_version: str) -> None: """Update migration version record from timestamp to sequential. Updates version_num and version_type while preserving execution_sequence, applied_at, and other tracking metadata. Used during fix command. Idempotent: If the version is already updated, logs and continues without error. This allows fix command to be safely re-run after pulling changes. Args: driver: The database driver to use. old_version: Current timestamp version string. new_version: New sequential version string. Raises: ValueError: If neither old_version nor new_version found in database. """ parsed_new_version = parse_version(new_version) new_version_type = parsed_new_version.type.value result = await driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type)) if result.rows_affected == 0: check_result = await driver.execute(self._get_applied_migrations_sql()) applied_versions = {row["version_num"] for row in check_result.get_data()} if check_result.data else set() if new_version in applied_versions: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), old_version=old_version, new_version=new_version, operation="version_update", status="skipped", ) return msg = f"Migration version {old_version} not found in database" raise ValueError(msg) await self._safe_commit_async(driver) log_with_context( logger, logging.INFO, "migration.track", db_system=resolve_db_system(type(driver).__name__), old_version=old_version, new_version=new_version, operation="version_update", status="updated", )
[docs] async def replace_with_squash( self, driver: "AsyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]", description: str, checksum: str, ) -> None: """Replace multiple migration records with a single squashed record. Deletes all replaced version records and inserts a new record for the squashed migration with metadata about which versions it replaces. Args: driver: The database driver to use. squashed_version: Version number of the squashed migration. replaced_versions: List of version strings being replaced. description: Description of the squashed migration. checksum: MD5 checksum of the squashed migration content. """ await driver.execute(self._get_delete_versions_sql(replaced_versions)) result = await driver.execute(self._get_next_execution_sequence_sql()) next_sequence = result.get_data()[0]["next_seq"] if result.data else 1 parsed_version = parse_version(squashed_version) version_type = parsed_version.type.value replaces_str = ",".join(replaced_versions) await driver.execute( self._get_record_squashed_migration_sql( squashed_version, version_type, next_sequence, description, 0, checksum, os.environ.get("USER", "unknown"), replaces_str, ) ) await self._safe_commit_async(driver) log_with_context( logger, logging.INFO, "migration.track", db_system=resolve_db_system(type(driver).__name__), squashed_version=squashed_version, replaced_count=len(replaced_versions), operation="squash", status="recorded", )
[docs] async def is_squash_already_applied( self, driver: "AsyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]" ) -> bool: """Check if a squash operation has already been applied. Determines if any of the replaced versions exist in the database, indicating that the original migrations were applied before the squash. Args: driver: The database driver to use. squashed_version: Version number of the squashed migration (unused but kept for API consistency). replaced_versions: List of version strings that would be replaced. Returns: True if any replaced version exists (squash already applied), False otherwise. """ result = await driver.execute(self._get_applied_migrations_sql()) applied_versions = {row["version_num"] for row in result.get_data()} if result.data else set() # Check if any replaced version exists in applied migrations return any(version in applied_versions for version in replaced_versions)
async def _safe_commit_async(self, driver: "AsyncDriverAdapterBase") -> None: """Safely commit a transaction only if autocommit is disabled. Args: driver: The database driver to use. Raises: Exception: Re-raises non-autocommit related exceptions to prevent silent failures (e.g., SQLite isolation errors). """ if driver.driver_features.get("autocommit", False): return try: await driver.commit() except Exception as exc: # Only suppress autocommit-related exceptions exc_str = str(exc).lower() if "autocommit" in exc_str or "cannot commit" in exc_str: log_with_context( logger, logging.DEBUG, "migration.track", db_system=resolve_db_system(type(driver).__name__), operation="commit", status="skipped", reason="autocommit", error_type=type(exc).__name__, ) else: # Re-raise non-autocommit exceptions to prevent silent failures raise