"""Migration template rendering and configuration utilities."""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from sqlspec.exceptions import SQLSpecError
if TYPE_CHECKING:
from collections.abc import Mapping
__all__ = (
"MigrationTemplateProfile",
"MigrationTemplateSettings",
"PythonTemplateDefinition",
"SQLTemplateDefinition",
"TemplateDescriptionHints",
"TemplateValidationError",
"build_template_settings",
)
class TemplateValidationError(SQLSpecError):
"""Raised when a migration template definition is invalid."""
@dataclass(slots=True)
class TemplateDescriptionHints:
"""Hints for extracting descriptions from rendered templates."""
sql_keys: "tuple[str, ...]" = ("Description",)
python_keys: "tuple[str, ...]" = ("Description",)
[docs]
@dataclass(slots=True)
class SQLTemplateDefinition:
"""SQL migration template fragments."""
header: str
metadata: "list[str]" = field(default_factory=list)
body: str = ""
description_keys: "tuple[str, ...]" = ("Description",)
[docs]
def render(self, context: "Mapping[str, str]") -> str:
"""Render the SQL template using the supplied context."""
rendered_lines: list[str] = [self._format(self.header, context)]
rendered_lines.extend(self._format(line, context) for line in self.metadata if line)
rendered_lines.append("")
rendered_lines.append(self._format(self.body, context))
return "\n".join(_normalize_newlines(rendered_lines)).rstrip() + "\n"
def _format(self, template: str, context: "Mapping[str, str]") -> str:
try:
return template.format_map(context)
except KeyError as exc: # pragma: no cover - defensive
missing = str(exc).strip("'")
msg = f"Missing template variable '{missing}' in SQL template"
raise TemplateValidationError(msg) from exc
except ValueError as exc: # pragma: no cover - defensive
msg = f"Invalid SQL template fragment: {exc}"
raise TemplateValidationError(msg) from exc
[docs]
@dataclass(slots=True)
class PythonTemplateDefinition:
"""Python migration template fragments."""
docstring: str
body: str
imports: "list[str]" = field(default_factory=list)
description_keys: "tuple[str, ...]" = ("Description",)
[docs]
def render(self, context: "Mapping[str, str]") -> str:
"""Render the Python template using the supplied context."""
docstring_block = f'"""{self._format(self.docstring, context)}"""'
rendered_lines: list[str] = [docstring_block, ""]
rendered_lines.extend(self.imports)
if self.imports:
rendered_lines.append("")
rendered_lines.append(self._format(self.body, context))
return "\n".join(_normalize_newlines(rendered_lines)).rstrip() + "\n"
def _format(self, template: str, context: "Mapping[str, str]") -> str:
try:
return template.format_map(context)
except KeyError as exc: # pragma: no cover - defensive
missing = str(exc).strip("'")
msg = f"Missing template variable '{missing}' in Python template"
raise TemplateValidationError(msg) from exc
except ValueError as exc: # pragma: no cover - defensive
msg = f"Invalid Python template fragment: {exc}"
raise TemplateValidationError(msg) from exc
[docs]
@dataclass(slots=True)
class MigrationTemplateProfile:
"""Concrete template profile selected via configuration."""
name: str
title: str
sql: "SQLTemplateDefinition"
python: "PythonTemplateDefinition"
[docs]
@dataclass(slots=True)
class MigrationTemplateSettings:
"""Resolved template configuration for a migration command context."""
default_format: str
profile: "MigrationTemplateProfile"
@property
def description_hints(self) -> "TemplateDescriptionHints":
"""Expose description extraction hints derived from the active profile."""
return TemplateDescriptionHints(
sql_keys=self.profile.sql.description_keys, python_keys=self.profile.python.description_keys
)
def build_template_settings(migration_config: dict[str, Any] | None) -> "MigrationTemplateSettings":
"""Build template settings from migration configuration."""
config = migration_config or {}
templates_config = config.get("templates") or {}
default_format = str(config.get("default_format") or "sql").lower()
if default_format not in {"sql", "py"}:
default_format = "sql"
title = str(config.get("title") or templates_config.get("title") or _DEFAULT_TITLE)
sql_definition = _build_sql_definition(templates_config.get("sql"))
python_definition = _build_python_definition(templates_config.get("py"))
profile = MigrationTemplateProfile(name="default", title=title, sql=sql_definition, python=python_definition)
return MigrationTemplateSettings(default_format=default_format, profile=profile)
def _build_sql_definition(overrides: Any) -> "SQLTemplateDefinition":
if overrides is None:
return _DEFAULT_SQL_TEMPLATE
if not isinstance(overrides, dict):
msg = "SQL template override must be a mapping"
raise TemplateValidationError(msg)
header = str(overrides.get("header") or _DEFAULT_SQL_TEMPLATE.header)
metadata = _coerce_string_list(overrides.get("metadata"), _DEFAULT_SQL_TEMPLATE.metadata)
body = str(overrides.get("body") or _DEFAULT_SQL_TEMPLATE.body)
description = _coerce_string_list(overrides.get("description_key"), list(_DEFAULT_SQL_TEMPLATE.description_keys))
description_keys = tuple(description) if description else _DEFAULT_SQL_TEMPLATE.description_keys
return SQLTemplateDefinition(header=header, metadata=metadata, body=body, description_keys=description_keys)
def _build_python_definition(overrides: Any) -> "PythonTemplateDefinition":
if overrides is None:
return _DEFAULT_PY_TEMPLATE
if not isinstance(overrides, dict):
msg = "Python template override must be a mapping"
raise TemplateValidationError(msg)
docstring = str(overrides.get("docstring") or _DEFAULT_PY_TEMPLATE.docstring)
body = str(overrides.get("body") or _DEFAULT_PY_TEMPLATE.body)
imports = _coerce_string_list(overrides.get("imports"), _DEFAULT_PY_TEMPLATE.imports)
description = _coerce_string_list(overrides.get("description_key"), list(_DEFAULT_PY_TEMPLATE.description_keys))
description_keys = tuple(description) if description else _DEFAULT_PY_TEMPLATE.description_keys
return PythonTemplateDefinition(docstring=docstring, body=body, imports=imports, description_keys=description_keys)
def _coerce_string_list(value: Any, default: "list[str]") -> "list[str]":
if value is None:
return list(default)
if isinstance(value, str):
return [line for line in value.splitlines() if line]
if isinstance(value, (list, tuple)):
return [str(item) for item in value if str(item)]
msg = "Template list override must be a string or list"
raise TemplateValidationError(msg)
def _normalize_newlines(lines: "list[str]") -> "list[str]":
normalized: list[str] = [line.rstrip("\r") for line in lines]
return normalized
_DEFAULT_TITLE = "SQLSpec Migration"
_DEFAULT_SQL_TEMPLATE = SQLTemplateDefinition(
header="-- {title}",
metadata=[
"-- Version: {version}",
"-- Description: {description}",
"-- Created: {created_at}",
"-- Author: {author}",
],
body=(
"-- name: migrate-{version}-up\n"
"CREATE TABLE placeholder (\n"
" id INTEGER PRIMARY KEY\n"
");\n\n"
"-- name: migrate-{version}-down\n"
"DROP TABLE placeholder;"
),
)
_DEFAULT_PY_TEMPLATE = PythonTemplateDefinition(
docstring=(
"{title} - {message}\n"
"Description: {description}\n"
"Version: {version}\n"
"Created: {created_at}\n"
"Author: {author}\n\n"
"Replace 'def' with 'async def' if you need awaitables. The optional"
" context argument receives the SQLSpec migration context when provided."
),
imports=["from typing import Iterable"],
body=(
"def up(context: object | None = None) -> str | Iterable[str]:\n"
' """Apply the migration (upgrade)."""\n'
' return """\n'
" CREATE TABLE example (\n"
" id INTEGER PRIMARY KEY,\n"
" name TEXT NOT NULL\n"
" );\n"
' """\n\n'
"def down(context: object | None = None) -> str | Iterable[str]:\n"
' """Reverse the migration."""\n'
' return "DROP TABLE example;"'
),
)