"""Migration command implementations for SQLSpec.
This module provides the main command interface for database migrations.
"""
import functools
import inspect
import logging
import time
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
from rich.console import Console
from rich.table import Table
from sqlspec.builder import sql
from sqlspec.migrations.base import BaseMigrationCommands
from sqlspec.migrations.context import MigrationContext
from sqlspec.migrations.fix import MigrationFixer
from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
from sqlspec.migrations.squash import MigrationSquasher
from sqlspec.migrations.utils import create_migration_file
from sqlspec.migrations.validation import validate_migration_order
from sqlspec.migrations.version import generate_conversion_map, generate_timestamp_version, parse_version
from sqlspec.observability import resolve_db_system
from sqlspec.utils.logging import get_logger, log_with_context
if TYPE_CHECKING:
from pathlib import Path
from sqlspec.config import AsyncConfigT, SyncConfigT
__all__ = ("AsyncMigrationCommands", "SyncMigrationCommands", "create_migration_commands")
logger = get_logger("sqlspec.migrations.commands")
console = Console()
P = ParamSpec("P")
R = TypeVar("R")
def _output_info(
use_logger: bool, echo: bool, summary_only: bool, message: str, *args: Any, rich_message: str | None = None
) -> None:
"""Output an info message to logger or console."""
if use_logger:
if summary_only:
return
logger.info(message, *args)
else:
if not echo:
return
console.print(rich_message or message % args if args else message)
def _output_warning(
use_logger: bool, echo: bool, summary_only: bool, message: str, *args: Any, rich_message: str | None = None
) -> None:
"""Output a warning message to logger or console."""
if use_logger:
logger.warning(message, *args)
else:
if not echo:
return
console.print(rich_message or message % args if args else message)
def _output_error(
use_logger: bool, echo: bool, summary_only: bool, message: str, *args: Any, rich_message: str | None = None
) -> None:
"""Output an error message to logger or console."""
if use_logger:
logger.error(message, *args)
else:
if not echo:
return
console.print(rich_message or message % args if args else message)
def _output_exception(
use_logger: bool, echo: bool, summary_only: bool, message: str, *args: Any, rich_message: str | None = None
) -> None:
"""Output an exception message to logger or console."""
if use_logger:
logger.exception(message, *args)
else:
if not echo:
return
console.print(rich_message or message % args if args else message)
def _log_command_summary(
*,
use_logger: bool,
summary_only: bool,
command: str,
status: str,
revision: str,
dry_run: bool,
pending_count: int,
applied_count: int | None,
reverted_count: int | None,
duration_ms: int,
db_system: str | None,
bind_key: str | None,
config_name: str,
error: Exception | None = None,
allow_missing: bool | None = None,
auto_sync: bool | None = None,
) -> None:
"""Emit a single summary log entry for migration commands."""
if not use_logger or not summary_only:
return
level = logging.ERROR if status == "failed" else logging.INFO
extra_fields: dict[str, Any] = {
"command": command,
"status": status,
"revision": revision,
"dry_run": dry_run,
"pending_count": pending_count,
"duration_ms": duration_ms,
"db_system": db_system,
"bind_key": bind_key,
"config_name": config_name,
}
if applied_count is not None:
extra_fields["applied_count"] = applied_count
if reverted_count is not None:
extra_fields["reverted_count"] = reverted_count
if allow_missing is not None:
extra_fields["allow_missing"] = allow_missing
if auto_sync is not None:
extra_fields["auto_sync"] = auto_sync
if error is not None:
extra_fields["error_type"] = type(error).__name__
log_with_context(logger, level, "migration.command.summary", **extra_fields)
MetadataBuilder = Callable[[dict[str, Any]], tuple[str | None, dict[str, Any]]]
def _bind_arguments(signature: inspect.Signature, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
bound = signature.bind_partial(*args, **kwargs)
arguments = dict(bound.arguments)
arguments.pop("self", None)
return arguments
def _with_command_span(
event: str, metadata_fn: "MetadataBuilder | None" = None, *, dry_run_param: str | None = "dry_run"
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Attach span lifecycle and command metric management to command methods."""
metric_prefix = f"migrations.command.{event}"
def decorator(func: Callable[P, R]) -> Callable[P, R]:
signature = inspect.signature(func)
def _prepare(self: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, bool, Any]:
runtime = self._runtime
metadata_args = _bind_arguments(signature, args, kwargs)
dry_run = False
if dry_run_param is not None:
dry_run = bool(metadata_args.get(dry_run_param, False))
metadata: dict[str, Any] | None = None
version: str | None = None
span = None
if runtime is not None:
runtime.increment_metric(f"{metric_prefix}.invocations")
if dry_run_param is not None and dry_run:
runtime.increment_metric(f"{metric_prefix}.dry_run")
if metadata_fn is not None:
version, metadata = metadata_fn(metadata_args)
span = runtime.start_migration_span(f"command.{event}", version=version, metadata=metadata)
return runtime, dry_run, span
def _finalize(
self: Any,
runtime: Any,
span: Any,
start: float,
error: "Exception | None",
recorded_error: bool,
dry_run: bool,
) -> None:
command_error = self._last_command_error
self._last_command_error = None
command_metrics = self._last_command_metrics
self._last_command_metrics = None
if runtime is None:
return
if command_error is not None and not recorded_error:
runtime.increment_metric(f"{metric_prefix}.errors")
if not dry_run and command_metrics:
for metric, value in command_metrics.items():
runtime.increment_metric(f"{metric_prefix}.{metric}", value)
duration_ms = int((time.perf_counter() - start) * 1000)
runtime.end_migration_span(span, duration_ms=duration_ms, error=error or command_error)
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
self = args[0]
runtime, dry_run, span = _prepare(self, args, kwargs)
start = time.perf_counter()
error: Exception | None = None
error_recorded = False
try:
async_func = cast("Callable[P, Awaitable[R]]", func)
return await async_func(*args, **kwargs)
except Exception as exc: # pragma: no cover - passthrough
error = exc
if runtime is not None:
runtime.increment_metric(f"{metric_prefix}.errors")
error_recorded = True
raise
finally:
_finalize(self, runtime, span, start, error, error_recorded, dry_run)
return cast("Callable[P, R]", async_wrapper)
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
self = args[0]
runtime, dry_run, span = _prepare(self, args, kwargs)
start = time.perf_counter()
error: Exception | None = None
error_recorded = False
try:
return func(*args, **kwargs)
except Exception as exc: # pragma: no cover - passthrough
error = exc
if runtime is not None:
runtime.increment_metric(f"{metric_prefix}.errors")
error_recorded = True
raise
finally:
_finalize(self, runtime, span, start, error, error_recorded, dry_run)
return cast("Callable[P, R]", sync_wrapper)
return decorator
def _upgrade_metadata(args: dict[str, Any]) -> tuple[str | None, dict[str, Any]]:
revision = cast("str | None", args.get("revision"))
metadata = {"dry_run": str(args.get("dry_run", False)).lower()}
return revision, metadata
def _downgrade_metadata(args: dict[str, Any]) -> tuple[str | None, dict[str, Any]]:
revision = cast("str | None", args.get("revision"))
metadata = {"dry_run": str(args.get("dry_run", False)).lower()}
return revision, metadata
[docs]
class SyncMigrationCommands(BaseMigrationCommands["SyncConfigT", Any]):
"""Synchronous migration commands."""
[docs]
def __init__(self, config: "SyncConfigT") -> None:
"""Initialize migration commands.
Args:
config: The SQLSpec configuration.
"""
super().__init__(config)
self.tracker = config.migration_tracker_type(self.version_table)
# Create context with extension configurations
context = MigrationContext.from_config(config)
context.extension_config = self.extension_configs
self.runner = SyncMigrationRunner(
self.migrations_path,
self._discover_extension_migrations(),
context,
self.extension_configs,
runtime=self._runtime,
description_hints=self._template_settings.description_hints,
)
[docs]
def init(self, directory: str, package: bool = True) -> None:
"""Initialize migration directory structure.
Args:
directory: Directory to initialize migrations in.
package: Whether to create __init__.py file.
"""
self.init_directory(directory, package)
[docs]
def current(self, verbose: bool = False) -> "str | None":
"""Show current migration version.
Args:
verbose: Whether to show detailed migration history.
Returns:
The current migration version or None if no migrations applied.
"""
with self.config.provide_session() as driver:
self.tracker.ensure_tracking_table(driver)
current = self.tracker.get_current_version(driver)
if not current:
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(driver).__name__),
current_version=None,
applied_count=0,
verbose=verbose,
status="empty",
)
console.print("[yellow]No migrations applied yet[/]")
return None
console.print(f"[green]Current version:[/] {current}")
applied: list[dict[str, Any]] = []
if verbose:
applied = self.tracker.get_applied_migrations(driver)
table = Table(title="Applied Migrations")
table.add_column("Version", style="cyan")
table.add_column("Description")
table.add_column("Applied At")
table.add_column("Time (ms)", justify="right")
table.add_column("Applied By")
for migration in applied:
table.add_row(
migration["version_num"],
migration.get("description", ""),
str(migration.get("applied_at", "")),
str(migration.get("execution_time_ms", "")),
migration.get("applied_by", ""),
)
console.print(table)
applied_count = len(applied) if verbose else None
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(driver).__name__),
current_version=current,
applied_count=applied_count,
verbose=verbose,
status="complete",
)
return cast("str | None", current)
def _load_single_migration_checksum(self, version: str, file_path: "Path") -> "tuple[str, tuple[str, Path]] | None":
"""Load checksum for a single migration.
Args:
version: Migration version.
file_path: Path to migration file.
Returns:
Tuple of (version, (checksum, file_path)) or None if load fails.
"""
try:
migration = self.runner.load_migration(file_path, version)
return (version, (migration["checksum"], file_path))
except Exception as exc:
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(self.config).__name__),
version=version,
file_path=str(file_path),
error_type=type(exc).__name__,
status="failed",
operation="load_checksum",
)
return None
def _load_migration_checksums(self, all_migrations: "list[tuple[str, Path]]") -> "dict[str, tuple[str, Path]]":
"""Load checksums for all migrations.
Args:
all_migrations: List of (version, file_path) tuples.
Returns:
Dictionary mapping version to (checksum, file_path) tuples.
"""
file_checksums = {}
for version, file_path in all_migrations:
result = self._load_single_migration_checksum(version, file_path)
if result:
file_checksums[result[0]] = result[1]
return file_checksums
def _collect_pending_migrations(
self, all_migrations: "list[tuple[str, Path]]", applied_set: set[str], revision: str
) -> "list[tuple[str, Path]]":
"""Collect pending migrations that need to be applied.
Args:
all_migrations: List of all (version, file_path) tuples.
applied_set: Set of already applied version strings.
revision: Target revision ("head" or specific version).
Returns:
List of (version, file_path) tuples for pending migrations.
"""
pending = []
for version, file_path in all_migrations:
if version not in applied_set:
if revision == "head":
pending.append((version, file_path))
else:
parsed_version = parse_version(version)
parsed_revision = parse_version(revision)
if parsed_version <= parsed_revision:
pending.append((version, file_path))
return pending
def _report_no_pending_migrations(
self, use_logger: bool, echo: bool, summary_only: bool, has_migrations: bool
) -> None:
"""Report that there are no pending migrations.
Args:
use_logger: Whether to output to logger instead of console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
has_migrations: Whether any migrations exist at all.
"""
if not has_migrations:
_output_info(
use_logger,
echo,
summary_only,
"No migrations found. Create your first migration with 'sqlspec create-migration'.",
rich_message="[yellow]No migrations found. Create your first migration with 'sqlspec create-migration'.[/]",
)
else:
_output_info(
use_logger,
echo,
summary_only,
"Already at latest version",
rich_message="[green]Already at latest version[/]",
)
def _apply_single_migration(
self, driver: Any, migration: "dict[str, Any]", version: str, use_logger: bool, echo: bool, summary_only: bool
) -> int | None:
"""Apply a single migration and record it.
Args:
driver: Database driver instance.
migration: Migration dictionary with version, description, checksum.
version: Version string.
use_logger: Whether to output to logger instead of console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
Returns:
Execution time in ms on success, None on failure.
"""
try:
def record_version(exec_time: int, migration: "dict[str, Any]" = migration) -> None:
self.tracker.record_migration(
driver, migration["version"], migration["description"], exec_time, migration["checksum"]
)
_, execution_time = self.runner.execute_upgrade(driver, migration, on_success=record_version)
except Exception as exc:
use_txn = self.runner.should_use_transaction(migration, self.config)
rollback_msg = " (transaction rolled back)" if use_txn else ""
_output_exception(
use_logger,
echo,
summary_only,
"Migration %s failed%s",
version,
rollback_msg,
rich_message=f"[red]✗ Failed{rollback_msg}: {exc}[/]",
)
self._last_command_error = exc
return None
else:
_output_info(
use_logger,
echo,
summary_only,
"Applied migration %s in %dms",
version,
execution_time,
rich_message=f"[green]✓ Applied in {execution_time}ms[/]",
)
return execution_time
def _collect_revert_migrations(self, applied: "list[dict[str, Any]]", revision: str) -> "list[dict[str, Any]]":
"""Collect migrations to revert based on target revision.
Args:
applied: List of applied migration records.
revision: Target revision ("-1", "base", or specific version).
Returns:
List of migration records to revert.
"""
if revision == "-1":
return [applied[-1]]
if revision == "base":
return list(reversed(applied))
parsed_revision = parse_version(revision)
to_revert = []
for migration in reversed(applied):
parsed_migration_version = parse_version(migration["version_num"])
if parsed_migration_version > parsed_revision:
to_revert.append(migration)
return to_revert
def _revert_single_migration(
self, driver: Any, migration: "dict[str, Any]", version: str, use_logger: bool, echo: bool, summary_only: bool
) -> int | None:
"""Revert a single migration.
Args:
driver: Database driver instance.
migration: Migration dictionary.
version: Version string.
use_logger: Whether to output to logger instead of console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
Returns:
Execution time in ms on success, None on failure.
"""
try:
def remove_version(exec_time: int, version: str = version) -> None:
self.tracker.remove_migration(driver, version)
_, execution_time = self.runner.execute_downgrade(driver, migration, on_success=remove_version)
except Exception as exc:
use_txn = self.runner.should_use_transaction(migration, self.config)
rollback_msg = " (transaction rolled back)" if use_txn else ""
_output_exception(
use_logger,
echo,
summary_only,
"Migration %s failed%s",
version,
rollback_msg,
rich_message=f"[red]✗ Failed{rollback_msg}: {exc}[/]",
)
self._last_command_error = exc
return None
else:
_output_info(
use_logger,
echo,
summary_only,
"Reverted migration %s in %dms",
version,
execution_time,
rich_message=f"[green]✓ Reverted in {execution_time}ms[/]",
)
return execution_time
def _synchronize_version_records(
self, driver: Any, *, use_logger: bool = False, echo: bool = True, summary_only: bool = False
) -> int:
"""Synchronize database version records with migration files.
Auto-updates DB tracking when migrations have been renamed by fix command.
This allows developers to just run upgrade after pulling changes without
manually running fix.
Validates checksums match before updating to prevent incorrect matches.
Args:
driver: Database driver instance.
use_logger: If True, output to logger instead of Rich console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
Returns:
Number of version records updated.
"""
all_migrations = self.runner.get_migration_files()
try:
applied_migrations = self.tracker.get_applied_migrations(driver)
except Exception as exc:
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(driver).__name__),
error_type=type(exc).__name__,
status="failed",
operation="applied_fetch",
)
return 0
applied_map = {m["version_num"]: m for m in applied_migrations}
conversion_map = generate_conversion_map(all_migrations)
updated_count = 0
if conversion_map:
for old_version, new_version in conversion_map.items():
if old_version in applied_map and new_version not in applied_map:
applied_checksum = applied_map[old_version]["checksum"]
file_path = next((path for v, path in all_migrations if v == new_version), None)
if file_path:
migration = self.runner.load_migration(file_path, new_version)
if migration["checksum"] == applied_checksum:
self.tracker.update_version_record(driver, old_version, new_version)
if use_logger:
if not summary_only:
logger.info("Reconciled version: %s -> %s", old_version, new_version)
elif echo:
console.print(f" [dim]Reconciled version:[/] {old_version} → {new_version}")
updated_count += 1
elif use_logger:
logger.warning(
"Checksum mismatch for %s -> %s, skipping auto-sync", old_version, new_version
)
elif echo:
console.print(
f" [yellow]Warning: Checksum mismatch for {old_version} → {new_version}, skipping auto-sync[/]"
)
else:
file_checksums = self._load_migration_checksums(all_migrations)
for applied_version, applied_record in applied_map.items():
for file_version, (file_checksum, _) in file_checksums.items():
if file_version not in applied_map and applied_record["checksum"] == file_checksum:
self.tracker.update_version_record(driver, applied_version, file_version)
if use_logger:
if not summary_only:
logger.info("Reconciled version: %s -> %s", applied_version, file_version)
elif echo:
console.print(f" [dim]Reconciled version:[/] {applied_version} → {file_version}")
updated_count += 1
break
if updated_count > 0:
if use_logger:
if not summary_only:
logger.info("Reconciled %d version record(s)", updated_count)
elif echo:
console.print(f"[cyan]Reconciled {updated_count} version record(s)[/]")
return updated_count
[docs]
@_with_command_span("upgrade", metadata_fn=_upgrade_metadata)
def upgrade(
self,
revision: str = "head",
allow_missing: bool = False,
auto_sync: bool = True,
dry_run: bool = False,
*,
use_logger: bool = False,
echo: bool | None = None,
summary_only: bool | None = None,
) -> None:
"""Upgrade to a target revision.
Validates migration order and warns if out-of-order migrations are detected.
Out-of-order migrations can occur when branches merge in different orders
across environments.
Args:
revision: Target revision or "head" for latest.
allow_missing: If True, allow out-of-order migrations even in strict mode.
Defaults to False.
auto_sync: If True, automatically reconcile renamed migrations in database.
Defaults to True. Can be disabled via --no-auto-sync flag.
dry_run: If True, show what would be done without making changes.
use_logger: If True, output to logger instead of Rich console.
Defaults to False. Can be set via MigrationConfig for persistent default.
echo: Echo output to the console. Defaults to True when unset.
summary_only: Emit a single summary log entry when logger output is enabled.
"""
runtime = self._runtime
applied_count = 0
pending_count = 0
db_system: str | None = None
error: Exception | None = None
ul, echo_value, summary_value = self._resolve_output_policy(use_logger, echo, summary_only)
self.runner.set_use_logger(ul)
self.runner.set_summary_only(summary_value)
self.tracker.set_output_policy(use_logger=ul, echo=echo_value, summary_only=summary_value)
output_info = functools.partial(_output_info, ul, echo_value, summary_value)
start_time = time.perf_counter()
try:
if dry_run:
output_info(
"DRY RUN MODE: No database changes will be applied",
rich_message="[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n",
)
with self.config.provide_session() as driver:
db_system = resolve_db_system(type(driver).__name__)
self.tracker.ensure_tracking_table(driver)
if auto_sync and self.config.migration_config.get("auto_sync", True):
self._synchronize_version_records(
driver, use_logger=ul, echo=echo_value, summary_only=summary_value
)
applied_migrations = self.tracker.get_applied_migrations(driver)
applied_versions = [m["version_num"] for m in applied_migrations]
applied_set = set(applied_versions)
all_migrations = self.runner.get_migration_files()
if runtime is not None:
runtime.increment_metric("migrations.command.upgrade.available", float(len(all_migrations)))
pending = self._collect_pending_migrations(all_migrations, applied_set, revision)
pending_count = len(pending)
if runtime is not None:
runtime.increment_metric("migrations.command.upgrade.pending", float(len(pending)))
if not pending:
self._report_no_pending_migrations(ul, echo_value, summary_value, bool(all_migrations))
return
migration_config = cast("dict[str, Any]", self.config.migration_config) or {}
strict_ordering = migration_config.get("strict_ordering", False) and not allow_missing
validate_migration_order(
[v for v, _ in pending],
applied_versions,
strict_ordering,
use_logger=ul,
echo=echo_value,
summary_only=summary_value,
)
output_info(
"Found %d pending migrations",
len(pending),
rich_message=f"[yellow]Found {len(pending)} pending migrations[/]",
)
for version, file_path in pending:
migration = self.runner.load_migration(file_path, version)
action_verb = "Would apply" if dry_run else "Applying"
output_info(
"%s %s: %s",
action_verb,
version,
migration["description"],
rich_message=f"\n[cyan]{action_verb} {version}:[/] {migration['description']}",
)
if dry_run:
output_info(
"Migration file: %s", file_path, rich_message=f"[dim]Migration file: {file_path}[/]"
)
continue
result = self._apply_single_migration(driver, migration, version, ul, echo_value, summary_value)
if result is None:
return
applied_count += 1
except Exception as exc: # pragma: no cover - passthrough
error = exc
raise
finally:
duration_ms = int((time.perf_counter() - start_time) * 1000)
_log_command_summary(
use_logger=ul,
summary_only=summary_value,
command="upgrade",
status="failed" if error else "complete",
revision=revision,
dry_run=dry_run,
pending_count=pending_count,
applied_count=applied_count,
reverted_count=None,
duration_ms=duration_ms,
db_system=db_system,
bind_key=getattr(self.config, "bind_key", None),
config_name=type(self.config).__name__,
error=error,
allow_missing=allow_missing,
auto_sync=auto_sync,
)
if dry_run:
output_info(
"Dry run complete. No changes were made to the database.",
rich_message="\n[bold yellow]Dry run complete.[/] No changes were made to the database.",
)
elif applied_count:
self._record_command_metric("applied", float(applied_count))
[docs]
@_with_command_span("downgrade", metadata_fn=_downgrade_metadata)
def downgrade(
self,
revision: str = "-1",
*,
dry_run: bool = False,
use_logger: bool = False,
echo: bool | None = None,
summary_only: bool | None = None,
) -> None:
"""Downgrade to a target revision.
Args:
revision: Target revision or "-1" for one step back.
dry_run: If True, show what would be done without making changes.
use_logger: If True, output to logger instead of Rich console.
Defaults to False. Can be set via MigrationConfig for persistent default.
echo: Echo output to the console. Defaults to True when unset.
summary_only: Emit a single summary log entry when logger output is enabled.
"""
runtime = self._runtime
reverted_count = 0
pending_count = 0
db_system: str | None = None
error: Exception | None = None
ul, echo_value, summary_value = self._resolve_output_policy(use_logger, echo, summary_only)
self.runner.set_use_logger(ul)
self.runner.set_summary_only(summary_value)
self.tracker.set_output_policy(use_logger=ul, echo=echo_value, summary_only=summary_value)
output_info = functools.partial(_output_info, ul, echo_value, summary_value)
output_error = functools.partial(_output_error, ul, echo_value, summary_value)
start_time = time.perf_counter()
try:
if dry_run:
output_info(
"DRY RUN MODE: No database changes will be applied",
rich_message="[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n",
)
with self.config.provide_session() as driver:
db_system = resolve_db_system(type(driver).__name__)
self.tracker.ensure_tracking_table(driver)
applied = self.tracker.get_applied_migrations(driver)
if runtime is not None:
runtime.increment_metric("migrations.command.downgrade.available", float(len(applied)))
if not applied:
output_info("No migrations to downgrade", rich_message="[yellow]No migrations to downgrade[/]")
return
to_revert = self._collect_revert_migrations(applied, revision)
pending_count = len(to_revert)
if runtime is not None:
runtime.increment_metric("migrations.command.downgrade.pending", float(len(to_revert)))
if not to_revert:
output_info("Nothing to downgrade", rich_message="[yellow]Nothing to downgrade[/]")
return
output_info(
"Reverting %d migrations",
len(to_revert),
rich_message=f"[yellow]Reverting {len(to_revert)} migrations[/]",
)
all_files = dict(self.runner.get_migration_files())
for migration_record in to_revert:
version = migration_record["version_num"]
if version not in all_files:
output_error(
"Migration file not found for %s",
version,
rich_message=f"[red]Migration file not found for {version}[/]",
)
if runtime is not None:
runtime.increment_metric("migrations.command.downgrade.missing_files")
continue
migration = self.runner.load_migration(all_files[version], version)
action_verb = "Would revert" if dry_run else "Reverting"
output_info(
"%s %s: %s",
action_verb,
version,
migration["description"],
rich_message=f"\n[cyan]{action_verb} {version}:[/] {migration['description']}",
)
if dry_run:
output_info(
"Migration file: %s",
all_files[version],
rich_message=f"[dim]Migration file: {all_files[version]}[/]",
)
continue
result = self._revert_single_migration(driver, migration, version, ul, echo_value, summary_value)
if result is None:
return
reverted_count += 1
except Exception as exc: # pragma: no cover - passthrough
error = exc
raise
finally:
duration_ms = int((time.perf_counter() - start_time) * 1000)
_log_command_summary(
use_logger=ul,
summary_only=summary_value,
command="downgrade",
status="failed" if error else "complete",
revision=revision,
dry_run=dry_run,
pending_count=pending_count,
applied_count=None,
reverted_count=reverted_count,
duration_ms=duration_ms,
db_system=db_system,
bind_key=getattr(self.config, "bind_key", None),
config_name=type(self.config).__name__,
error=error,
)
if dry_run:
output_info(
"Dry run complete. No changes were made to the database.",
rich_message="\n[bold yellow]Dry run complete.[/] No changes were made to the database.",
)
elif reverted_count:
self._record_command_metric("applied", float(reverted_count))
[docs]
def stamp(self, revision: str) -> None:
"""Mark database as being at a specific revision without running migrations.
Args:
revision: The revision to stamp.
"""
with self.config.provide_session() as driver:
self.tracker.ensure_tracking_table(driver)
all_migrations = dict(self.runner.get_migration_files())
if revision not in all_migrations:
console.print(f"[red]Unknown revision: {revision}[/]")
return
clear_sql = sql.delete().from_(self.tracker.version_table)
driver.execute(clear_sql)
self.tracker.record_migration(driver, revision, f"Stamped to {revision}", 0, "manual-stamp")
console.print(f"[green]Database stamped at revision {revision}[/]")
[docs]
def revision(self, message: str, file_type: str | None = None) -> None:
"""Create a new migration file with timestamp-based versioning.
Generates a unique timestamp version (YYYYMMDDHHmmss format) to avoid
conflicts when multiple developers create migrations concurrently.
Args:
message: Description for the migration.
file_type: Type of migration file to create ('sql' or 'py').
"""
version = generate_timestamp_version()
selected_format = file_type or self._template_settings.default_format
file_path = create_migration_file(
self.migrations_path,
version,
message,
selected_format,
config=self.config,
template_settings=self._template_settings,
)
log_with_context(
logger,
logging.DEBUG,
"migration.create",
db_system=resolve_db_system(type(self.config).__name__),
version=version,
file_path=str(file_path),
file_type=selected_format,
description=message,
)
console.print(f"[green]Created migration:[/] {file_path}")
[docs]
def squash(
self,
start_version: str | None = None,
end_version: str | None = None,
description: str | None = None,
*,
dry_run: bool = False,
update_database: bool = True,
yes: bool = False,
allow_gaps: bool = False,
output_format: str = "sql",
) -> None:
"""Squash a range of migrations into a single file.
Combines multiple sequential migrations into a single "release" migration.
UP statements are merged in version order, DOWN statements in reverse order.
Args:
start_version: First version in the range to squash (inclusive).
When None, defaults to the first sequential migration found.
end_version: Last version in the range to squash (inclusive).
When None, defaults to the last sequential migration found.
description: Description for the squashed migration file.
When None, prompts interactively.
dry_run: Preview changes without applying.
update_database: Update migration records in database.
yes: Skip confirmation prompt.
allow_gaps: Allow gaps in version sequence.
output_format: Output format ("sql" or "py").
Raises:
SquashValidationError: If validation fails (invalid range, gaps, etc.).
"""
squasher = MigrationSquasher(self.migrations_path, self.runner, self._template_settings)
# Infer start/end from all sequential migrations when not provided
if start_version is None or end_version is None:
all_migrations = self.runner.get_migration_files()
sequential = [(v, p) for v, p in all_migrations if v.isdigit() or v.lstrip("0").isdigit()]
if not sequential:
console.print("[yellow]No sequential migrations found to squash[/]")
return
if start_version is None:
start_version = sequential[0][0]
if end_version is None:
end_version = sequential[-1][0]
console.print(f"[cyan]Squashing range: {start_version} to {end_version}[/]")
# Prompt for description when not provided
if description is None:
from rich.prompt import Prompt
description = Prompt.ask("Migration description", default="squashed_migrations")
plans = squasher.plan_squash(
start_version, end_version, description, allow_gaps=allow_gaps, output_format=output_format
)
# Display plan for each squash group
table = Table(title="Squash Plan")
table.add_column("Version", style="cyan")
table.add_column("File")
table.add_column("Target", style="green")
total_migrations = 0
for plan in plans:
for version, file_path in plan.source_migrations:
table.add_row(version, file_path.name, plan.target_path.name)
total_migrations += 1
console.print(table)
target_files = ", ".join(p.target_path.name for p in plans)
console.print(
f"\n[yellow]{total_migrations} migrations will be squashed into {len(plans)} file(s): {target_files}[/]"
)
if dry_run:
console.print("[yellow][Preview Mode - No changes made][/]")
return
if not yes:
response = input("\nProceed with squash? [y/N]: ")
if response.lower() != "y":
console.print("[yellow]Squash cancelled[/]")
return
squasher.apply_squash(plans)
for plan in plans:
console.print(f"[green]✓ Created squashed migration: {plan.target_path.name}[/]")
if update_database:
with self.config.provide_session() as driver:
self.tracker.ensure_tracking_table(driver)
for plan in plans:
if self.tracker.is_squash_already_applied(driver, plan.target_version, plan.source_versions):
up_sql, down_sql = squasher.extract_sql(plan.source_migrations)
if plan.target_path.suffix == ".py":
content = squasher.generate_python_squash(plan, up_sql, down_sql)
else:
content = squasher.generate_squashed_content(plan, up_sql, down_sql)
checksum = self.runner.calculate_checksum(content)
self.tracker.replace_with_squash(
driver, plan.target_version, plan.source_versions, description, checksum
)
console.print("[green]✓ Updated migration tracking table[/]")
console.print("[green]✓ Squash complete![/]")
[docs]
def fix(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None:
"""Convert timestamp migrations to sequential format.
Implements hybrid versioning workflow where development uses timestamps
and production uses sequential numbers. Creates backup before changes
and provides rollback on errors.
Args:
dry_run: Preview changes without applying.
update_database: Update migration records in database.
yes: Skip confirmation prompt.
Examples:
>>> commands.fix(dry_run=True) # Preview only
>>> commands.fix(yes=True) # Auto-approve
>>> commands.fix(update_database=False) # Files only
"""
all_migrations = self.runner.get_migration_files()
conversion_map = generate_conversion_map(all_migrations)
if not conversion_map:
console.print("[yellow]No timestamp migrations found - nothing to convert[/]")
return
fixer = MigrationFixer(self.migrations_path)
renames = fixer.plan_renames(conversion_map)
table = Table(title="Migration Conversions")
table.add_column("Current Version", style="cyan")
table.add_column("New Version", style="green")
table.add_column("File")
for rename in renames:
table.add_row(rename.old_version, rename.new_version, rename.old_path.name)
console.print(table)
console.print(f"\n[yellow]{len(renames)} migrations will be converted[/]")
if dry_run:
console.print("[yellow][Preview Mode - No changes made][/]")
return
if not yes:
response = input("\nProceed with conversion? [y/N]: ")
if response.lower() != "y":
console.print("[yellow]Conversion cancelled[/]")
return
try:
backup_path = fixer.create_backup()
console.print(f"[green]✓ Created backup in {backup_path.name}[/]")
fixer.apply_renames(renames)
for rename in renames:
console.print(f"[green]✓ Renamed {rename.old_path.name} → {rename.new_path.name}[/]")
if update_database:
with self.config.provide_session() as driver:
self.tracker.ensure_tracking_table(driver)
applied_migrations = self.tracker.get_applied_migrations(driver)
applied_versions = {m["version_num"] for m in applied_migrations}
updated_count = 0
for old_version, new_version in conversion_map.items():
if old_version in applied_versions:
self.tracker.update_version_record(driver, old_version, new_version)
updated_count += 1
if updated_count > 0:
console.print(
f"[green]✓ Updated {updated_count} version records in migration tracking table[/]"
)
else:
console.print("[green]✓ No applied migrations to update in tracking table[/]")
fixer.cleanup()
console.print("[green]✓ Conversion complete![/]")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/]")
fixer.rollback()
console.print("[yellow]Restored files from backup[/]")
raise
[docs]
class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]):
"""Asynchronous migration commands."""
[docs]
def __init__(self, config: "AsyncConfigT") -> None:
"""Initialize migration commands.
Args:
config: The SQLSpec configuration.
"""
super().__init__(config)
self.tracker = config.migration_tracker_type(self.version_table)
# Create context with extension configurations
context = MigrationContext.from_config(config)
context.extension_config = self.extension_configs
self.runner = AsyncMigrationRunner(
self.migrations_path,
self._discover_extension_migrations(),
context,
self.extension_configs,
runtime=self._runtime,
description_hints=self._template_settings.description_hints,
)
[docs]
async def init(self, directory: str, package: bool = True) -> None:
"""Initialize migration directory structure.
Args:
directory: Directory path for migrations.
package: Whether to create __init__.py in the directory.
"""
self.init_directory(directory, package)
[docs]
async def current(self, verbose: bool = False) -> "str | None":
"""Show current migration version.
Args:
verbose: Whether to show detailed migration history.
Returns:
The current migration version or None if no migrations applied.
"""
async with self.config.provide_session() as driver:
await self.tracker.ensure_tracking_table(driver)
current = await self.tracker.get_current_version(driver)
if not current:
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(driver).__name__),
current_version=None,
applied_count=0,
verbose=verbose,
status="empty",
)
console.print("[yellow]No migrations applied yet[/]")
return None
console.print(f"[green]Current version:[/] {current}")
applied: list[dict[str, Any]] = []
if verbose:
applied = await self.tracker.get_applied_migrations(driver)
table = Table(title="Applied Migrations")
table.add_column("Version", style="cyan")
table.add_column("Description")
table.add_column("Applied At")
table.add_column("Time (ms)", justify="right")
table.add_column("Applied By")
for migration in applied:
table.add_row(
migration["version_num"],
migration.get("description", ""),
str(migration.get("applied_at", "")),
str(migration.get("execution_time_ms", "")),
migration.get("applied_by", ""),
)
console.print(table)
applied_count = len(applied) if verbose else None
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(driver).__name__),
current_version=current,
applied_count=applied_count,
verbose=verbose,
status="complete",
)
return cast("str | None", current)
async def _load_single_migration_checksum(
self, version: str, file_path: "Path"
) -> "tuple[str, tuple[str, Path]] | None":
"""Load checksum for a single migration.
Args:
version: Migration version.
file_path: Path to migration file.
Returns:
Tuple of (version, (checksum, file_path)) or None if load fails.
"""
try:
migration = await self.runner.load_migration(file_path, version)
return (version, (migration["checksum"], file_path))
except Exception as exc:
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(self.config).__name__),
version=version,
file_path=str(file_path),
error_type=type(exc).__name__,
status="failed",
operation="load_checksum",
)
return None
async def _load_migration_checksums(
self, all_migrations: "list[tuple[str, Path]]"
) -> "dict[str, tuple[str, Path]]":
"""Load checksums for all migrations.
Args:
all_migrations: List of (version, file_path) tuples.
Returns:
Dictionary mapping version to (checksum, file_path) tuples.
"""
file_checksums = {}
for version, file_path in all_migrations:
result = await self._load_single_migration_checksum(version, file_path)
if result:
file_checksums[result[0]] = result[1]
return file_checksums
def _collect_pending_migrations(
self, all_migrations: "list[tuple[str, Path]]", applied_set: set[str], revision: str
) -> "list[tuple[str, Path]]":
"""Collect pending migrations that need to be applied.
Args:
all_migrations: List of all (version, file_path) tuples.
applied_set: Set of already applied version strings.
revision: Target revision ("head" or specific version).
Returns:
List of (version, file_path) tuples for pending migrations.
"""
pending = []
for version, file_path in all_migrations:
if version not in applied_set:
if revision == "head":
pending.append((version, file_path))
else:
parsed_version = parse_version(version)
parsed_revision = parse_version(revision)
if parsed_version <= parsed_revision:
pending.append((version, file_path))
return pending
def _report_no_pending_migrations(
self, use_logger: bool, echo: bool, summary_only: bool, has_migrations: bool
) -> None:
"""Report that there are no pending migrations.
Args:
use_logger: Whether to output to logger instead of console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
has_migrations: Whether any migrations exist at all.
"""
if not has_migrations:
_output_info(
use_logger,
echo,
summary_only,
"No migrations found. Create your first migration with 'sqlspec create-migration'.",
rich_message="[yellow]No migrations found. Create your first migration with 'sqlspec create-migration'.[/]",
)
else:
_output_info(
use_logger,
echo,
summary_only,
"Already at latest version",
rich_message="[green]Already at latest version[/]",
)
async def _apply_single_migration(
self, driver: Any, migration: "dict[str, Any]", version: str, use_logger: bool, echo: bool, summary_only: bool
) -> int | None:
"""Apply a single migration and record it.
Args:
driver: Database driver instance.
migration: Migration dictionary with version, description, checksum.
version: Version string.
use_logger: Whether to output to logger instead of console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
Returns:
Execution time in ms on success, None on failure.
"""
try:
async def record_version(exec_time: int, migration: "dict[str, Any]" = migration) -> None:
await self.tracker.record_migration(
driver, migration["version"], migration["description"], exec_time, migration["checksum"]
)
_, execution_time = await self.runner.execute_upgrade(driver, migration, on_success=record_version)
except Exception as exc:
use_txn = self.runner.should_use_transaction(migration, self.config)
rollback_msg = " (transaction rolled back)" if use_txn else ""
_output_exception(
use_logger,
echo,
summary_only,
"Migration %s failed%s",
version,
rollback_msg,
rich_message=f"[red]✗ Failed{rollback_msg}: {exc}[/]",
)
self._last_command_error = exc
return None
else:
_output_info(
use_logger,
echo,
summary_only,
"Applied migration %s in %dms",
version,
execution_time,
rich_message=f"[green]✓ Applied in {execution_time}ms[/]",
)
return execution_time
def _collect_revert_migrations(self, applied: "list[dict[str, Any]]", revision: str) -> "list[dict[str, Any]]":
"""Collect migrations to revert based on target revision.
Args:
applied: List of applied migration records.
revision: Target revision ("-1", "base", or specific version).
Returns:
List of migration records to revert.
"""
if revision == "-1":
return [applied[-1]]
if revision == "base":
return list(reversed(applied))
parsed_revision = parse_version(revision)
to_revert = []
for migration in reversed(applied):
parsed_migration_version = parse_version(migration["version_num"])
if parsed_migration_version > parsed_revision:
to_revert.append(migration)
return to_revert
async def _revert_single_migration(
self, driver: Any, migration: "dict[str, Any]", version: str, use_logger: bool, echo: bool, summary_only: bool
) -> int | None:
"""Revert a single migration.
Args:
driver: Database driver instance.
migration: Migration dictionary.
version: Version string.
use_logger: Whether to output to logger instead of console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
Returns:
Execution time in ms on success, None on failure.
"""
try:
async def remove_version(exec_time: int, version: str = version) -> None:
await self.tracker.remove_migration(driver, version)
_, execution_time = await self.runner.execute_downgrade(driver, migration, on_success=remove_version)
except Exception as exc:
use_txn = self.runner.should_use_transaction(migration, self.config)
rollback_msg = " (transaction rolled back)" if use_txn else ""
_output_exception(
use_logger,
echo,
summary_only,
"Migration %s failed%s",
version,
rollback_msg,
rich_message=f"[red]✗ Failed{rollback_msg}: {exc}[/]",
)
self._last_command_error = exc
return None
else:
_output_info(
use_logger,
echo,
summary_only,
"Reverted migration %s in %dms",
version,
execution_time,
rich_message=f"[green]✓ Reverted in {execution_time}ms[/]",
)
return execution_time
async def _synchronize_version_records(
self, driver: Any, *, use_logger: bool = False, echo: bool = True, summary_only: bool = False
) -> int:
"""Synchronize database version records with migration files.
Auto-updates DB tracking when migrations have been renamed by fix command.
This allows developers to just run upgrade after pulling changes without
manually running fix.
Validates checksums match before updating to prevent incorrect matches.
Args:
driver: Database driver instance.
use_logger: If True, output to logger instead of Rich console.
echo: Whether to echo output to the console.
summary_only: Whether summary-only logging is enabled.
Returns:
Number of version records updated.
"""
all_migrations = await self.runner.get_migration_files()
try:
applied_migrations = await self.tracker.get_applied_migrations(driver)
except Exception as exc:
log_with_context(
logger,
logging.DEBUG,
"migration.list",
db_system=resolve_db_system(type(driver).__name__),
error_type=type(exc).__name__,
status="failed",
operation="applied_fetch",
)
return 0
applied_map = {m["version_num"]: m for m in applied_migrations}
conversion_map = generate_conversion_map(all_migrations)
updated_count = 0
if conversion_map:
for old_version, new_version in conversion_map.items():
if old_version in applied_map and new_version not in applied_map:
applied_checksum = applied_map[old_version]["checksum"]
file_path = next((path for v, path in all_migrations if v == new_version), None)
if file_path:
migration = await self.runner.load_migration(file_path, new_version)
if migration["checksum"] == applied_checksum:
await self.tracker.update_version_record(driver, old_version, new_version)
if use_logger:
if not summary_only:
logger.info("Reconciled version: %s -> %s", old_version, new_version)
elif echo:
console.print(f" [dim]Reconciled version:[/] {old_version} → {new_version}")
updated_count += 1
elif use_logger:
logger.warning(
"Checksum mismatch for %s -> %s, skipping auto-sync", old_version, new_version
)
elif echo:
console.print(
f" [yellow]Warning: Checksum mismatch for {old_version} → {new_version}, skipping auto-sync[/]"
)
else:
file_checksums = await self._load_migration_checksums(all_migrations)
for applied_version, applied_record in applied_map.items():
for file_version, (file_checksum, _) in file_checksums.items():
if file_version not in applied_map and applied_record["checksum"] == file_checksum:
await self.tracker.update_version_record(driver, applied_version, file_version)
if use_logger:
if not summary_only:
logger.info("Reconciled version: %s -> %s", applied_version, file_version)
elif echo:
console.print(f" [dim]Reconciled version:[/] {applied_version} → {file_version}")
updated_count += 1
break
if updated_count > 0:
if use_logger:
if not summary_only:
logger.info("Reconciled %d version record(s)", updated_count)
elif echo:
console.print(f"[cyan]Reconciled {updated_count} version record(s)[/]")
return updated_count
[docs]
@_with_command_span("upgrade", metadata_fn=_upgrade_metadata)
async def upgrade(
self,
revision: str = "head",
allow_missing: bool = False,
auto_sync: bool = True,
dry_run: bool = False,
*,
use_logger: bool = False,
echo: bool | None = None,
summary_only: bool | None = None,
) -> None:
"""Upgrade to a target revision.
Validates migration order and warns if out-of-order migrations are detected.
Out-of-order migrations can occur when branches merge in different orders
across environments.
Args:
revision: Target revision or "head" for latest.
allow_missing: If True, allow out-of-order migrations even in strict mode.
Defaults to False.
auto_sync: If True, automatically reconcile renamed migrations in database.
Defaults to True. Can be disabled via --no-auto-sync flag.
dry_run: If True, show what would be done without making changes.
use_logger: If True, output to logger instead of Rich console.
Defaults to False. Can be set via MigrationConfig for persistent default.
echo: Echo output to the console. Defaults to True when unset.
summary_only: Emit a single summary log entry when logger output is enabled.
"""
runtime = self._runtime
applied_count = 0
pending_count = 0
db_system: str | None = None
error: Exception | None = None
ul, echo_value, summary_value = self._resolve_output_policy(use_logger, echo, summary_only)
self.runner.set_use_logger(ul)
self.runner.set_summary_only(summary_value)
self.tracker.set_output_policy(use_logger=ul, echo=echo_value, summary_only=summary_value)
output_info = functools.partial(_output_info, ul, echo_value, summary_value)
start_time = time.perf_counter()
try:
if dry_run:
output_info(
"DRY RUN MODE: No database changes will be applied",
rich_message="[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n",
)
async with self.config.provide_session() as driver:
db_system = resolve_db_system(type(driver).__name__)
await self.tracker.ensure_tracking_table(driver)
if auto_sync and self.config.migration_config.get("auto_sync", True):
await self._synchronize_version_records(
driver, use_logger=ul, echo=echo_value, summary_only=summary_value
)
applied_migrations = await self.tracker.get_applied_migrations(driver)
applied_versions = [m["version_num"] for m in applied_migrations]
applied_set = set(applied_versions)
all_migrations = await self.runner.get_migration_files()
if runtime is not None:
runtime.increment_metric("migrations.command.upgrade.available", float(len(all_migrations)))
pending = self._collect_pending_migrations(all_migrations, applied_set, revision)
pending_count = len(pending)
if runtime is not None:
runtime.increment_metric("migrations.command.upgrade.pending", float(len(pending)))
if not pending:
self._report_no_pending_migrations(ul, echo_value, summary_value, bool(all_migrations))
return
migration_config = cast("dict[str, Any]", self.config.migration_config) or {}
strict_ordering = migration_config.get("strict_ordering", False) and not allow_missing
validate_migration_order(
[v for v, _ in pending],
applied_versions,
strict_ordering,
use_logger=ul,
echo=echo_value,
summary_only=summary_value,
)
output_info(
"Found %d pending migrations",
len(pending),
rich_message=f"[yellow]Found {len(pending)} pending migrations[/]",
)
for version, file_path in pending:
migration = await self.runner.load_migration(file_path, version)
action_verb = "Would apply" if dry_run else "Applying"
output_info(
"%s %s: %s",
action_verb,
version,
migration["description"],
rich_message=f"\n[cyan]{action_verb} {version}:[/] {migration['description']}",
)
if dry_run:
output_info(
"Migration file: %s", file_path, rich_message=f"[dim]Migration file: {file_path}[/]"
)
continue
result = await self._apply_single_migration(
driver, migration, version, ul, echo_value, summary_value
)
if result is None:
return
applied_count += 1
except Exception as exc: # pragma: no cover - passthrough
error = exc
raise
finally:
duration_ms = int((time.perf_counter() - start_time) * 1000)
_log_command_summary(
use_logger=ul,
summary_only=summary_value,
command="upgrade",
status="failed" if error else "complete",
revision=revision,
dry_run=dry_run,
pending_count=pending_count,
applied_count=applied_count,
reverted_count=None,
duration_ms=duration_ms,
db_system=db_system,
bind_key=getattr(self.config, "bind_key", None),
config_name=type(self.config).__name__,
error=error,
allow_missing=allow_missing,
auto_sync=auto_sync,
)
if dry_run:
output_info(
"Dry run complete. No changes were made to the database.",
rich_message="\n[bold yellow]Dry run complete.[/] No changes were made to the database.",
)
elif applied_count:
self._record_command_metric("applied", float(applied_count))
[docs]
@_with_command_span("downgrade", metadata_fn=_downgrade_metadata)
async def downgrade(
self,
revision: str = "-1",
*,
dry_run: bool = False,
use_logger: bool = False,
echo: bool | None = None,
summary_only: bool | None = None,
) -> None:
"""Downgrade to a target revision.
Args:
revision: Target revision or "-1" for one step back.
dry_run: If True, show what would be done without making changes.
use_logger: If True, output to logger instead of Rich console.
Defaults to False. Can be set via MigrationConfig for persistent default.
echo: Echo output to the console. Defaults to True when unset.
summary_only: Emit a single summary log entry when logger output is enabled.
"""
runtime = self._runtime
reverted_count = 0
pending_count = 0
db_system: str | None = None
error: Exception | None = None
ul, echo_value, summary_value = self._resolve_output_policy(use_logger, echo, summary_only)
self.runner.set_use_logger(ul)
self.runner.set_summary_only(summary_value)
self.tracker.set_output_policy(use_logger=ul, echo=echo_value, summary_only=summary_value)
output_info = functools.partial(_output_info, ul, echo_value, summary_value)
output_error = functools.partial(_output_error, ul, echo_value, summary_value)
start_time = time.perf_counter()
try:
if dry_run:
output_info(
"DRY RUN MODE: No database changes will be applied",
rich_message="[bold yellow]DRY RUN MODE:[/] No database changes will be applied\n",
)
async with self.config.provide_session() as driver:
db_system = resolve_db_system(type(driver).__name__)
await self.tracker.ensure_tracking_table(driver)
applied = await self.tracker.get_applied_migrations(driver)
if runtime is not None:
runtime.increment_metric("migrations.command.downgrade.available", float(len(applied)))
if not applied:
output_info("No migrations to downgrade", rich_message="[yellow]No migrations to downgrade[/]")
return
to_revert = self._collect_revert_migrations(applied, revision)
pending_count = len(to_revert)
if runtime is not None:
runtime.increment_metric("migrations.command.downgrade.pending", float(len(to_revert)))
if not to_revert:
output_info("Nothing to downgrade", rich_message="[yellow]Nothing to downgrade[/]")
return
output_info(
"Reverting %d migrations",
len(to_revert),
rich_message=f"[yellow]Reverting {len(to_revert)} migrations[/]",
)
all_files = dict(await self.runner.get_migration_files())
for migration_record in to_revert:
version = migration_record["version_num"]
if version not in all_files:
output_error(
"Migration file not found for %s",
version,
rich_message=f"[red]Migration file not found for {version}[/]",
)
if runtime is not None:
runtime.increment_metric("migrations.command.downgrade.missing_files")
continue
migration = await self.runner.load_migration(all_files[version], version)
action_verb = "Would revert" if dry_run else "Reverting"
output_info(
"%s %s: %s",
action_verb,
version,
migration["description"],
rich_message=f"\n[cyan]{action_verb} {version}:[/] {migration['description']}",
)
if dry_run:
output_info(
"Migration file: %s",
all_files[version],
rich_message=f"[dim]Migration file: {all_files[version]}[/]",
)
continue
result = await self._revert_single_migration(
driver, migration, version, ul, echo_value, summary_value
)
if result is None:
return
reverted_count += 1
except Exception as exc: # pragma: no cover - passthrough
error = exc
raise
finally:
duration_ms = int((time.perf_counter() - start_time) * 1000)
_log_command_summary(
use_logger=ul,
summary_only=summary_value,
command="downgrade",
status="failed" if error else "complete",
revision=revision,
dry_run=dry_run,
pending_count=pending_count,
applied_count=None,
reverted_count=reverted_count,
duration_ms=duration_ms,
db_system=db_system,
bind_key=getattr(self.config, "bind_key", None),
config_name=type(self.config).__name__,
error=error,
)
if dry_run:
output_info(
"Dry run complete. No changes were made to the database.",
rich_message="\n[bold yellow]Dry run complete.[/] No changes were made to the database.",
)
elif reverted_count:
self._record_command_metric("applied", float(reverted_count))
[docs]
async def stamp(self, revision: str) -> None:
"""Mark database as being at a specific revision without running migrations.
Args:
revision: The revision to stamp.
"""
async with self.config.provide_session() as driver:
await self.tracker.ensure_tracking_table(driver)
all_migrations = dict(await self.runner.get_migration_files())
if revision not in all_migrations:
console.print(f"[red]Unknown revision: {revision}[/]")
return
clear_sql = sql.delete().from_(self.tracker.version_table)
await driver.execute(clear_sql)
await self.tracker.record_migration(driver, revision, f"Stamped to {revision}", 0, "manual-stamp")
console.print(f"[green]Database stamped at revision {revision}[/]")
[docs]
async def revision(self, message: str, file_type: str | None = None) -> None:
"""Create a new migration file with timestamp-based versioning.
Generates a unique timestamp version (YYYYMMDDHHmmss format) to avoid
conflicts when multiple developers create migrations concurrently.
Args:
message: Description for the migration.
file_type: Type of migration file to create ('sql' or 'py').
"""
version = generate_timestamp_version()
selected_format = file_type or self._template_settings.default_format
file_path = create_migration_file(
self.migrations_path,
version,
message,
selected_format,
config=self.config,
template_settings=self._template_settings,
)
log_with_context(
logger,
logging.DEBUG,
"migration.create",
db_system=resolve_db_system(type(self.config).__name__),
version=version,
file_path=str(file_path),
file_type=selected_format,
description=message,
)
console.print(f"[green]Created migration:[/] {file_path}")
[docs]
async def squash(
self,
start_version: str | None = None,
end_version: str | None = None,
description: str | None = None,
*,
dry_run: bool = False,
update_database: bool = True,
yes: bool = False,
allow_gaps: bool = False,
output_format: str = "sql",
) -> None:
"""Squash a range of migrations into a single file.
Combines multiple sequential migrations into a single "release" migration.
UP statements are merged in version order, DOWN statements in reverse order.
Args:
start_version: First version in the range to squash (inclusive).
When None, defaults to the first sequential migration found.
end_version: Last version in the range to squash (inclusive).
When None, defaults to the last sequential migration found.
description: Description for the squashed migration file.
When None, prompts interactively.
dry_run: Preview changes without applying.
update_database: Update migration records in database.
yes: Skip confirmation prompt.
allow_gaps: Allow gaps in version sequence.
output_format: Output format ("sql" or "py").
Raises:
SquashValidationError: If validation fails (invalid range, gaps, etc.).
"""
sync_runner = SyncMigrationRunner(
self.migrations_path,
self._discover_extension_migrations(),
None,
self.extension_configs,
runtime=self._runtime,
description_hints=self._template_settings.description_hints,
)
squasher = MigrationSquasher(self.migrations_path, sync_runner, self._template_settings)
# Infer start/end from all sequential migrations when not provided
if start_version is None or end_version is None:
all_migrations = sync_runner.get_migration_files()
sequential = [(v, p) for v, p in all_migrations if v.isdigit() or v.lstrip("0").isdigit()]
if not sequential:
console.print("[yellow]No sequential migrations found to squash[/]")
return
if start_version is None:
start_version = sequential[0][0]
if end_version is None:
end_version = sequential[-1][0]
console.print(f"[cyan]Squashing range: {start_version} to {end_version}[/]")
# Prompt for description when not provided
if description is None:
import anyio
from rich.prompt import Prompt
description = await anyio.to_thread.run_sync( # pyright: ignore[reportAttributeAccessIssue]
lambda: Prompt.ask("Migration description", default="squashed_migrations")
)
plans = squasher.plan_squash(
start_version,
end_version,
description, # pyright: ignore[reportArgumentType]
allow_gaps=allow_gaps,
output_format=output_format,
)
# Display plan for each squash group
table = Table(title="Squash Plan")
table.add_column("Version", style="cyan")
table.add_column("File")
table.add_column("Target", style="green")
total_migrations = 0
for plan in plans:
for version, file_path in plan.source_migrations:
table.add_row(version, file_path.name, plan.target_path.name)
total_migrations += 1
console.print(table)
target_files = ", ".join(p.target_path.name for p in plans)
console.print(
f"\n[yellow]{total_migrations} migrations will be squashed into {len(plans)} file(s): {target_files}[/]"
)
if dry_run:
console.print("[yellow][Preview Mode - No changes made][/]")
return
if not yes:
import anyio
response = await anyio.to_thread.run_sync(input, "\nProceed with squash? [y/N]: ") # pyright: ignore[reportAttributeAccessIssue]
if response.lower() != "y":
console.print("[yellow]Squash cancelled[/]")
return
squasher.apply_squash(plans)
for plan in plans:
console.print(f"[green]✓ Created squashed migration: {plan.target_path.name}[/]")
if update_database:
async with self.config.provide_session() as driver:
await self.tracker.ensure_tracking_table(driver)
for plan in plans:
if await self.tracker.is_squash_already_applied(driver, plan.target_version, plan.source_versions):
up_sql, down_sql = squasher.extract_sql(plan.source_migrations)
if plan.target_path.suffix == ".py":
content = squasher.generate_python_squash(plan, up_sql, down_sql)
else:
content = squasher.generate_squashed_content(plan, up_sql, down_sql)
checksum = sync_runner.calculate_checksum(content)
await self.tracker.replace_with_squash(
driver, plan.target_version, plan.source_versions, description, checksum
)
console.print("[green]✓ Updated migration tracking table[/]")
console.print("[green]✓ Squash complete![/]")
[docs]
async def fix(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None:
"""Convert timestamp migrations to sequential format.
Implements hybrid versioning workflow where development uses timestamps
and production uses sequential numbers. Creates backup before changes
and provides rollback on errors.
Args:
dry_run: Preview changes without applying.
update_database: Update migration records in database.
yes: Skip confirmation prompt.
Examples:
>>> await commands.fix(dry_run=True) # Preview only
>>> await commands.fix(yes=True) # Auto-approve
>>> await commands.fix(update_database=False) # Files only
"""
all_migrations = await self.runner.get_migration_files()
conversion_map = generate_conversion_map(all_migrations)
if not conversion_map:
console.print("[yellow]No timestamp migrations found - nothing to convert[/]")
return
fixer = MigrationFixer(self.migrations_path)
renames = fixer.plan_renames(conversion_map)
table = Table(title="Migration Conversions")
table.add_column("Current Version", style="cyan")
table.add_column("New Version", style="green")
table.add_column("File")
for rename in renames:
table.add_row(rename.old_version, rename.new_version, rename.old_path.name)
console.print(table)
console.print(f"\n[yellow]{len(renames)} migrations will be converted[/]")
if dry_run:
console.print("[yellow][Preview Mode - No changes made][/]")
return
if not yes:
import anyio
response = await anyio.to_thread.run_sync(input, "\nProceed with conversion? [y/N]: ") # pyright: ignore[reportAttributeAccessIssue]
if response.lower() != "y":
console.print("[yellow]Conversion cancelled[/]")
return
try:
backup_path = fixer.create_backup()
console.print(f"[green]✓ Created backup in {backup_path.name}[/]")
fixer.apply_renames(renames)
for rename in renames:
console.print(f"[green]✓ Renamed {rename.old_path.name} → {rename.new_path.name}[/]")
if update_database:
async with self.config.provide_session() as driver:
await self.tracker.ensure_tracking_table(driver)
applied_migrations = await self.tracker.get_applied_migrations(driver)
applied_versions = {m["version_num"] for m in applied_migrations}
updated_count = 0
for old_version, new_version in conversion_map.items():
if old_version in applied_versions:
await self.tracker.update_version_record(driver, old_version, new_version)
updated_count += 1
if updated_count > 0:
console.print(
f"[green]✓ Updated {updated_count} version records in migration tracking table[/]"
)
else:
console.print("[green]✓ No applied migrations to update in tracking table[/]")
fixer.cleanup()
console.print("[green]✓ Conversion complete![/]")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/]")
fixer.rollback()
console.print("[yellow]Restored files from backup[/]")
raise
[docs]
def create_migration_commands(
config: "SyncConfigT | AsyncConfigT",
) -> "SyncMigrationCommands[SyncConfigT] | AsyncMigrationCommands[AsyncConfigT]":
"""Factory function to create the appropriate migration commands.
Args:
config: The SQLSpec configuration.
Returns:
Appropriate migration commands instance.
"""
if config.is_async:
return cast("AsyncMigrationCommands[AsyncConfigT]", AsyncMigrationCommands(cast("AsyncConfigT", config)))
return cast("SyncMigrationCommands[SyncConfigT]", SyncMigrationCommands(cast("SyncConfigT", config)))