"""Migration execution engine for SQLSpec."""
import ast
import hashlib
import inspect
import logging
import re
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload
from sqlspec.core import SQL
from sqlspec.loader import SQLFileLoader
from sqlspec.migrations.context import MigrationContext
from sqlspec.migrations.loaders import get_migration_loader
from sqlspec.migrations.templates import TemplateDescriptionHints
from sqlspec.migrations.version import parse_version
from sqlspec.observability import resolve_db_system
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.sync_tools import async_, await_
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Coroutine
from sqlspec.config import DatabaseConfigProtocol
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
from sqlspec.observability import ObservabilityRuntime
__all__ = ("AsyncMigrationRunner", "SyncMigrationRunner", "create_migration_runner")
logger = get_logger("sqlspec.migrations.runner")
class _CachedMigrationMetadata:
"""Cached migration metadata keyed by file path."""
__slots__ = ("metadata", "mtime_ns", "size")
def __init__(self, metadata: "dict[str, Any]", mtime_ns: int, size: int) -> None:
self.metadata = metadata
self.mtime_ns = mtime_ns
self.size = size
def clone(self) -> "dict[str, Any]":
return dict(self.metadata)
class _MigrationFileEntry:
"""Represents a migration file discovered during directory scanning."""
__slots__ = ("extension_name", "path")
def __init__(self, path: Path, extension_name: "str | None") -> None:
self.path = path
self.extension_name = extension_name
class BaseMigrationRunner(ABC):
"""Base migration runner with common functionality shared between sync and async implementations."""
def __init__(
self,
migrations_path: Path,
extension_migrations: "dict[str, Path] | None" = None,
context: "MigrationContext | None" = None,
extension_configs: "dict[str, dict[str, Any]] | None" = None,
runtime: "ObservabilityRuntime | None" = None,
description_hints: "TemplateDescriptionHints | None" = None,
summary_only: bool = False,
use_logger: bool = False,
) -> None:
"""Initialize the migration runner.
Args:
migrations_path: Path to the directory containing migration files.
extension_migrations: Optional mapping of extension names to their migration paths.
context: Optional migration context for Python migrations.
extension_configs: Optional mapping of extension names to their configurations.
runtime: Observability runtime shared with command/context consumers.
description_hints: Hints for extracting migration descriptions.
summary_only: Whether summary-only logging is enabled.
use_logger: Whether to emit log output. Defaults to False (CLI mode).
"""
self.migrations_path = migrations_path
self.extension_migrations = extension_migrations or {}
self.runtime = runtime
self.loader = SQLFileLoader(runtime=runtime)
self.project_root: Path | None = None
self.context = context
self.extension_configs = extension_configs or {}
self._listing_digest: str | None = None
self._listing_cache: list[tuple[str, Path]] | None = None
self._listing_signatures: dict[str, tuple[int, int]] = {}
self._metadata_cache: dict[str, _CachedMigrationMetadata] = {}
self.description_hints = description_hints or TemplateDescriptionHints()
self.summary_only = summary_only
self.use_logger = use_logger
def set_summary_only(self, value: bool) -> None:
"""Set summary-only logging behavior for migration runner."""
self.summary_only = value
def set_use_logger(self, value: bool) -> None:
"""Set whether to emit log output.
Args:
value: True to enable logging, False for silent mode (CLI default).
"""
self.use_logger = value
def _log_migration_event(self, level: int, event: str, **extra_fields: Any) -> None:
"""Log migration events, respecting use_logger and summary_only settings."""
if not self.use_logger:
return
if self.summary_only and level == logging.INFO:
return
log_with_context(logger, level, event, **extra_fields)
def _metric(self, name: str, amount: float = 1.0) -> None:
if self.runtime is None:
return
self.runtime.increment_metric(name, amount)
def _iter_directory_entries(self, base_path: Path, extension_name: "str | None") -> "list[_MigrationFileEntry]":
"""Collect migration files discovered under a base path."""
if not base_path.exists():
return []
entries: list[_MigrationFileEntry] = []
for pattern in ("*.sql", "*.py"):
for file_path in sorted(base_path.glob(pattern)):
if file_path.name.startswith("."):
continue
entries.append(_MigrationFileEntry(path=file_path, extension_name=extension_name))
return entries
def _collect_listing_entries(self) -> "tuple[list[_MigrationFileEntry], dict[str, tuple[int, int]], str]":
"""Gather migration files, stat signatures, and digest for cache validation."""
entries: list[_MigrationFileEntry] = []
signatures: dict[str, tuple[int, int]] = {}
digest_source = hashlib.md5(usedforsecurity=False)
for entry in self._iter_directory_entries(self.migrations_path, None):
self._record_entry(entry, entries, signatures, digest_source)
for ext_name, ext_path in self.extension_migrations.items():
for entry in self._iter_directory_entries(ext_path, ext_name):
self._record_entry(entry, entries, signatures, digest_source)
return entries, signatures, digest_source.hexdigest()
def _record_entry(
self,
entry: _MigrationFileEntry,
entries: "list[_MigrationFileEntry]",
signatures: "dict[str, tuple[int, int]]",
digest_source: Any,
) -> None:
"""Record entry metadata for cache decisions."""
try:
stat_result = entry.path.stat()
except FileNotFoundError:
return
path_str = str(entry.path)
token = (stat_result.st_mtime_ns, stat_result.st_size)
signatures[path_str] = token
digest_source.update(path_str.encode("utf-8"))
digest_source.update(f"{token[0]}:{token[1]}".encode())
entries.append(entry)
def _build_sorted_listing(self, entries: "list[_MigrationFileEntry]") -> "list[tuple[str, Path]]":
"""Construct sorted migration listing from directory entries."""
migrations: list[tuple[str, Path]] = []
for entry in entries:
version = self._extract_version(entry.path.name)
if not version:
continue
if entry.extension_name:
version = f"ext_{entry.extension_name}_{version}"
migrations.append((version, entry.path))
def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any":
version_str = migration_tuple[0]
try:
return parse_version(version_str)
except ValueError:
return version_str
return sorted(migrations, key=version_sort_key)
def _log_listing_invalidation(
self, previous: "dict[str, tuple[int, int]]", current: "dict[str, tuple[int, int]]"
) -> None:
"""Log cache invalidation details at INFO level."""
prev_keys = set(previous)
curr_keys = set(current)
added = curr_keys - prev_keys
removed = prev_keys - curr_keys
modified = {key for key in prev_keys & curr_keys if previous[key] != current[key]}
self._log_migration_event(
logging.INFO,
"migration.listing.invalidated",
added_count=len(added),
removed_count=len(removed),
modified_count=len(modified),
)
self._metric("migrations.listing.cache_invalidations")
if added:
self._metric("migrations.listing.added", float(len(added)))
if removed:
self._metric("migrations.listing.removed", float(len(removed)))
if modified:
self._metric("migrations.listing.modified", float(len(modified)))
def _extract_version(self, filename: str) -> "str | None":
"""Extract version from filename.
Supports sequential (0001), timestamp (20251011120000), and extension-prefixed
(ext_litestar_0001) version formats.
Args:
filename: The migration filename.
Returns:
The extracted version string or None.
"""
extension_version_parts = 3
timestamp_min_length = 4
name_without_ext = filename.rsplit(".", 1)[0]
if name_without_ext.startswith("ext_"):
parts = name_without_ext.split("_", 3)
if len(parts) >= extension_version_parts:
return f"{parts[0]}_{parts[1]}_{parts[2]}"
return None
parts = name_without_ext.split("_", 1)
if parts and parts[0].isdigit():
return parts[0] if len(parts[0]) > timestamp_min_length else parts[0].zfill(4)
return None
def calculate_checksum(self, content: str) -> str:
"""Calculate MD5 checksum of migration content.
Canonicalizes content by excluding query name headers that change during
fix command (migrate-{version}-up/down). This ensures checksums remain
stable when converting timestamp versions to sequential format.
Args:
content: The migration file content.
Returns:
MD5 checksum hex string.
"""
canonical_content = re.sub(r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$", "", content, flags=re.MULTILINE)
return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324
@abstractmethod
def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[Any, Any, dict[str, Any]]"]:
"""Load a migration file and extract its components.
Args:
file_path: Path to the migration file.
Returns:
Dictionary containing migration metadata and queries.
For async implementations, returns a coroutine.
"""
def _load_migration_listing(self) -> "list[tuple[str, Path]]":
"""Build the cached migration listing shared by sync/async runners."""
entries, signatures, digest = self._collect_listing_entries()
cached_listing = self._listing_cache
if cached_listing is not None and self._listing_digest == digest:
self._metric("migrations.listing.cache_hit")
self._metric("migrations.listing.files_cached", float(len(cached_listing)))
self._log_migration_event(logging.DEBUG, "migration.listing.cache_hit", file_count=len(cached_listing))
return cached_listing
files = self._build_sorted_listing(entries)
previous_digest = self._listing_digest
previous_signatures = self._listing_signatures
self._metric("migrations.listing.cache_miss")
self._metric("migrations.listing.files_scanned", float(len(files)))
self._listing_cache = files
self._listing_signatures = signatures
self._listing_digest = digest
if previous_digest is None:
self._log_migration_event(logging.DEBUG, "migration.listing.cache_primed", file_count=len(files))
else:
self._log_listing_invalidation(previous_signatures, signatures)
return files
@abstractmethod
def get_migration_files(self) -> "list[tuple[str, Path]] | Awaitable[list[tuple[str, Path]]]":
"""Get all migration files sorted by version."""
def _load_migration_metadata_common(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
"""Load common migration metadata that doesn't require async operations.
Args:
file_path: Path to the migration file.
version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
Returns:
Partial migration metadata dictionary.
"""
cache_key = str(file_path)
stat_result = file_path.stat()
cached_metadata = self._metadata_cache.get(cache_key)
if (
cached_metadata
and cached_metadata.mtime_ns == stat_result.st_mtime_ns
and cached_metadata.size == stat_result.st_size
):
self._metric("migrations.metadata.cache_hit")
self._log_migration_event(logging.DEBUG, "migration.metadata.cache_hit", file_path=cache_key)
metadata = cached_metadata.clone()
metadata["file_path"] = file_path
return metadata
self._metric("migrations.metadata.cache_miss")
self._metric("migrations.metadata.bytes", float(stat_result.st_size))
content = file_path.read_text(encoding="utf-8")
checksum = self.calculate_checksum(content)
if version is None:
version = self._extract_version(file_path.name)
description = self._extract_description(content, file_path)
if not description:
description = file_path.stem.split("_", 1)[1] if "_" in file_path.stem else ""
transactional_match = re.search(
r"^--\s*transactional:\s*(true|false)\s*$", content, re.MULTILINE | re.IGNORECASE
)
transactional = None
if transactional_match:
transactional = transactional_match.group(1).lower() == "true"
metadata = {
"version": version,
"description": description,
"file_path": file_path,
"checksum": checksum,
"content": content,
"transactional": transactional,
}
self._metadata_cache[cache_key] = _CachedMigrationMetadata(
metadata=dict(metadata), mtime_ns=stat_result.st_mtime_ns, size=stat_result.st_size
)
if cached_metadata:
self._log_migration_event(logging.DEBUG, "migration.metadata.cache_invalidated", file_path=cache_key)
else:
self._log_migration_event(logging.DEBUG, "migration.metadata.cached", file_path=cache_key)
return metadata
def _extract_description(self, content: str, file_path: Path) -> str:
if file_path.suffix == ".sql":
return self._extract_sql_description(content)
if file_path.suffix == ".py":
return self._extract_python_description(content)
return ""
def _extract_sql_description(self, content: str) -> str:
keys = self.description_hints.sql_keys
for line in content.splitlines():
stripped = line.strip()
if not stripped:
continue
if stripped.startswith("--"):
body = stripped.lstrip("-").strip()
if not body:
continue
if ":" in body:
key, value = body.split(":", 1)
if key.strip() in keys:
return value.strip()
continue
break
return ""
def _extract_python_description(self, content: str) -> str:
try:
module = ast.parse(content)
except SyntaxError:
return ""
docstring = ast.get_docstring(module) or ""
keys = self.description_hints.python_keys
for line in docstring.splitlines():
stripped = line.strip()
if not stripped:
continue
if ":" in stripped:
key, value = stripped.split(":", 1)
if key.strip() in keys:
return value.strip()
return stripped
return ""
def _get_context_for_migration(self, file_path: Path) -> "MigrationContext | None":
"""Get the appropriate context for a migration file.
Args:
file_path: Path to the migration file.
Returns:
Migration context to use, or None to use default.
"""
context_to_use = self.context
if context_to_use and file_path.name.startswith("ext_"):
version = self._extract_version(file_path.name)
if version and version.startswith("ext_"):
min_extension_version_parts = 3
parts = version.split("_", 2)
if len(parts) >= min_extension_version_parts:
ext_name = parts[1]
if ext_name in self.extension_configs:
context_to_use = MigrationContext(
dialect=self.context.dialect if self.context else None,
config=self.context.config if self.context else None,
driver=self.context.driver if self.context else None,
metadata=self.context.metadata.copy() if self.context and self.context.metadata else {},
extension_config=self.extension_configs[ext_name],
)
for ext_name, ext_path in self.extension_migrations.items():
if file_path.parent == ext_path:
if ext_name in self.extension_configs and self.context:
context_to_use = MigrationContext(
config=self.context.config,
dialect=self.context.dialect,
driver=self.context.driver,
metadata=self.context.metadata.copy() if self.context.metadata else {},
extension_config=self.extension_configs[ext_name],
)
break
return context_to_use
def should_use_transaction(
self, migration: "dict[str, Any]", config: "DatabaseConfigProtocol[Any, Any, Any]"
) -> bool:
"""Determine if migration should run in a transaction.
Args:
migration: Migration metadata dictionary.
config: The database configuration instance.
Returns:
True if migration should be wrapped in a transaction.
"""
if not config.supports_transactional_ddl:
return False
if migration.get("transactional") is not None:
return bool(migration["transactional"])
migration_config = cast("dict[str, Any]", config.migration_config) or {}
return bool(migration_config.get("transactional", True))
[docs]
class SyncMigrationRunner(BaseMigrationRunner):
"""Synchronous migration runner with pure sync methods."""
[docs]
def get_migration_files(self) -> "list[tuple[str, Path]]":
"""Get all migration files sorted by version.
Returns:
List of (version, path) tuples sorted by version.
"""
return self._load_migration_listing()
[docs]
def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
"""Load a migration file and extract its components.
Args:
file_path: Path to the migration file.
version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
Returns:
Dictionary containing migration metadata and queries.
"""
metadata = self._load_migration_metadata_common(file_path, version)
context_to_use = self._get_context_for_migration(file_path)
loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader)
loader.validate_migration_file(file_path)
has_upgrade, has_downgrade = True, False
if file_path.suffix == ".sql":
version = metadata["version"]
up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
else:
try:
has_downgrade = bool(self._get_migration_sql({"loader": loader, "file_path": file_path}, "down"))
except Exception:
has_downgrade = False
metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
return metadata
[docs]
def execute_upgrade(
self,
driver: "SyncDriverAdapterBase",
migration: "dict[str, Any]",
*,
use_transaction: "bool | None" = None,
on_success: "Callable[[int], None] | None" = None,
) -> "tuple[str | None, int]":
"""Execute an upgrade migration.
Args:
driver: The sync database driver to use.
migration: Migration metadata dictionary.
use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
on_success: Callback invoked with execution_time_ms before commit (for version tracking).
Returns:
Tuple of (sql_content, execution_time_ms).
"""
upgrade_sql_list = self._get_migration_sql(migration, "up")
if upgrade_sql_list is None:
self._metric("migrations.upgrade.skipped")
self._log_migration_event(
logging.WARNING,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
status="missing",
)
return None, 0
if use_transaction is None:
config = self.context.config if self.context else None
use_transaction = self.should_use_transaction(migration, config) if config else False
runtime = self.runtime
span = None
if runtime is not None:
version = cast("str | None", migration.get("version"))
span = runtime.start_migration_span("upgrade", version=version)
runtime.increment_metric("migrations.upgrade.invocations")
self._log_migration_event(
logging.INFO,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
use_transaction=use_transaction,
status="start",
)
start_time = time.perf_counter()
execution_time = 0
try:
if use_transaction:
driver.begin()
for sql_statement in upgrade_sql_list:
if sql_statement.strip():
driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
on_success(execution_time)
driver.commit()
else:
for sql_statement in upgrade_sql_list:
if sql_statement.strip():
driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
on_success(execution_time)
except Exception as exc:
if use_transaction:
driver.rollback()
if runtime is not None:
duration_ms = int((time.perf_counter() - start_time) * 1000)
runtime.increment_metric("migrations.upgrade.errors")
runtime.end_migration_span(span, duration_ms=duration_ms, error=exc)
self._log_migration_event(
logging.ERROR,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=int((time.perf_counter() - start_time) * 1000),
error_type=type(exc).__name__,
status="failed",
)
raise
if runtime is not None:
runtime.increment_metric("migrations.upgrade.applied")
runtime.increment_metric("migrations.upgrade.duration_ms", float(execution_time))
runtime.end_migration_span(span, duration_ms=execution_time)
self._log_migration_event(
logging.INFO,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=execution_time,
status="complete",
)
return None, execution_time
[docs]
def execute_downgrade(
self,
driver: "SyncDriverAdapterBase",
migration: "dict[str, Any]",
*,
use_transaction: "bool | None" = None,
on_success: "Callable[[int], None] | None" = None,
) -> "tuple[str | None, int]":
"""Execute a downgrade migration.
Args:
driver: The sync database driver to use.
migration: Migration metadata dictionary.
use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
on_success: Callback invoked with execution_time_ms before commit (for version tracking).
Returns:
Tuple of (sql_content, execution_time_ms).
"""
downgrade_sql_list = self._get_migration_sql(migration, "down")
if downgrade_sql_list is None:
self._metric("migrations.downgrade.skipped")
self._log_migration_event(
logging.WARNING,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
status="missing",
)
return None, 0
if use_transaction is None:
config = self.context.config if self.context else None
use_transaction = self.should_use_transaction(migration, config) if config else False
runtime = self.runtime
span = None
if runtime is not None:
version = cast("str | None", migration.get("version"))
span = runtime.start_migration_span("downgrade", version=version)
runtime.increment_metric("migrations.downgrade.invocations")
self._log_migration_event(
logging.INFO,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
use_transaction=use_transaction,
status="start",
)
start_time = time.perf_counter()
execution_time = 0
try:
if use_transaction:
driver.begin()
for sql_statement in downgrade_sql_list:
if sql_statement.strip():
driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
on_success(execution_time)
driver.commit()
else:
for sql_statement in downgrade_sql_list:
if sql_statement.strip():
driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
on_success(execution_time)
except Exception as exc:
if use_transaction:
driver.rollback()
if runtime is not None:
duration_ms = int((time.perf_counter() - start_time) * 1000)
runtime.increment_metric("migrations.downgrade.errors")
runtime.end_migration_span(span, duration_ms=duration_ms, error=exc)
self._log_migration_event(
logging.ERROR,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=int((time.perf_counter() - start_time) * 1000),
error_type=type(exc).__name__,
status="failed",
)
raise
if runtime is not None:
runtime.increment_metric("migrations.downgrade.applied")
runtime.increment_metric("migrations.downgrade.duration_ms", float(execution_time))
runtime.end_migration_span(span, duration_ms=execution_time)
self._log_migration_event(
logging.INFO,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=execution_time,
status="complete",
)
return None, execution_time
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
"""Get migration SQL for given direction (sync version).
Args:
migration: Migration metadata.
direction: Either 'up' or 'down'.
Returns:
SQL statements for the migration.
"""
# If this is being called during migration loading (no has_*grade field yet),
# don't raise/warn - just proceed to check if the method exists
if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
if direction == "down":
self._log_migration_event(
logging.WARNING, "migration.downgrade.missing", version=migration.get("version")
)
return None
msg = f"Migration {migration.get('version')} has no upgrade query"
raise ValueError(msg)
file_path, loader = migration["file_path"], migration["loader"]
try:
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
sql_statements = (
await_(method, raise_sync_error=False)(file_path)
if inspect.iscoroutinefunction(method)
else method(file_path)
)
except Exception as e:
if direction == "down":
self._log_migration_event(
logging.WARNING, "migration.downgrade.load_failed", version=migration.get("version"), error=str(e)
)
return None
msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
raise ValueError(msg) from e
else:
if sql_statements:
return cast("list[str]", sql_statements)
return None
[docs]
def load_all_migrations(self) -> "dict[str, SQL]":
"""Load all migrations into a single namespace for bulk operations.
Returns:
Dictionary mapping query names to SQL objects.
"""
all_queries = {}
migrations = self.get_migration_files()
for version, file_path in migrations:
if file_path.suffix == ".sql":
self.loader.load_sql(file_path)
for query_name in self.loader.list_queries():
all_queries[query_name] = self.loader.get_sql(query_name)
else:
loader = get_migration_loader(
file_path, self.migrations_path, self.project_root, self.context, self.loader
)
try:
up_sql = await_(loader.get_up_sql, raise_sync_error=False)(file_path)
down_sql = await_(loader.get_down_sql, raise_sync_error=False)(file_path)
if up_sql:
all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
if down_sql:
all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
except Exception as e:
self._log_migration_event(
logging.DEBUG, "migration.python.load_failed", file_path=str(file_path), error=str(e)
)
return all_queries
[docs]
class AsyncMigrationRunner(BaseMigrationRunner):
"""Asynchronous migration runner with pure async methods."""
[docs]
async def get_migration_files(self) -> "list[tuple[str, Path]]":
"""Get all migration files sorted by version.
Returns:
List of (version, path) tuples sorted by version.
"""
return await async_(self._load_migration_listing)()
[docs]
async def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
"""Load a migration file and extract its components.
Args:
file_path: Path to the migration file.
version: Optional pre-extracted version (preserves prefixes like ext_adk_0001).
Returns:
Dictionary containing migration metadata and queries.
"""
metadata = self._load_migration_metadata_common(file_path, version)
context_to_use = self._get_context_for_migration(file_path)
loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader)
loader.validate_migration_file(file_path)
has_upgrade, has_downgrade = True, False
if file_path.suffix == ".sql":
version = metadata["version"]
up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down"
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
else:
try:
has_downgrade = bool(
await self._get_migration_sql_async({"loader": loader, "file_path": file_path}, "down")
)
except Exception:
has_downgrade = False
metadata.update({"has_upgrade": has_upgrade, "has_downgrade": has_downgrade, "loader": loader})
return metadata
[docs]
async def execute_upgrade(
self,
driver: "AsyncDriverAdapterBase",
migration: "dict[str, Any]",
*,
use_transaction: "bool | None" = None,
on_success: "Callable[[int], Awaitable[None]] | None" = None,
) -> "tuple[str | None, int]":
"""Execute an upgrade migration.
Args:
driver: The async database driver to use.
migration: Migration metadata dictionary.
use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
on_success: Async callback invoked with execution_time_ms before commit (for version tracking).
Returns:
Tuple of (sql_content, execution_time_ms).
"""
upgrade_sql_list = await self._get_migration_sql_async(migration, "up")
if upgrade_sql_list is None:
self._metric("migrations.upgrade.skipped")
self._log_migration_event(
logging.WARNING,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
status="missing",
)
return None, 0
if use_transaction is None:
config = self.context.config if self.context else None
use_transaction = self.should_use_transaction(migration, config) if config else False
runtime = self.runtime
span = None
if runtime is not None:
version = cast("str | None", migration.get("version"))
span = runtime.start_migration_span("upgrade", version=version)
runtime.increment_metric("migrations.upgrade.invocations")
self._log_migration_event(
logging.INFO,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
use_transaction=use_transaction,
status="start",
)
start_time = time.perf_counter()
execution_time = 0
try:
if use_transaction:
await driver.begin()
for sql_statement in upgrade_sql_list:
if sql_statement.strip():
await driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
await on_success(execution_time)
await driver.commit()
else:
for sql_statement in upgrade_sql_list:
if sql_statement.strip():
await driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
await on_success(execution_time)
except Exception as exc:
if use_transaction:
await driver.rollback()
if runtime is not None:
duration_ms = int((time.perf_counter() - start_time) * 1000)
runtime.increment_metric("migrations.upgrade.errors")
runtime.end_migration_span(span, duration_ms=duration_ms, error=exc)
self._log_migration_event(
logging.ERROR,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=int((time.perf_counter() - start_time) * 1000),
error_type=type(exc).__name__,
status="failed",
)
raise
if runtime is not None:
runtime.increment_metric("migrations.upgrade.applied")
runtime.increment_metric("migrations.upgrade.duration_ms", float(execution_time))
runtime.end_migration_span(span, duration_ms=execution_time)
self._log_migration_event(
logging.INFO,
"migration.apply",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=execution_time,
status="complete",
)
return None, execution_time
[docs]
async def execute_downgrade(
self,
driver: "AsyncDriverAdapterBase",
migration: "dict[str, Any]",
*,
use_transaction: "bool | None" = None,
on_success: "Callable[[int], Awaitable[None]] | None" = None,
) -> "tuple[str | None, int]":
"""Execute a downgrade migration.
Args:
driver: The async database driver to use.
migration: Migration metadata dictionary.
use_transaction: Override transaction behavior. If None, uses should_use_transaction logic.
on_success: Async callback invoked with execution_time_ms before commit (for version tracking).
Returns:
Tuple of (sql_content, execution_time_ms).
"""
downgrade_sql_list = await self._get_migration_sql_async(migration, "down")
if downgrade_sql_list is None:
self._metric("migrations.downgrade.skipped")
self._log_migration_event(
logging.WARNING,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
status="missing",
)
return None, 0
if use_transaction is None:
config = self.context.config if self.context else None
use_transaction = self.should_use_transaction(migration, config) if config else False
runtime = self.runtime
span = None
if runtime is not None:
version = cast("str | None", migration.get("version"))
span = runtime.start_migration_span("downgrade", version=version)
runtime.increment_metric("migrations.downgrade.invocations")
self._log_migration_event(
logging.INFO,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
use_transaction=use_transaction,
status="start",
)
start_time = time.perf_counter()
execution_time = 0
try:
if use_transaction:
await driver.begin()
for sql_statement in downgrade_sql_list:
if sql_statement.strip():
await driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
await on_success(execution_time)
await driver.commit()
else:
for sql_statement in downgrade_sql_list:
if sql_statement.strip():
await driver.execute_script(sql_statement)
execution_time = int((time.perf_counter() - start_time) * 1000)
if on_success:
await on_success(execution_time)
except Exception as exc:
if use_transaction:
await driver.rollback()
if runtime is not None:
duration_ms = int((time.perf_counter() - start_time) * 1000)
runtime.increment_metric("migrations.downgrade.errors")
runtime.end_migration_span(span, duration_ms=duration_ms, error=exc)
self._log_migration_event(
logging.ERROR,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=int((time.perf_counter() - start_time) * 1000),
error_type=type(exc).__name__,
status="failed",
)
raise
if runtime is not None:
runtime.increment_metric("migrations.downgrade.applied")
runtime.increment_metric("migrations.downgrade.duration_ms", float(execution_time))
runtime.end_migration_span(span, duration_ms=execution_time)
self._log_migration_event(
logging.INFO,
"migration.rollback",
db_system=resolve_db_system(type(driver).__name__),
version=migration.get("version"),
duration_ms=execution_time,
status="complete",
)
return None, execution_time
async def _get_migration_sql_async(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
"""Get migration SQL for given direction (async version).
Args:
migration: Migration metadata.
direction: Either 'up' or 'down'.
Returns:
SQL statements for the migration.
"""
# If this is being called during migration loading (no has_*grade field yet),
# don't raise/warn - just proceed to check if the method exists
if f"has_{direction}grade" in migration and not migration.get(f"has_{direction}grade"):
if direction == "down":
self._log_migration_event(
logging.WARNING, "migration.downgrade.missing", version=migration.get("version")
)
return None
msg = f"Migration {migration.get('version')} has no upgrade query"
raise ValueError(msg)
file_path, loader = migration["file_path"], migration["loader"]
try:
method = loader.get_up_sql if direction == "up" else loader.get_down_sql
sql_statements = await method(file_path)
except Exception as e:
if direction == "down":
self._log_migration_event(
logging.WARNING, "migration.downgrade.load_failed", version=migration.get("version"), error=str(e)
)
return None
msg = f"Failed to load upgrade for migration {migration.get('version')}: {e}"
raise ValueError(msg) from e
else:
if sql_statements:
return cast("list[str]", sql_statements)
return None
[docs]
async def load_all_migrations(self) -> "dict[str, SQL]":
"""Load all migrations into a single namespace for bulk operations.
Returns:
Dictionary mapping query names to SQL objects.
"""
all_queries = {}
migrations = await self.get_migration_files()
for version, file_path in migrations:
if file_path.suffix == ".sql":
await async_(self.loader.load_sql)(file_path)
for query_name in self.loader.list_queries():
all_queries[query_name] = self.loader.get_sql(query_name)
else:
loader = get_migration_loader(
file_path, self.migrations_path, self.project_root, self.context, self.loader
)
try:
up_sql = await loader.get_up_sql(file_path)
down_sql = await loader.get_down_sql(file_path)
if up_sql:
all_queries[f"migrate-{version}-up"] = SQL(up_sql[0])
if down_sql:
all_queries[f"migrate-{version}-down"] = SQL(down_sql[0])
except Exception as e:
self._log_migration_event(
logging.DEBUG, "migration.python.load_failed", file_path=str(file_path), error=str(e)
)
return all_queries
@overload
def create_migration_runner(
migrations_path: Path,
extension_migrations: "dict[str, Path]",
context: "MigrationContext | None",
extension_configs: "dict[str, Any]",
is_async: "Literal[False]" = False,
runtime: "ObservabilityRuntime | None" = None,
description_hints: "TemplateDescriptionHints | None" = None,
) -> SyncMigrationRunner: ...
@overload
def create_migration_runner(
migrations_path: Path,
extension_migrations: "dict[str, Path]",
context: "MigrationContext | None",
extension_configs: "dict[str, Any]",
is_async: "Literal[True]",
runtime: "ObservabilityRuntime | None" = None,
description_hints: "TemplateDescriptionHints | None" = None,
) -> AsyncMigrationRunner: ...
[docs]
def create_migration_runner(
migrations_path: Path,
extension_migrations: "dict[str, Path]",
context: "MigrationContext | None",
extension_configs: "dict[str, Any]",
is_async: bool = False,
runtime: "ObservabilityRuntime | None" = None,
description_hints: "TemplateDescriptionHints | None" = None,
) -> "SyncMigrationRunner | AsyncMigrationRunner":
"""Factory function to create the appropriate migration runner.
Args:
migrations_path: Path to migrations directory.
extension_migrations: Extension migration paths.
context: Migration context.
extension_configs: Extension configurations.
is_async: Whether to create async or sync runner.
runtime: Observability runtime shared with loaders and execution steps.
description_hints: Optional description extraction hints from template profiles.
Returns:
Appropriate migration runner instance.
"""
if is_async:
return AsyncMigrationRunner(
migrations_path,
extension_migrations,
context,
extension_configs,
runtime=runtime,
description_hints=description_hints,
)
return SyncMigrationRunner(
migrations_path,
extension_migrations,
context,
extension_configs,
runtime=runtime,
description_hints=description_hints,
)