"""Migration version tracking for SQLSpec.
This module provides functionality to track applied migrations in the database.
"""
import logging
import os
from collections.abc import Mapping
from contextlib import suppress
from typing import TYPE_CHECKING, Any
from rich.console import Console
from sqlspec.builder import sql
from sqlspec.migrations.base import BaseMigrationTracker
from sqlspec.migrations.version import parse_version
from sqlspec.observability import resolve_db_system
from sqlspec.utils.logging import get_logger, log_with_context
if TYPE_CHECKING:
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
__all__ = ("AsyncMigrationTracker", "SyncMigrationTracker")
logger = get_logger("sqlspec.migrations.tracker")
def _extract_column_name(metadata: Any) -> "str | None":
"""Extract column name from a metadata entry."""
if isinstance(metadata, Mapping):
value = metadata.get("column_name")
if value is None:
value = metadata.get("COLUMN_NAME")
return str(value).lower() if value is not None else None
value = getattr(metadata, "column_name", None)
if value is not None:
return str(value).lower()
return None
[docs]
class SyncMigrationTracker(BaseMigrationTracker["SyncDriverAdapterBase"]):
"""Synchronous migration version tracker."""
def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None:
"""Check for and add any missing columns to the tracking table.
Uses the adapter's data_dictionary to query existing columns,
then compares to the target schema and adds missing columns one by one.
Args:
driver: The database driver to use.
"""
console = Console()
try:
columns_data = driver.data_dictionary.get_columns(driver, self.version_table)
if not columns_data:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="table_check",
status="missing",
)
columns_data = []
existing_columns = {name for col in columns_data if (name := _extract_column_name(col)) is not None}
missing_columns = self._detect_missing_columns(existing_columns)
if not missing_columns:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="current",
)
return
if self._should_echo():
console.print(
f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]"
)
for col_name in sorted(missing_columns):
self._add_column(driver, col_name)
driver.commit()
if self._should_echo():
console.print("[green]Migration tracking table schema updated successfully[/]")
except Exception as exc:
with suppress(Exception):
driver.rollback()
log_with_context(
logger,
logging.ERROR,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="failed",
error_type=type(exc).__name__,
)
def _add_column(self, driver: "SyncDriverAdapterBase", column_name: str) -> None:
"""Add a single column to the tracking table.
Args:
driver: The database driver to use.
column_name: Name of the column to add (lowercase).
"""
target_create = self._get_create_table_sql()
column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None)
if not column_def:
return
alter_sql = sql.alter_table(self.version_table).add_column(
name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null
)
driver.execute(alter_sql)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
column_name=column_name,
operation="schema_update",
status="column_added",
)
[docs]
def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None:
"""Create the migration tracking table if it doesn't exist.
Also checks for and adds any missing columns to support schema migrations.
Args:
driver: The database driver to use.
"""
driver.execute(self._get_create_table_sql())
self._safe_commit(driver)
self._migrate_schema_if_needed(driver)
[docs]
def get_current_version(self, driver: "SyncDriverAdapterBase") -> str | None:
"""Get the latest applied migration version.
Args:
driver: The database driver to use.
Returns:
The current version number or None if no migrations applied.
"""
result = driver.execute(self._get_current_version_sql())
current = result.get_data()[0]["version_num"] if result.data else None
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
current_version=current,
status="current",
)
return current
[docs]
def get_applied_migrations(self, driver: "SyncDriverAdapterBase") -> "list[dict[str, Any]]":
"""Get all applied migrations in order.
Args:
driver: The database driver to use.
Returns:
List of migration records.
"""
result = driver.execute(self._get_applied_migrations_sql())
applied = result.get_data()
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
applied_count=len(applied),
status="listed",
)
return applied
[docs]
def record_migration(
self, driver: "SyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str
) -> None:
"""Record a successfully applied migration.
Parses version to determine type (sequential or timestamp) and
auto-increments execution_sequence for application order tracking.
Args:
driver: The database driver to use.
version: Version number of the migration.
description: Description of the migration.
execution_time_ms: Execution time in milliseconds.
checksum: MD5 checksum of the migration content.
"""
parsed_version = parse_version(version)
version_type = parsed_version.type.value
result = driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
driver.execute(
self._get_record_migration_sql(
version,
version_type,
next_sequence,
description,
execution_time_ms,
checksum,
os.environ.get("USER", "unknown"),
)
)
self._safe_commit(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="record",
status="recorded",
)
[docs]
def remove_migration(self, driver: "SyncDriverAdapterBase", version: str) -> None:
"""Remove a migration record (used during downgrade).
Args:
driver: The database driver to use.
version: Version number to remove.
"""
driver.execute(self._get_remove_migration_sql(version))
self._safe_commit(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="remove",
status="removed",
)
[docs]
def update_version_record(self, driver: "SyncDriverAdapterBase", old_version: str, new_version: str) -> None:
"""Update migration version record from timestamp to sequential.
Updates version_num and version_type while preserving execution_sequence,
applied_at, and other tracking metadata. Used during fix command.
Idempotent: If the version is already updated, logs and continues without error.
This allows fix command to be safely re-run after pulling changes.
Args:
driver: The database driver to use.
old_version: Current timestamp version string.
new_version: New sequential version string.
Raises:
ValueError: If neither old_version nor new_version found in database.
"""
parsed_new_version = parse_version(new_version)
new_version_type = parsed_new_version.type.value
result = driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type))
if result.rows_affected == 0:
check_result = driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in check_result.get_data()} if check_result.data else set()
if new_version in applied_versions:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="skipped",
)
return
msg = f"Migration version {old_version} not found in database"
raise ValueError(msg)
self._safe_commit(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="updated",
)
[docs]
def replace_with_squash(
self,
driver: "SyncDriverAdapterBase",
squashed_version: str,
replaced_versions: "list[str]",
description: str,
checksum: str,
) -> None:
"""Replace multiple migration records with a single squashed record.
Deletes all replaced version records and inserts a new record for the
squashed migration with metadata about which versions it replaces.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration.
replaced_versions: List of version strings being replaced.
description: Description of the squashed migration.
checksum: MD5 checksum of the squashed migration content.
"""
driver.execute(self._get_delete_versions_sql(replaced_versions))
result = driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
parsed_version = parse_version(squashed_version)
version_type = parsed_version.type.value
replaces_str = ",".join(replaced_versions)
driver.execute(
self._get_record_squashed_migration_sql(
squashed_version,
version_type,
next_sequence,
description,
0,
checksum,
os.environ.get("USER", "unknown"),
replaces_str,
)
)
self._safe_commit(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
squashed_version=squashed_version,
replaced_count=len(replaced_versions),
operation="squash",
status="recorded",
)
[docs]
def is_squash_already_applied(
self, driver: "SyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]"
) -> bool:
"""Check if a squash operation has already been applied.
Determines if any of the replaced versions exist in the database,
indicating that the original migrations were applied before the squash.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration (unused but kept for API consistency).
replaced_versions: List of version strings that would be replaced.
Returns:
True if any replaced version exists (squash already applied), False otherwise.
"""
result = driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in result.get_data()} if result.data else set()
# Check if any replaced version exists in applied migrations
return any(version in applied_versions for version in replaced_versions)
def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None:
"""Safely commit a transaction only if autocommit is disabled.
Args:
driver: The database driver to use.
Raises:
Exception: Re-raises non-autocommit related exceptions to prevent
silent failures (e.g., SQLite isolation errors).
"""
if driver.driver_features.get("autocommit", False):
return
try:
driver.commit()
except Exception as exc:
# Only suppress autocommit-related exceptions
exc_str = str(exc).lower()
if "autocommit" in exc_str or "cannot commit" in exc_str:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
operation="commit",
status="skipped",
reason="autocommit",
error_type=type(exc).__name__,
)
else:
# Re-raise non-autocommit exceptions to prevent silent failures
raise
[docs]
class AsyncMigrationTracker(BaseMigrationTracker["AsyncDriverAdapterBase"]):
"""Asynchronous migration version tracker."""
async def _migrate_schema_if_needed(self, driver: "AsyncDriverAdapterBase") -> None:
"""Check for and add any missing columns to the tracking table.
Uses the driver's data_dictionary to query existing columns,
then compares to the target schema and adds missing columns one by one.
Args:
driver: The database driver to use.
"""
console = Console()
try:
columns_data = await driver.data_dictionary.get_columns(driver, self.version_table)
if not columns_data:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="table_check",
status="missing",
)
columns_data = []
existing_columns = {name for col in columns_data if (name := _extract_column_name(col)) is not None}
missing_columns = self._detect_missing_columns(existing_columns)
if not missing_columns:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="current",
)
return
if self._should_echo():
console.print(
f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]"
)
for col_name in sorted(missing_columns):
await self._add_column(driver, col_name)
await driver.commit()
if self._should_echo():
console.print("[green]Migration tracking table schema updated successfully[/]")
except Exception as exc:
with suppress(Exception):
await driver.rollback()
log_with_context(
logger,
logging.ERROR,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
operation="schema_check",
status="failed",
error_type=type(exc).__name__,
)
async def _add_column(self, driver: "AsyncDriverAdapterBase", column_name: str) -> None:
"""Add a single column to the tracking table.
Args:
driver: The database driver to use.
column_name: Name of the column to add (lowercase).
"""
target_create = self._get_create_table_sql()
column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None)
if not column_def:
return
alter_sql = sql.alter_table(self.version_table).add_column(
name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null
)
await driver.execute(alter_sql)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
table=self.version_table,
column_name=column_name,
operation="schema_update",
status="column_added",
)
[docs]
async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None:
"""Create the migration tracking table if it doesn't exist.
Also checks for and adds any missing columns to support schema migrations.
Args:
driver: The database driver to use.
"""
await driver.execute(self._get_create_table_sql())
await self._safe_commit_async(driver)
await self._migrate_schema_if_needed(driver)
[docs]
async def get_current_version(self, driver: "AsyncDriverAdapterBase") -> str | None:
"""Get the latest applied migration version.
Args:
driver: The database driver to use.
Returns:
The current version number or None if no migrations applied.
"""
result = await driver.execute(self._get_current_version_sql())
current = result.get_data()[0]["version_num"] if result.data else None
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
current_version=current,
status="current",
)
return current
[docs]
async def get_applied_migrations(self, driver: "AsyncDriverAdapterBase") -> "list[dict[str, Any]]":
"""Get all applied migrations in order.
Args:
driver: The database driver to use.
Returns:
List of migration records.
"""
result = await driver.execute(self._get_applied_migrations_sql())
applied = result.get_data()
log_with_context(
logger,
logging.DEBUG,
"migration.history",
db_system=resolve_db_system(type(driver).__name__),
applied_count=len(applied),
status="listed",
)
return applied
[docs]
async def record_migration(
self, driver: "AsyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str
) -> None:
"""Record a successfully applied migration.
Parses version to determine type (sequential or timestamp) and
auto-increments execution_sequence for application order tracking.
Args:
driver: The database driver to use.
version: Version number of the migration.
description: Description of the migration.
execution_time_ms: Execution time in milliseconds.
checksum: MD5 checksum of the migration content.
"""
parsed_version = parse_version(version)
version_type = parsed_version.type.value
result = await driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
await driver.execute(
self._get_record_migration_sql(
version,
version_type,
next_sequence,
description,
execution_time_ms,
checksum,
os.environ.get("USER", "unknown"),
)
)
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="record",
status="recorded",
)
[docs]
async def remove_migration(self, driver: "AsyncDriverAdapterBase", version: str) -> None:
"""Remove a migration record (used during downgrade).
Args:
driver: The database driver to use.
version: Version number to remove.
"""
await driver.execute(self._get_remove_migration_sql(version))
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
version=version,
operation="remove",
status="removed",
)
[docs]
async def update_version_record(self, driver: "AsyncDriverAdapterBase", old_version: str, new_version: str) -> None:
"""Update migration version record from timestamp to sequential.
Updates version_num and version_type while preserving execution_sequence,
applied_at, and other tracking metadata. Used during fix command.
Idempotent: If the version is already updated, logs and continues without error.
This allows fix command to be safely re-run after pulling changes.
Args:
driver: The database driver to use.
old_version: Current timestamp version string.
new_version: New sequential version string.
Raises:
ValueError: If neither old_version nor new_version found in database.
"""
parsed_new_version = parse_version(new_version)
new_version_type = parsed_new_version.type.value
result = await driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type))
if result.rows_affected == 0:
check_result = await driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in check_result.get_data()} if check_result.data else set()
if new_version in applied_versions:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="skipped",
)
return
msg = f"Migration version {old_version} not found in database"
raise ValueError(msg)
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
old_version=old_version,
new_version=new_version,
operation="version_update",
status="updated",
)
[docs]
async def replace_with_squash(
self,
driver: "AsyncDriverAdapterBase",
squashed_version: str,
replaced_versions: "list[str]",
description: str,
checksum: str,
) -> None:
"""Replace multiple migration records with a single squashed record.
Deletes all replaced version records and inserts a new record for the
squashed migration with metadata about which versions it replaces.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration.
replaced_versions: List of version strings being replaced.
description: Description of the squashed migration.
checksum: MD5 checksum of the squashed migration content.
"""
await driver.execute(self._get_delete_versions_sql(replaced_versions))
result = await driver.execute(self._get_next_execution_sequence_sql())
next_sequence = result.get_data()[0]["next_seq"] if result.data else 1
parsed_version = parse_version(squashed_version)
version_type = parsed_version.type.value
replaces_str = ",".join(replaced_versions)
await driver.execute(
self._get_record_squashed_migration_sql(
squashed_version,
version_type,
next_sequence,
description,
0,
checksum,
os.environ.get("USER", "unknown"),
replaces_str,
)
)
await self._safe_commit_async(driver)
log_with_context(
logger,
logging.INFO,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
squashed_version=squashed_version,
replaced_count=len(replaced_versions),
operation="squash",
status="recorded",
)
[docs]
async def is_squash_already_applied(
self, driver: "AsyncDriverAdapterBase", squashed_version: str, replaced_versions: "list[str]"
) -> bool:
"""Check if a squash operation has already been applied.
Determines if any of the replaced versions exist in the database,
indicating that the original migrations were applied before the squash.
Args:
driver: The database driver to use.
squashed_version: Version number of the squashed migration (unused but kept for API consistency).
replaced_versions: List of version strings that would be replaced.
Returns:
True if any replaced version exists (squash already applied), False otherwise.
"""
result = await driver.execute(self._get_applied_migrations_sql())
applied_versions = {row["version_num"] for row in result.get_data()} if result.data else set()
# Check if any replaced version exists in applied migrations
return any(version in applied_versions for version in replaced_versions)
async def _safe_commit_async(self, driver: "AsyncDriverAdapterBase") -> None:
"""Safely commit a transaction only if autocommit is disabled.
Args:
driver: The database driver to use.
Raises:
Exception: Re-raises non-autocommit related exceptions to prevent
silent failures (e.g., SQLite isolation errors).
"""
if driver.driver_features.get("autocommit", False):
return
try:
await driver.commit()
except Exception as exc:
# Only suppress autocommit-related exceptions
exc_str = str(exc).lower()
if "autocommit" in exc_str or "cannot commit" in exc_str:
log_with_context(
logger,
logging.DEBUG,
"migration.track",
db_system=resolve_db_system(type(driver).__name__),
operation="commit",
status="skipped",
reason="autocommit",
error_type=type(exc).__name__,
)
else:
# Re-raise non-autocommit exceptions to prevent silent failures
raise