"""Migration context for passing runtime information to migrations."""
import asyncio
import inspect
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from sqlglot.dialects.dialect import Dialect
from sqlspec.protocols import HasStatementConfigProtocol
from sqlspec.utils.logging import get_logger
from sqlspec.utils.type_guards import has_statement_config_factory
if TYPE_CHECKING:
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
logger = get_logger("sqlspec.migrations.context")
__all__ = ("MigrationContext",)
def _normalize_dialect_name(dialect: Any | None) -> "str | None":
if dialect is None:
return None
if isinstance(dialect, str):
return dialect
if isinstance(dialect, type):
return dialect.__name__
if isinstance(dialect, Dialect):
return dialect.__class__.__name__
return None
[docs]
@dataclass
class MigrationContext:
"""Context object passed to migration functions.
Provides runtime information about the database environment
to migration functions, allowing them to generate dialect-specific SQL.
"""
config: "Any | None" = None
"""Database configuration object."""
dialect: "str | None" = None
"""Database dialect (e.g., 'postgres', 'mysql', 'sqlite')."""
metadata: "dict[str, Any] | None" = None
"""Additional metadata for the migration."""
extension_config: "dict[str, Any] | None" = None
"""Extension-specific configuration options."""
driver: "SyncDriverAdapterBase | AsyncDriverAdapterBase | None" = None
"""Database driver instance (available during execution)."""
_execution_metadata: "dict[str, Any]" = field(default_factory=dict)
"""Internal execution metadata for tracking async operations."""
[docs]
def __post_init__(self) -> None:
"""Initialize metadata and extension config if not provided."""
if not self.metadata:
self.metadata = {}
if not self.extension_config:
self.extension_config = {}
[docs]
@classmethod
def from_config(cls, config: Any) -> "MigrationContext":
"""Create context from database configuration.
Args:
config: Database configuration object.
Returns:
Migration context with dialect information.
"""
dialect: Any | None = None
try:
if isinstance(config, HasStatementConfigProtocol) and config.statement_config:
dialect = config.statement_config.dialect
elif has_statement_config_factory(config):
stmt_config = config._create_statement_config() # pyright: ignore[reportPrivateUsage]
dialect = stmt_config.dialect
except Exception:
logger.debug("Unable to extract dialect from config")
return cls(dialect=_normalize_dialect_name(dialect), config=config)
@property
def is_async_execution(self) -> bool:
"""Check if migrations are running in an async execution context.
Returns:
True if executing in an async context.
"""
try:
asyncio.current_task()
except RuntimeError:
return False
else:
return True
@property
def is_async_driver(self) -> bool:
"""Check if the current driver is async.
Returns:
True if driver supports async operations.
"""
if self.driver is None:
return False
execute_method = self.driver.execute_script
return inspect.iscoroutinefunction(execute_method)
@property
def execution_mode(self) -> str:
"""Get the current execution mode.
Returns:
'async' if in async context, 'sync' otherwise.
"""
return "async" if self.is_async_execution else "sync"
[docs]
def set_execution_metadata(self, key: str, value: Any) -> None:
"""Set execution metadata for tracking migration state.
Args:
key: Metadata key.
value: Metadata value.
"""
self._execution_metadata[key] = value
[docs]
def get_execution_metadata(self, key: str, default: Any = None) -> Any:
"""Get execution metadata.
Args:
key: Metadata key.
default: Default value if key not found.
Returns:
Metadata value or default.
"""
return self._execution_metadata.get(key, default)
[docs]
def validate_async_usage(self, migration_func: Any) -> None:
"""Validate proper usage of async functions in migration context.
Args:
migration_func: The migration function to validate.
"""
if inspect.iscoroutinefunction(migration_func) and not self.is_async_execution and not self.is_async_driver:
msg = (
"Async migration function detected but execution context is sync. "
"Consider using async database configuration or sync migration functions."
)
logger.warning(msg)
if not inspect.iscoroutinefunction(migration_func) and self.is_async_driver:
self.set_execution_metadata("mixed_execution", value=True)
logger.debug("Sync migration function in async driver context - using compatibility mode")