Source code for sqlspec.migrations.utils

"""Utility functions for SQLSpec migrations."""

import importlib
import inspect
import os
import subprocess
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

from sqlspec.migrations.templates import MigrationTemplateSettings, TemplateValidationError, build_template_settings
from sqlspec.utils.logging import get_logger
from sqlspec.utils.text import slugify

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlspec.config import DatabaseConfigProtocol
    from sqlspec.driver import AsyncDriverAdapterBase

__all__ = ("create_migration_file", "drop_all", "get_author")

logger = get_logger(__name__)


[docs] def create_migration_file( migrations_dir: Path, version: str, message: str, file_type: str | None = None, *, config: "DatabaseConfigProtocol[Any, Any, Any] | None" = None, template_settings: "MigrationTemplateSettings | None" = None, ) -> Path: """Create a new migration file from template.""" migration_config = cast("dict[str, Any]", config.migration_config) if config is not None else {} settings = template_settings or build_template_settings(migration_config) author = get_author(migration_config.get("author"), config=config) safe_message = _slugify_message(message) file_format = settings.resolve_format(file_type) extension = "py" if file_format == "py" else "sql" filename = f"{version}_{safe_message or 'migration'}.{extension}" file_path = migrations_dir / filename context = _build_template_context( settings=settings, version=version, message=message, author=author, adapter=_resolve_adapter_name(config), project_slug=_derive_project_slug(config), safe_message=safe_message, ) renderer = settings.profile.python.render if file_format == "py" else settings.profile.sql.render content = renderer(context) file_path.write_text(content, encoding="utf-8") return file_path
[docs] def get_author( author_config: Any | None = None, *, config: "DatabaseConfigProtocol[Any, Any, Any] | None" = None ) -> str: """Resolve author metadata for migration templates.""" if isinstance(author_config, str): token = author_config.strip() if not token: return _resolve_git_author() lowered = token.lower() if lowered == "git": return _resolve_git_author() if lowered == "system": return _get_system_username() if lowered.startswith("env:"): env_var = token.split(":", 1)[1].strip() if not env_var: msg = "Environment author token requires a variable name" raise TemplateValidationError(msg) return _resolve_author_from_env(env_var) if lowered.startswith("callable:"): import_path = token.split(":", 1)[1].strip() if not import_path: msg = "Callable author token requires an import path" raise TemplateValidationError(msg) return _resolve_author_callable(import_path, config) if ":" in token and " " not in token: return _resolve_author_callable(token, config) return token if isinstance(author_config, dict): mode = str(author_config.get("mode") or "static").lower() value = author_config.get("value") if mode == "static": if not isinstance(value, str) or not value.strip(): msg = "Static author value must be a non-empty string" raise TemplateValidationError(msg) return value.strip() if mode == "env": if not isinstance(value, str) or not value.strip(): msg = "Environment author mode requires an environment variable name" raise TemplateValidationError(msg) return _resolve_author_from_env(value.strip()) if mode == "callable": if not isinstance(value, str) or not value.strip(): msg = "Callable author mode requires an import path" raise TemplateValidationError(msg) return _resolve_author_callable(value.strip(), config) if mode == "system": return _get_system_username() if mode == "git": return _resolve_git_author() msg = f"Unsupported author mode '{mode}'" raise TemplateValidationError(msg) return _resolve_git_author()
def _get_git_config(config_key: str) -> str | None: """Retrieve git configuration value. Args: config_key: Git config key (e.g., 'user.name', 'user.email'). Returns: Configuration value if found, None otherwise. """ try: result = subprocess.run( # noqa: S603 ["git", "config", config_key], # noqa: S607 capture_output=True, text=True, timeout=2, check=False, ) if result.returncode == 0 and result.stdout.strip(): return result.stdout.strip() except (subprocess.SubprocessError, FileNotFoundError, OSError) as e: logger.debug("Failed to get git config %s: %s", config_key, e) return None def _get_system_username() -> str: """Get system username from environment. Returns: Username from USER environment variable, or 'unknown' if not set. """ return os.environ.get("USER", "unknown") def _resolve_git_author() -> str: git_name = _get_git_config("user.name") git_email = _get_git_config("user.email") if git_name and git_email: return f"{git_name} <{git_email}>" return _get_system_username() def _resolve_author_from_env(env_var: str) -> str: value = os.environ.get(env_var) if value: return value.strip() msg = f"Environment variable '{env_var}' is not set for migration author" raise TemplateValidationError(msg) def _resolve_author_callable(import_path: str, config: "DatabaseConfigProtocol[Any, Any, Any] | None") -> str: def _raise_callable_error(message: str) -> None: msg = message raise TemplateValidationError(msg) module_name, _, attr_name = import_path.partition(":") if not module_name or not attr_name: _raise_callable_error("Callable author path must be in 'module:function' format") module = importlib.import_module(module_name) candidate_obj = module.__dict__.get(attr_name) if candidate_obj is None or not callable(candidate_obj): _raise_callable_error(f"Callable '{import_path}' is not callable") candidate = cast("Callable[..., Any]", candidate_obj) signature = inspect.signature(candidate) param_count = len(signature.parameters) if param_count > 1: _raise_callable_error("Author callable must accept zero or one positional argument") try: result_value: object = candidate() if param_count == 0 else candidate(config) except Exception as exc: # pragma: no cover - passthrough msg = f"Author callable '{import_path}' raised an error: {exc}" raise TemplateValidationError(msg) from exc result_str: str = str(result_value) return result_str def _build_template_context( *, settings: "MigrationTemplateSettings", version: str, message: str, author: str, adapter: str, project_slug: str, safe_message: str, ) -> "dict[str, str]": created_at = datetime.now(timezone.utc).isoformat() display_message = message or "New migration" description = display_message.strip() or safe_message or version return { "title": settings.profile.title, "version": version, "message": display_message, "description": description, "created_at": created_at, "author": author, "adapter": adapter, "project_slug": project_slug, "slug": safe_message, } def _derive_project_slug(config: "DatabaseConfigProtocol[Any, Any, Any] | None") -> str: if config and config.bind_key: source = config.bind_key elif config: source = config.__class__.__module__.split(".")[0] else: source = Path.cwd().name return _slugify_message(source) def _resolve_adapter_name(config: "DatabaseConfigProtocol[Any, Any, Any] | None") -> str: if config is None: return "UnknownAdapter" driver_type = config.driver_type if driver_type is not None: return str(driver_type.__name__) return type(config).__name__ def _slugify_message(message: str) -> str: slug = slugify(message or "", separator="_") return slug[:50]
[docs] async def drop_all(engine: "AsyncDriverAdapterBase", version_table_name: str, metadata: Any | None = None) -> None: """Drop all tables from the database. Args: engine: The database engine/driver. version_table_name: Name of the version tracking table. metadata: Optional metadata object. Raises: NotImplementedError: Always raised. """ msg = "drop_all functionality requires database-specific implementation" raise NotImplementedError(msg)