"""Migration version tracking for SQLSpec.
This module provides functionality to track applied migrations in the database.
"""
import logging
import os
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from rich.console import Console
from sqlspec.builder import sql
from sqlspec.migrations.base import BaseMigrationTracker
from sqlspec.migrations.version import parse_version
from sqlspec.observability import resolve_db_system
from sqlspec.utils.logging import get_logger, log_with_context
if TYPE_CHECKING:
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
__all__ = ("AsyncMigrationTracker", "SyncMigrationTracker")
logger = get_logger("sqlspec.migrations.tracker")
def _extract_column_name(metadata: Any) -> "str | None":
"""Extract column name from a metadata entry."""
if isinstance(metadata, Mapping):
value = metadata.get("column_name")
if value is None:
value = metadata.get("COLUMN_NAME")
return str(value).lower() if value is not None else None
value = getattr(metadata, "column_name", None)
if value is not None:
return str(value).lower()
return None
[docs]
class SyncMigrationTracker(BaseMigrationTracker["SyncDriverAdapterBase"]):
"""Synchronous migration version tracker."""
def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None:
"""Check for and add any missing columns to the tracking table.
Uses the adapter's data_dictionary to query existing columns,
then compares to the target schema and adds missing columns one by one.
Args:
driver: The database driver to use.
"""
console = Console()
try:
columns_data = driver.data_dictionary.get_columns(driver, self.version_table)
if not columns_data:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="table_check",
status="missing",
)
columns_data = []
existing_columns = {name for col in columns_data if (name := _extract_column_name(col)) is not None}
missing_columns = self._detect_missing_columns(existing_columns)
if not missing_columns:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="current",
)
return
if self._should_echo():
console.print(
f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]"
)
for col_name in sorted(missing_columns):
self._add_column(driver, col_name)
driver.commit()
if self._should_echo():
console.print("[green]Migration tracking table schema updated successfully[/]")
except Exception as exc:
log_with_context(
logger,
logging.ERROR,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="failed",
error_type=type(exc).__name__,
)
def _add_column(self, driver: "SyncDriverAdapterBase", column_name: str) -> None:
"""Add a single column to the tracking table.
Args:
driver: The database driver to use.
column_name: Name of the column to add (lowercase).
"""
target_create = self._get_create_table_sql()
column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None)
if not column_def:
return
alter_sql = sql.alter_table(self.version_table).add_column(
name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null
)
driver.execute(alter_sql)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
column_name=column_name,
operation="schema_update",
status="column_added",
)
[docs]
def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None:
"""Create the migration tracking table if it doesn't exist.
Also checks for and adds any missing columns to support schema migrations.
Args:
driver: The database driver to use.
"""
driver.execute(self._get_create_table_sql())
self._safe_commit(driver)
self._migrate_schema_if_needed(driver)
[docs]
def get_current_version(self, driver: "SyncDriverAdapterBase") -> str | None:
"""Get the latest applied migration version.
Args:
driver: The database driver to use.
Returns:
The current version number or None if no migrations applied.
"""
result = driver.execute(self._get_current_version_sql())
current = result.get_data()[0]["version_num"] if result.data else None
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
current_version=current,
status="current",
)
return current
[docs]
def get_applied_migrations(self, driver: "SyncDriverAdapterBase") -> "list[dict[str, Any]]":
"""Get all applied migrations in order.
Args:
driver: The database driver to use.
Returns:
List of migration records.
"""
result = driver.execute(self._get_applied_migrations_sql())
applied = result.get_data()
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
applied_count=len(applied),
status="listed",
)
return applied
[docs]
def record_migration(
self, driver: "SyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str
) -> None:
"""Record a successfully applied migration.
Parses version to determine type (sequential or timestamp) and
auto-increments execution_sequence for application order tracking.
Args:
driver: The database driver to use.
version: Version number of the migration.
description: Description of the migration.
execution_time_ms: Execution time in milliseconds.
checksum: MD5 checksum of the migration content.
"""
parsed_version = parse_version(version)
version_type = parsed_version.type.value
result = driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
driver.execute(
self._get_record_migration_sql(
version,
version_type,
next_sequence,
description,
execution_time_ms,
checksum,
os.environ.get("USER", "unknown"),
)
)
self._safe_commit(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="record",
status="recorded",
)
[docs]
def remove_migration(self, driver: "SyncDriverAdapterBase", version: str) -> None:
"""Remove a migration record (used during downgrade).
Args:
driver: The database driver to use.
version: Version number to remove.
"""
driver.execute(self._get_remove_migration_sql(version))
self._safe_commit(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="remove",
status="removed",
)
[docs]
def update_version_record(self, driver: "SyncDriverAdapterBase", old_version: str, new_version: str) -> None:
"""Update migration version record from timestamp to sequential.
Updates version_num and version_type while preserving execution_sequence,
applied_at, and other tracking metadata. Used during fix command.
Idempotent: If the version is already updated, logs and continues without error.
This allows fix command to be safely re-run after pulling changes.
Args:
driver: The database driver to use.
old_version: Current timestamp version string.
new_version: New sequential version string.
Raises:
ValueError: If neither old_version nor new_version found in database.
"""
parsed_new_version = parse_version(new_version)
new_version_type = parsed_new_version.type.value
result = driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type))
if result.rows_affected == 0:
check_result = driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in check_result.get_data()} if check_result.data else set()
if new_version in applied_versions:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="skipped",
)
return
msg = f"Migration version {old_version} not found in database"
raise ValueError(msg)
self._safe_commit(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="updated",
)
[docs]
def replace_with_squash(
self,
driver: "SyncDriverAdapterBase",
squashed_version: str,
replaced_versions: "list[str]",
description: str,
checksum: str,
) -> None:
"""Replace multiple migration records with a single squashed record.
Deletes all replaced version records and inserts a new record for the
squashed migration with metadata about which versions it replaces.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration.
replaced_versions: List of version strings being replaced.
description: Description of the squashed migration.
checksum: MD5 checksum of the squashed migration content.
"""
driver.execute(self._get_delete_versions_sql(replaced_versions))
result = driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
parsed_version = parse_version(squashed_version)
version_type = parsed_version.type.value
replaces_str = ",".join(replaced_versions)
driver.execute(
self._get_record_squashed_migration_sql(
squashed_version,
version_type,
next_sequence,
description,
0,
checksum,
os.environ.get("USER", "unknown"),
replaces_str,
)
)
self._safe_commit(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
squashed_version=squashed_version,
replaced_count=len(replaced_versions),
operation="squash",
status="recorded",
)
[docs]
def is_squash_already_applied(
self, driver: "SyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]"
) -> bool:
"""Check if a squash operation has already been applied.
Determines if any of the replaced versions exist in the database,
indicating that the original migrations were applied before the squash.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration (unused but kept for API consistency).
replaced_versions: List of version strings that would be replaced.
Returns:
True if any replaced version exists (squash already applied), False otherwise.
"""
result = driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in result.get_data()} if result.data else set()
# Check if any replaced version exists in applied migrations
return any(version in applied_versions for version in replaced_versions)
def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None:
"""Safely commit a transaction only if autocommit is disabled.
Args:
driver: The database driver to use.
Raises:
Exception: Re-raises non-autocommit related exceptions to prevent
silent failures (e.g., SQLite isolation errors).
"""
if driver.driver_features.get("autocommit", False):
return
try:
driver.commit()
except Exception as exc:
# Only suppress autocommit-related exceptions
exc_str = str(exc).lower()
if "autocommit" in exc_str or "cannot commit" in exc_str:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
operation="commit",
status="skipped",
reason="autocommit",
error_type=type(exc).__name__,
)
else:
# Re-raise non-autocommit exceptions to prevent silent failures
raise
[docs]
class AsyncMigrationTracker(BaseMigrationTracker["AsyncDriverAdapterBase"]):
"""Asynchronous migration version tracker."""
async def _migrate_schema_if_needed(self, driver: "AsyncDriverAdapterBase") -> None:
"""Check for and add any missing columns to the tracking table.
Uses the driver's data_dictionary to query existing columns,
then compares to the target schema and adds missing columns one by one.
Args:
driver: The database driver to use.
"""
console = Console()
try:
columns_data = await driver.data_dictionary.get_columns(driver, self.version_table)
if not columns_data:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="table_check",
status="missing",
)
columns_data = []
existing_columns = {name for col in columns_data if (name := _extract_column_name(col)) is not None}
missing_columns = self._detect_missing_columns(existing_columns)
if not missing_columns:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="current",
)
return
if self._should_echo():
console.print(
f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]"
)
for col_name in sorted(missing_columns):
await self._add_column(driver, col_name)
await driver.commit()
if self._should_echo():
console.print("[green]Migration tracking table schema updated successfully[/]")
except Exception as exc:
log_with_context(
logger,
logging.ERROR,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="failed",
error_type=type(exc).__name__,
)
async def _add_column(self, driver: "AsyncDriverAdapterBase", column_name: str) -> None:
"""Add a single column to the tracking table.
Args:
driver: The database driver to use.
column_name: Name of the column to add (lowercase).
"""
target_create = self._get_create_table_sql()
column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None)
if not column_def:
return
alter_sql = sql.alter_table(self.version_table).add_column(
name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null
)
await driver.execute(alter_sql)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
column_name=column_name,
operation="schema_update",
status="column_added",
)
[docs]
async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None:
"""Create the migration tracking table if it doesn't exist.
Also checks for and adds any missing columns to support schema migrations.
Args:
driver: The database driver to use.
"""
await driver.execute(self._get_create_table_sql())
await self._safe_commit_async(driver)
await self._migrate_schema_if_needed(driver)
[docs]
async def get_current_version(self, driver: "AsyncDriverAdapterBase") -> str | None:
"""Get the latest applied migration version.
Args:
driver: The database driver to use.
Returns:
The current version number or None if no migrations applied.
"""
result = await driver.execute(self._get_current_version_sql())
current = result.get_data()[0]["version_num"] if result.data else None
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
current_version=current,
status="current",
)
return current
[docs]
async def get_applied_migrations(self, driver: "AsyncDriverAdapterBase") -> "list[dict[str, Any]]":
"""Get all applied migrations in order.
Args:
driver: The database driver to use.
Returns:
List of migration records.
"""
result = await driver.execute(self._get_applied_migrations_sql())
applied = result.get_data()
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
applied_count=len(applied),
status="listed",
)
return applied
[docs]
async def record_migration(
self, driver: "AsyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str
) -> None:
"""Record a successfully applied migration.
Parses version to determine type (sequential or timestamp) and
auto-increments execution_sequence for application order tracking.
Args:
driver: The database driver to use.
version: Version number of the migration.
description: Description of the migration.
execution_time_ms: Execution time in milliseconds.
checksum: MD5 checksum of the migration content.
"""
parsed_version = parse_version(version)
version_type = parsed_version.type.value
result = await driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
await driver.execute(
self._get_record_migration_sql(
version,
version_type,
next_sequence,
description,
execution_time_ms,
checksum,
os.environ.get("USER", "unknown"),
)
)
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="record",
status="recorded",
)
[docs]
async def remove_migration(self, driver: "AsyncDriverAdapterBase", version: str) -> None:
"""Remove a migration record (used during downgrade).
Args:
driver: The database driver to use.
version: Version number to remove.
"""
await driver.execute(self._get_remove_migration_sql(version))
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="remove",
status="removed",
)
[docs]
async def update_version_record(self, driver: "AsyncDriverAdapterBase", old_version: str, new_version: str) -> None:
"""Update migration version record from timestamp to sequential.
Updates version_num and version_type while preserving execution_sequence,
applied_at, and other tracking metadata. Used during fix command.
Idempotent: If the version is already updated, logs and continues without error.
This allows fix command to be safely re-run after pulling changes.
Args:
driver: The database driver to use.
old_version: Current timestamp version string.
new_version: New sequential version string.
Raises:
ValueError: If neither old_version nor new_version found in database.
"""
parsed_new_version = parse_version(new_version)
new_version_type = parsed_new_version.type.value
result = await driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type))
if result.rows_affected == 0:
check_result = await driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in check_result.get_data()} if check_result.data else set()
if new_version in applied_versions:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="skipped",
)
return
msg = f"Migration version {old_version} not found in database"
raise ValueError(msg)
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="updated",
)
[docs]
async def replace_with_squash(
self,
driver: "AsyncDriverAdapterBase",
squashed_version: str,
replaced_versions: "list[str]",
description: str,
checksum: str,
) -> None:
"""Replace multiple migration records with a single squashed record.
Deletes all replaced version records and inserts a new record for the
squashed migration with metadata about which versions it replaces.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration.
replaced_versions: List of version strings being replaced.
description: Description of the squashed migration.
checksum: MD5 checksum of the squashed migration content.
"""
await driver.execute(self._get_delete_versions_sql(replaced_versions))
result = await driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
parsed_version = parse_version(squashed_version)
version_type = parsed_version.type.value
replaces_str = ",".join(replaced_versions)
await driver.execute(
self._get_record_squashed_migration_sql(
squashed_version,
version_type,
next_sequence,
description,
0,
checksum,
os.environ.get("USER", "unknown"),
replaces_str,
)
)
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
squashed_version=squashed_version,
replaced_count=len(replaced_versions),
operation="squash",
status="recorded",
)
[docs]
async def is_squash_already_applied(
self, driver: "AsyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]"
) -> bool:
"""Check if a squash operation has already been applied.
Determines if any of the replaced versions exist in the database,
indicating that the original migrations were applied before the squash.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration (unused but kept for API consistency).
replaced_versions: List of version strings that would be replaced.
Returns:
True if any replaced version exists (squash already applied), False otherwise.
"""
result = await driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in result.get_data()} if result.data else set()
# Check if any replaced version exists in applied migrations
return any(version in applied_versions for version in replaced_versions)
async def _safe_commit_async(self, driver: "AsyncDriverAdapterBase") -> None:
"""Safely commit a transaction only if autocommit is disabled.
Args:
driver: The database driver to use.
Raises:
Exception: Re-raises non-autocommit related exceptions to prevent
silent failures (e.g., SQLite isolation errors).
"""
if driver.driver_features.get("autocommit", False):
return
try:
await driver.commit()
except Exception as exc:
# Only suppress autocommit-related exceptions
exc_str = str(exc).lower()
if "autocommit" in exc_str or "cannot commit" in exc_str:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
operation="commit",
status="skipped",
reason="autocommit",
error_type=type(exc).__name__,
)
else:
# Re-raise non-autocommit exceptions to prevent silent failures
raise