"""Storage pipeline scaffolding for driver-aware storage bridge."""
from collections import deque
from functools import partial
from pathlib import Path
from time import perf_counter, time
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, TypeAlias, cast
from uuid import uuid4
from mypy_extensions import mypyc_attr
from typing_extensions import NotRequired, TypedDict
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.storage._utils import import_pyarrow, import_pyarrow_csv, import_pyarrow_parquet
from sqlspec.storage.errors import execute_async_storage_operation, execute_sync_storage_operation
from sqlspec.storage.registry import StorageRegistry, storage_registry
from sqlspec.utils.serializers import from_json, get_serializer_metrics, serialize_collection, to_json
from sqlspec.utils.sync_tools import async_
from sqlspec.utils.type_guards import supports_async_delete, supports_async_read_bytes, supports_async_write_bytes
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator
from sqlspec.protocols import ObjectStoreProtocol
from sqlspec.typing import ArrowTable
__all__ = (
"AsyncStoragePipeline",
"PartitionStrategyConfig",
"StagedArtifact",
"StorageBridgeJob",
"StorageCapabilities",
"StorageDestination",
"StorageDiagnostics",
"StorageFormat",
"StorageLoadRequest",
"StorageTelemetry",
"SyncStoragePipeline",
"create_storage_bridge_job",
"get_recent_storage_events",
"get_storage_bridge_diagnostics",
"get_storage_bridge_metrics",
"record_storage_diagnostic_event",
"reset_storage_bridge_events",
"reset_storage_bridge_metrics",
)
StorageFormat = Literal["jsonl", "json", "parquet", "arrow-ipc", "csv"]
StorageDestination: TypeAlias = str | Path
StorageDiagnostics: TypeAlias = dict[str, float]
[docs]
class StorageCapabilities(TypedDict):
"""Runtime-evaluated driver storage capabilities."""
arrow_export_enabled: bool
arrow_import_enabled: bool
parquet_export_enabled: bool
parquet_import_enabled: bool
requires_staging_for_load: bool
staging_protocols: "list[str]"
partition_strategies: "list[str]"
default_storage_profile: NotRequired[str | None]
[docs]
class PartitionStrategyConfig(TypedDict, total=False):
"""Configuration for partition fan-out strategies."""
kind: str
partitions: int
rows_per_chunk: int
manifest_path: str
[docs]
class StorageLoadRequest(TypedDict):
"""Request describing a staging allocation."""
partition_id: str
destination_uri: str
ttl_seconds: int
correlation_id: str
source_uri: NotRequired[str]
[docs]
class StagedArtifact(TypedDict):
"""Metadata describing a staged artifact managed by the pipeline."""
partition_id: str
uri: str
cleanup_token: str
ttl_seconds: int
expires_at: float
correlation_id: str
[docs]
class StorageTelemetry(TypedDict, total=False):
"""Telemetry payload for storage bridge operations."""
destination: str
bytes_processed: int
rows_processed: int
partitions_created: int
duration_s: float
format: str
extra: "dict[str, object]"
backend: str
correlation_id: str
config: str
bind_key: str
[docs]
class StorageBridgeJob(NamedTuple):
"""Handle representing a storage bridge operation."""
job_id: str
status: str
telemetry: StorageTelemetry
class _StorageBridgeMetrics:
__slots__ = ("bytes_written", "partitions_created")
def __init__(self) -> None:
self.bytes_written = 0
self.partitions_created = 0
def record_bytes(self, count: int) -> None:
self.bytes_written += max(count, 0)
def record_partitions(self, count: int) -> None:
self.partitions_created += max(count, 0)
def snapshot(self) -> "dict[str, int]":
return {
"storage_bridge.bytes_written": self.bytes_written,
"storage_bridge.partitions_created": self.partitions_created,
}
def reset(self) -> None:
self.bytes_written = 0
self.partitions_created = 0
_METRICS = _StorageBridgeMetrics()
_RECENT_STORAGE_EVENTS: "deque[StorageTelemetry]" = deque(maxlen=25)
[docs]
def get_storage_bridge_metrics() -> "dict[str, int]":
"""Return aggregated storage bridge metrics."""
return _METRICS.snapshot()
[docs]
def reset_storage_bridge_metrics() -> None:
"""Reset aggregated storage bridge metrics."""
_METRICS.reset()
def record_storage_diagnostic_event(telemetry: StorageTelemetry) -> None:
"""Record telemetry for inclusion in diagnostics snapshots."""
_RECENT_STORAGE_EVENTS.append(cast("StorageTelemetry", dict(telemetry)))
def get_recent_storage_events() -> "list[StorageTelemetry]":
"""Return recent storage telemetry events (most recent first)."""
return [cast("StorageTelemetry", dict(entry)) for entry in _RECENT_STORAGE_EVENTS]
def reset_storage_bridge_events() -> None:
"""Clear recorded storage telemetry events."""
_RECENT_STORAGE_EVENTS.clear()
[docs]
def create_storage_bridge_job(status: str, telemetry: StorageTelemetry) -> StorageBridgeJob:
"""Create a storage bridge job handle with a unique identifier."""
job = StorageBridgeJob(job_id=str(uuid4()), status=status, telemetry=telemetry)
record_storage_diagnostic_event(job.telemetry)
return job
[docs]
def get_storage_bridge_diagnostics() -> "StorageDiagnostics":
"""Return aggregated storage bridge + serializer cache metrics."""
diagnostics: dict[str, float] = {key: float(value) for key, value in get_storage_bridge_metrics().items()}
serializer_metrics = get_serializer_metrics()
for key, value in serializer_metrics.items():
diagnostics[f"serializer.{key}"] = float(value)
return diagnostics
def _encode_row_payload(rows: "list[Any]", format_hint: StorageFormat) -> bytes:
if format_hint == "json":
data = to_json(rows, as_bytes=True)
if isinstance(data, bytes):
return data
return data.encode()
buffer = bytearray()
for row in rows:
buffer.extend(to_json(row, as_bytes=True))
buffer.extend(b"\n")
return bytes(buffer)
def _encode_arrow_payload(
table: "ArrowTable",
format_choice: StorageFormat,
*,
compression: str | None,
write_options: "dict[str, Any] | None" = None,
) -> bytes:
pa = import_pyarrow()
sink = pa.BufferOutputStream()
if format_choice == "arrow-ipc":
writer = pa.ipc.new_file(sink, table.schema)
writer.write_table(table)
writer.close()
elif format_choice == "csv":
pa_csv = import_pyarrow_csv()
csv_opts: Any = None
if write_options:
csv_opts = pa_csv.WriteOptions(**write_options)
pa_csv.write_csv(table, sink, write_options=csv_opts)
else:
pq = import_pyarrow_parquet()
pq.write_table(table, sink, compression=compression)
buffer = sink.getvalue()
result_bytes: bytes = buffer.to_pybytes()
return result_bytes
def _delete_backend_sync(backend: "ObjectStoreProtocol", path: str, *, backend_name: str) -> None:
execute_sync_storage_operation(
partial(backend.delete_sync, path), backend=backend_name, operation="delete", path=path
)
def _write_backend_sync(backend: "ObjectStoreProtocol", path: str, payload: bytes, *, backend_name: str) -> None:
execute_sync_storage_operation(
partial(backend.write_bytes_sync, path, payload), backend=backend_name, operation="write_bytes", path=path
)
def _read_backend_sync(backend: "ObjectStoreProtocol", path: str, *, backend_name: str) -> bytes:
return execute_sync_storage_operation(
partial(backend.read_bytes_sync, path), backend=backend_name, operation="read_bytes", path=path
)
def _decode_arrow_payload(payload: bytes, format_choice: StorageFormat) -> "ArrowTable":
pa = import_pyarrow()
if format_choice == "parquet":
pq = import_pyarrow_parquet()
return cast("ArrowTable", pq.read_table(pa.BufferReader(payload)))
if format_choice == "arrow-ipc":
reader = pa.ipc.open_file(pa.BufferReader(payload))
return cast("ArrowTable", reader.read_all())
if format_choice == "csv":
pa_csv = import_pyarrow_csv()
return cast("ArrowTable", pa_csv.read_csv(pa.BufferReader(payload)))
text_payload = payload.decode()
if format_choice == "json":
data = from_json(text_payload)
rows = data if isinstance(data, list) else [data]
return cast("ArrowTable", pa.Table.from_pylist(rows))
if format_choice == "jsonl":
rows = [from_json(line) for line in text_payload.splitlines() if line.strip()]
return cast("ArrowTable", pa.Table.from_pylist(rows))
msg = f"Unsupported storage format for Arrow decoding: {format_choice}"
raise ValueError(msg)
def _resolve_alias_destination(
registry: StorageRegistry, destination: str, backend_options: "dict[str, Any]"
) -> "tuple[ObjectStoreProtocol, str] | None":
if not destination.startswith("alias://"):
return None
payload = destination.removeprefix("alias://")
alias_name, _, relative_path = payload.partition("/")
alias = alias_name.strip()
if not alias:
msg = "Alias destinations must include a registry alias before the path component"
raise ImproperConfigurationError(msg)
path_segment = relative_path.strip()
if not path_segment:
msg = "Alias destinations must include an object path after the alias name"
raise ImproperConfigurationError(msg)
backend = registry.get(alias, **backend_options)
return backend, path_segment.lstrip("/")
def _normalize_path_for_backend(destination: str) -> str:
if destination.startswith("file://"):
return destination.removeprefix("file://")
if "://" in destination:
_, remainder = destination.split("://", 1)
return remainder.lstrip("/")
return destination
def _resolve_storage_backend(
registry: StorageRegistry, destination: StorageDestination, backend_options: "dict[str, Any] | None"
) -> "tuple[ObjectStoreProtocol, str]":
destination_str = destination.as_posix() if isinstance(destination, Path) else str(destination)
options = backend_options or {}
alias_resolution = _resolve_alias_destination(registry, destination_str, options)
if alias_resolution is not None:
return alias_resolution
backend = registry.get(destination_str, **options)
normalized_path = _normalize_path_for_backend(destination_str)
return backend, normalized_path
[docs]
@mypyc_attr(allow_interpreted_subclasses=True)
class SyncStoragePipeline:
"""Pipeline coordinating storage registry operations and telemetry."""
__slots__ = ("registry",)
[docs]
def __init__(self, *, registry: StorageRegistry | None = None) -> None:
self.registry = registry or storage_registry
def _resolve_backend(
self, destination: StorageDestination, backend_options: "dict[str, Any] | None"
) -> "tuple[ObjectStoreProtocol, str]":
"""Resolve storage backend and normalized path for a destination."""
return _resolve_storage_backend(self.registry, destination, backend_options)
[docs]
def write_rows(
self,
rows: "list[dict[str, Any]]",
destination: StorageDestination,
*,
format_hint: StorageFormat | None = None,
storage_options: "dict[str, Any] | None" = None,
) -> StorageTelemetry:
"""Write dictionary rows to storage using cached serializers."""
serialized = serialize_collection(rows)
format_choice = format_hint or "jsonl"
payload = _encode_row_payload(serialized, format_choice)
return self._write_bytes(
payload,
destination,
rows=len(serialized),
format_label=format_choice,
storage_options=storage_options or {},
)
[docs]
def write_arrow(
self,
table: "ArrowTable",
destination: StorageDestination,
*,
format_hint: StorageFormat | None = None,
storage_options: "dict[str, Any] | None" = None,
compression: str | None = None,
) -> StorageTelemetry:
"""Write an Arrow table to storage using zero-copy buffers."""
format_choice = format_hint or "parquet"
format_write_options = (storage_options or {}).get("write_options") if format_choice == "csv" else None
payload = _encode_arrow_payload(
table, format_choice, compression=compression, write_options=format_write_options
)
return self._write_bytes(
payload,
destination,
rows=int(table.num_rows),
format_label=format_choice,
storage_options=storage_options or {},
)
[docs]
def read_arrow(
self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None
) -> "tuple[ArrowTable, StorageTelemetry]":
"""Read an artifact from storage and decode it into an Arrow table."""
backend, path = self._resolve_backend(source, storage_options)
backend_name = backend.backend_type
payload = _read_backend_sync(backend, path, backend_name=backend_name)
table = _decode_arrow_payload(payload, file_format)
rows_processed = int(table.num_rows)
telemetry: StorageTelemetry = {
"destination": path,
"bytes_processed": len(payload),
"rows_processed": rows_processed,
"format": file_format,
"backend": backend_name,
}
return table, telemetry
[docs]
def stream_read(
self,
source: StorageDestination,
*,
chunk_size: int | None = None,
storage_options: "dict[str, Any] | None" = None,
) -> "Iterator[bytes]":
"""Stream bytes from an artifact."""
backend, path = self._resolve_backend(source, storage_options)
return backend.stream_read_sync(path, chunk_size=chunk_size)
[docs]
def allocate_staging_artifacts(self, requests: "list[StorageLoadRequest]") -> "list[StagedArtifact]":
"""Allocate staging metadata for upcoming loads."""
artifacts: list[StagedArtifact] = []
now = time()
for request in requests:
ttl = max(request["ttl_seconds"], 0)
cleanup_token = f"{request['correlation_id']}::{request['partition_id']}"
artifacts.append({
"partition_id": request["partition_id"],
"uri": request["destination_uri"],
"cleanup_token": cleanup_token,
"ttl_seconds": ttl,
"expires_at": now + ttl if ttl else now,
"correlation_id": request["correlation_id"],
})
if artifacts:
_METRICS.record_partitions(len(artifacts))
return artifacts
[docs]
def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None:
"""Delete staged artifacts best-effort."""
for artifact in artifacts:
backend, path = self._resolve_backend(artifact["uri"], None)
try:
_delete_backend_sync(backend, path, backend_name=backend.backend_type)
except Exception:
if not ignore_errors:
raise
def _write_bytes(
self,
payload: bytes,
destination: StorageDestination,
*,
rows: int,
format_label: str,
storage_options: "dict[str, Any]",
) -> StorageTelemetry:
backend, path = self._resolve_backend(destination, storage_options)
backend_name = backend.backend_type
start = perf_counter()
_write_backend_sync(backend, path, payload, backend_name=backend_name)
elapsed = perf_counter() - start
bytes_written = len(payload)
_METRICS.record_bytes(bytes_written)
telemetry: StorageTelemetry = {
"destination": path,
"bytes_processed": bytes_written,
"rows_processed": rows,
"duration_s": elapsed,
"format": format_label,
"backend": backend_name,
}
return telemetry
[docs]
@mypyc_attr(allow_interpreted_subclasses=True)
class AsyncStoragePipeline:
"""Async variant of the storage pipeline leveraging async-capable backends when available."""
__slots__ = ("registry",)
[docs]
def __init__(self, *, registry: StorageRegistry | None = None) -> None:
self.registry = registry or storage_registry
async def write_rows(
self,
rows: "list[dict[str, Any]]",
destination: StorageDestination,
*,
format_hint: StorageFormat | None = None,
storage_options: "dict[str, Any] | None" = None,
) -> StorageTelemetry:
serialized = serialize_collection(rows)
format_choice = format_hint or "jsonl"
payload = await async_(_encode_row_payload)(serialized, format_choice)
return await self._write_bytes_async(
payload,
destination,
rows=len(serialized),
format_label=format_choice,
storage_options=storage_options or {},
)
async def write_arrow(
self,
table: "ArrowTable",
destination: StorageDestination,
*,
format_hint: StorageFormat | None = None,
storage_options: "dict[str, Any] | None" = None,
compression: str | None = None,
) -> StorageTelemetry:
format_choice = format_hint or "parquet"
format_write_options = (storage_options or {}).get("write_options") if format_choice == "csv" else None
payload = await async_(_encode_arrow_payload)(
table, format_choice, compression=compression, write_options=format_write_options
)
return await self._write_bytes_async(
payload,
destination,
rows=int(table.num_rows),
format_label=format_choice,
storage_options=storage_options or {},
)
async def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None:
for artifact in artifacts:
backend, path = _resolve_storage_backend(self.registry, artifact["uri"], None)
backend_name = backend.backend_type
if supports_async_delete(backend):
try:
await execute_async_storage_operation(
partial(backend.delete_async, path), backend=backend_name, operation="delete", path=path
)
except Exception:
if not ignore_errors:
raise
continue
try:
await async_(_delete_backend_sync)(backend=backend, path=path, backend_name=backend_name)
except Exception:
if not ignore_errors:
raise
async def _write_bytes_async(
self,
payload: bytes,
destination: StorageDestination,
*,
rows: int,
format_label: str,
storage_options: "dict[str, Any]",
) -> StorageTelemetry:
backend, path = _resolve_storage_backend(self.registry, destination, storage_options)
backend_name = backend.backend_type
start = perf_counter()
if supports_async_write_bytes(backend):
await execute_async_storage_operation(
partial(backend.write_bytes_async, path, payload),
backend=backend_name,
operation="write_bytes",
path=path,
)
else:
await async_(_write_backend_sync)(backend=backend, path=path, payload=payload, backend_name=backend_name)
elapsed = perf_counter() - start
bytes_written = len(payload)
_METRICS.record_bytes(bytes_written)
telemetry: StorageTelemetry = {
"destination": path,
"bytes_processed": bytes_written,
"rows_processed": rows,
"duration_s": elapsed,
"format": format_label,
"backend": backend_name,
}
return telemetry
async def read_arrow_async(
self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None
) -> "tuple[ArrowTable, StorageTelemetry]":
backend, path = _resolve_storage_backend(self.registry, source, storage_options)
backend_name = backend.backend_type
if supports_async_read_bytes(backend):
payload = await execute_async_storage_operation(
partial(backend.read_bytes_async, path), backend=backend_name, operation="read_bytes", path=path
)
else:
payload = await async_(_read_backend_sync)(backend=backend, path=path, backend_name=backend_name)
table = await async_(_decode_arrow_payload)(payload, file_format)
rows_processed = int(table.num_rows)
telemetry: StorageTelemetry = {
"destination": path,
"bytes_processed": len(payload),
"rows_processed": rows_processed,
"format": file_format,
"backend": backend_name,
}
return table, telemetry
[docs]
async def stream_read_async(
self,
source: StorageDestination,
*,
chunk_size: int | None = None,
storage_options: "dict[str, Any] | None" = None,
) -> "AsyncIterator[bytes]":
"""Stream bytes from an artifact asynchronously."""
backend, path = _resolve_storage_backend(self.registry, source, storage_options)
return await backend.stream_read_async(path, chunk_size=chunk_size)