"""Migration squash engine for combining multiple migrations into a single file.
This module provides utilities to consolidate multiple sequential migrations
into a single "release" migration file, following the Django-style squash workflow.
"""
import inspect
import shutil
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING
from sqlspec.exceptions import SquashValidationError
from sqlspec.utils.logging import get_logger
from sqlspec.utils.sync_tools import await_
from sqlspec.utils.text import slugify
if TYPE_CHECKING:
from sqlspec.migrations.runner import SyncMigrationRunner
from sqlspec.migrations.templates import MigrationTemplateSettings
__all__ = ("MigrationSquasher", "SquashPlan", "group_migrations_by_type", "parse_version_range")
logger = get_logger("sqlspec.migrations.squash")
def parse_version_range(range_str: str) -> tuple[str, str]:
"""Parse a version range string into (start, end) tuple.
Accepts multiple formats: ``START:END``, ``START..END``, or ``START-END``.
Args:
range_str: Version range string.
Returns:
Tuple of (start_version, end_version) zero-padded to 4 digits.
Raises:
ValueError: If the format is not recognised.
"""
for sep in (":", "..", "-"):
if sep in range_str:
parts = range_str.split(sep, 1)
start = parts[0].strip().zfill(4)
end = parts[1].strip().zfill(4)
return start, end
msg = f"Invalid VERSION_RANGE format: '{range_str}'. Use START:END, START..END, or START-END (e.g., 1:7)"
raise ValueError(msg)
def group_migrations_by_type(migrations: list[tuple[str, Path]]) -> list[tuple[str, list[tuple[str, Path]]]]:
"""Group consecutive migrations by file type (sql or py).
Partitions a list of migrations into groups where each group contains
consecutive migrations of the same type. This enables squashing mixed
SQL and Python migrations into separate output files.
Args:
migrations: List of (version, path) tuples to group.
Returns:
List of (type, migrations) tuples where type is "sql" or "py"
and migrations is the list of (version, path) for that group.
"""
if not migrations:
return []
groups: list[tuple[str, list[tuple[str, Path]]]] = []
current_type: str | None = None
current_group: list[tuple[str, Path]] = []
for version, path in migrations:
file_type = "py" if path.suffix == ".py" else "sql"
if file_type != current_type:
if current_group and current_type is not None:
groups.append((current_type, current_group))
current_type = file_type
current_group = [(version, path)]
else:
current_group.append((version, path))
if current_group and current_type is not None:
groups.append((current_type, current_group))
return groups
[docs]
@dataclass(slots=True)
class SquashPlan:
"""Represents a planned squash operation.
Attributes:
source_migrations: List of (version, path) tuples for migrations being squashed.
target_version: The version string for the squashed migration.
target_path: Output file path for the squashed migration.
description: Combined description for the squashed migration.
source_versions: List of version strings being replaced (for tracking table updates).
"""
source_migrations: list[tuple[str, Path]]
target_version: str
target_path: Path
description: str
source_versions: list[str]
[docs]
class MigrationSquasher:
"""Core squash engine for combining migrations.
Provides functionality to plan, validate, and execute migration squash operations.
Combines multiple sequential migrations into a single file with merged UP/DOWN SQL.
"""
__slots__ = ("backup_path", "migrations_path", "runner", "template_settings")
[docs]
def __init__(
self,
migrations_path: Path,
runner: "SyncMigrationRunner",
template_settings: "MigrationTemplateSettings | None" = None,
) -> None:
"""Initialize the migration squasher.
Args:
migrations_path: Path to the migrations directory.
runner: SyncMigrationRunner instance for loading migrations.
template_settings: Optional template settings for generating squashed file.
"""
self.migrations_path = migrations_path
self.runner = runner
self.template_settings = template_settings
self.backup_path: Path | None = None
[docs]
def plan_squash(
self,
start_version: str,
end_version: str,
description: str,
*,
allow_gaps: bool = False,
output_format: str = "sql",
) -> list[SquashPlan]:
"""Plan a squash operation for a range of migrations.
For homogeneous migrations (all SQL or all Python), returns a single plan.
For mixed SQL/Python migrations, returns multiple plans - one per
consecutive group of same-type migrations.
Args:
start_version: First version in the range to squash (inclusive).
end_version: Last version in the range to squash (inclusive).
description: Description for the squashed migration file.
allow_gaps: If True, allow gaps in version sequence.
output_format: Output file format ("sql" or "py").
Returns:
List of SquashPlan objects with details of planned operations.
Raises:
SquashValidationError: If validation fails (invalid range, gaps, etc.).
"""
# Validate range direction
if int(start_version) > int(end_version):
msg = f"Invalid range: start version {start_version} is greater than end version {end_version}"
raise SquashValidationError(msg)
# Get all migrations from runner
all_migrations = self.runner.get_migration_files()
version_map = dict(all_migrations)
# Validate versions exist
if start_version not in version_map:
msg = f"Start version {start_version} not found in migrations"
raise SquashValidationError(msg)
if end_version not in version_map:
msg = f"End version {end_version} not found in migrations"
raise SquashValidationError(msg)
# Filter migrations in range
start_int = int(start_version)
end_int = int(end_version)
source_migrations: list[tuple[str, Path]] = []
for version, path in all_migrations:
try:
version_int = int(version)
except ValueError:
continue # Skip non-sequential versions (ext_*, timestamps)
if start_int <= version_int <= end_int:
source_migrations.append((version, path))
# Validate no gaps in sequence (unless allow_gaps is True)
if not allow_gaps and len(source_migrations) > 1:
source_versions_int = sorted(int(v) for v, _ in source_migrations)
for i in range(1, len(source_versions_int)):
if source_versions_int[i] - source_versions_int[i - 1] != 1:
msg = f"Gap detected in version sequence between {source_versions_int[i - 1]:04d} and {source_versions_int[i]:04d}"
raise SquashValidationError(msg)
# Slugify description for safe filenames
safe_description = slugify(description, separator="_")[:50] or "migration"
# Group migrations by type (sql vs py) unless output_format forces a specific format
if output_format == "py":
# Force all output to Python format - single plan with all migrations
extension = ".py"
target_version = f"{int(start_version):04d}"
target_path = self.migrations_path / f"{target_version}_{safe_description}{extension}"
return [
SquashPlan(
source_migrations=source_migrations,
target_version=target_version,
target_path=target_path,
description=description,
source_versions=[v for v, _ in source_migrations],
)
]
# Default: group by type and generate appropriate files
groups = group_migrations_by_type(source_migrations)
# Build plans for each group
plans: list[SquashPlan] = []
version_counter = int(start_version)
for file_type, group_migrations in groups:
group_versions = [v for v, _ in group_migrations]
target_version = f"{version_counter:04d}"
extension = ".py" if file_type == "py" else ".sql"
target_path = self.migrations_path / f"{target_version}_{safe_description}{extension}"
plans.append(
SquashPlan(
source_migrations=group_migrations,
target_version=target_version,
target_path=target_path,
description=description,
source_versions=group_versions,
)
)
version_counter += 1
return plans
[docs]
def generate_squashed_content(self, plan: SquashPlan, up_sql: list[str], down_sql: list[str]) -> str:
"""Generate the content for a squashed migration file.
Args:
plan: The SquashPlan describing the squash operation.
up_sql: List of UP SQL statements (in execution order).
down_sql: List of DOWN SQL statements (in rollback order).
Returns:
Complete SQL file content as a string.
"""
lines: list[str] = []
# Header section
title = "SQLSpec Migration"
if self.template_settings and self.template_settings.profile:
title = self.template_settings.profile.title
lines.extend((
f"-- {title}",
f"-- Version: {plan.target_version}",
f"-- Description: {plan.description}",
f"-- Squashed from: {', '.join(plan.source_versions)}",
"",
f"-- name: migrate-{plan.target_version}-up",
))
for statement in up_sql:
lines.append(statement.rstrip())
if not statement.rstrip().endswith(";"):
pass # Don't add extra semicolons
lines.append("")
# DOWN section (only if there are statements)
if down_sql:
lines.append(f"-- name: migrate-{plan.target_version}-down")
lines.extend(statement.rstrip() for statement in down_sql)
lines.append("")
return "\n".join(lines)
[docs]
def generate_python_squash(self, plan: SquashPlan, up_sql: list[str], down_sql: list[str]) -> str:
"""Generate Python migration file content instead of SQL.
Creates a Python migration file with up() and down() functions
that return the SQL statements as lists.
Args:
plan: The SquashPlan describing the squash operation.
up_sql: List of UP SQL statements (in execution order).
down_sql: List of DOWN SQL statements (in rollback order).
Returns:
Complete Python file content as a string.
"""
lines: list[str] = []
# Module docstring
title = "SQLSpec Migration"
if self.template_settings and self.template_settings.profile:
title = self.template_settings.profile.title
lines.extend([
'"""' + title + ".",
"",
f"Version: {plan.target_version}",
f"Description: {plan.description}",
f"Squashed from: {', '.join(plan.source_versions)}",
'"""',
"",
])
# Generate up() function
lines.extend(["def up() -> list[str]:", ' """Return UP migration SQL statements."""', " return ["])
lines.extend(f" {statement!r}," for statement in up_sql)
lines.extend([" ]", ""])
# Generate down() function
lines.extend(["def down() -> list[str] | None:", ' """Return DOWN migration SQL statements."""'])
if down_sql:
lines.append(" return [")
lines.extend(f" {statement!r}," for statement in down_sql)
lines.extend([" ]", ""])
else:
lines.extend([" return None", ""])
return "\n".join(lines)
[docs]
def apply_squash(self, plans: list[SquashPlan], *, dry_run: bool = False) -> None:
"""Apply the squash operation for one or more plans.
Creates backup, writes squashed files, deletes source migrations,
and cleans up backup on success. Rolls back on error.
Args:
plans: List of SquashPlan objects to execute.
dry_run: If True, no files are modified (preview only).
"""
if dry_run:
logger.debug("Dry run mode - no changes will be made")
return
# Create backup before making changes
self._create_backup()
try:
for plan in plans:
# Extract SQL from source migrations
up_sql, down_sql = self.extract_sql(plan.source_migrations)
# Generate squashed content based on target file type
if plan.target_path.suffix == ".py":
content = self.generate_python_squash(plan, up_sql, down_sql)
else:
content = self.generate_squashed_content(plan, up_sql, down_sql)
# Write the squashed file
plan.target_path.write_text(content, encoding="utf-8")
logger.debug("Wrote squashed migration to %s", plan.target_path)
# Collect all source paths to delete (avoid duplicates across plans)
all_source_paths = {source_path for plan in plans for _, source_path in plan.source_migrations}
# Delete all source migration files
for source_path in all_source_paths:
if source_path.exists():
source_path.unlink()
logger.debug("Deleted source migration %s", source_path)
# Clean up backup on success
self._cleanup_backup()
except Exception:
# Rollback on error
self._rollback_backup()
raise
def _create_backup(self) -> Path:
"""Create timestamped backup directory with all migration files.
Returns:
Path to created backup directory.
"""
timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
backup_dir = self.migrations_path / f".backup_{timestamp}"
backup_dir.mkdir(parents=True, exist_ok=False)
for file_path in self.migrations_path.iterdir():
if file_path.is_file() and not file_path.name.startswith("."):
shutil.copy2(file_path, backup_dir / file_path.name)
self.backup_path = backup_dir
logger.debug("Created backup at %s", backup_dir)
return backup_dir
def _cleanup_backup(self) -> None:
"""Remove backup directory after successful operation."""
if not self.backup_path or not self.backup_path.exists():
return
shutil.rmtree(self.backup_path)
logger.debug("Cleaned up backup at %s", self.backup_path)
self.backup_path = None
def _rollback_backup(self) -> None:
"""Restore migration files from backup on error."""
if not self.backup_path or not self.backup_path.exists():
return
backup_dir = self.backup_path
# Delete any partially created files
for file_path in self.migrations_path.iterdir():
if file_path.is_file() and not file_path.name.startswith("."):
file_path.unlink()
# Restore from backup
for backup_file in backup_dir.iterdir():
if backup_file.is_file():
shutil.copy2(backup_file, self.migrations_path / backup_file.name)
# Clean up the backup directory itself
shutil.rmtree(backup_dir)
self.backup_path = None
logger.debug("Rolled back from backup at %s", backup_dir)