Source code for sqlspec.storage.backends.base

"""Base class for storage backends."""

import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator
from typing import Any, NoReturn, cast

from mypy_extensions import mypyc_attr
from typing_extensions import Self

from sqlspec.typing import ArrowRecordBatch, ArrowTable
from sqlspec.utils.sync_tools import CapacityLimiter

__all__ = (
    "AsyncArrowBatchIterator",
    "AsyncBytesIterator",
    "AsyncChunkedBytesIterator",
    "AsyncObStoreStreamIterator",
    "AsyncThreadedBytesIterator",
    "ObjectStoreBase",
    "storage_limiter",
)

# Dedicated capacity limiter for storage I/O operations (100 concurrent ops)
# This is shared across all storage backends to prevent overwhelming the system
storage_limiter = CapacityLimiter(100)


class _ExhaustedSentinel:
    """Sentinel value to signal iterator exhaustion across thread boundaries.

    StopIteration cannot be raised into asyncio Futures, so we use this sentinel
    to signal iterator exhaustion from the thread pool back to the async context.
    """

    __slots__ = ()


_EXHAUSTED = _ExhaustedSentinel()


def _next_or_sentinel(iterator: "Iterator[Any]") -> "Any":
    """Get next item or return sentinel if exhausted.

    This helper wraps next() to catch StopIteration in the thread,
    since StopIteration cannot propagate through asyncio Futures.
    """
    try:
        return next(iterator)
    except StopIteration:
        return _EXHAUSTED


class AsyncArrowBatchIterator:
    """Async iterator wrapper for sync Arrow batch iterators.

    This class implements the async iterator protocol without using async generators,
    allowing it to be compiled by mypyc (which doesn't support async generators).

    The class wraps a synchronous iterator and exposes it as an async iterator,
    enabling usage with `async for` syntax.
    """

    __slots__ = ("_sync_iter",)

    def __init__(self, sync_iterator: "Iterator[ArrowRecordBatch]") -> None:
        """Initialize the async iterator wrapper.

        Args:
            sync_iterator: The synchronous iterator to wrap.
        """
        self._sync_iter = sync_iterator

    def __aiter__(self) -> "AsyncArrowBatchIterator":
        """Return self as the async iterator."""
        return self

    async def __anext__(self) -> "ArrowRecordBatch":
        """Get the next item from the iterator asynchronously.

        Uses asyncio.to_thread to offload the blocking next() call
        to a thread pool, preventing event loop blocking.

        Returns:
            The next Arrow record batch.

        Raises:
            StopAsyncIteration: When the iterator is exhausted.
        """
        result = await asyncio.to_thread(_next_or_sentinel, self._sync_iter)
        if result is _EXHAUSTED:
            raise StopAsyncIteration
        return cast("ArrowRecordBatch", result)


class AsyncBytesIterator:
    """Async iterator wrapper for sync bytes iterators.

    This class implements the async iterator protocol without using async generators,
    allowing it to be compiled by mypyc (which doesn't support async generators).

    The class wraps a synchronous iterator and exposes it as an async iterator,
    enabling usage with `async for` syntax.

    Note: This class blocks the event loop during I/O. For non-blocking streaming,
    use AsyncChunkedBytesIterator with pre-loaded data instead.
    """

    __slots__ = ("_sync_iter",)

    def __init__(self, sync_iterator: "Iterator[bytes]") -> None:
        """Initialize the async iterator wrapper.

        Args:
            sync_iterator: The synchronous iterator to wrap.
        """
        self._sync_iter = sync_iterator

    def __aiter__(self) -> "AsyncBytesIterator":
        """Return self as the async iterator."""
        return self

    async def __anext__(self) -> bytes:
        """Get the next item from the iterator asynchronously.

        Returns:
            The next chunk of bytes.

        Raises:
            StopAsyncIteration: When the iterator is exhausted.
        """
        try:
            return next(self._sync_iter)
        except StopIteration:
            raise StopAsyncIteration from None


class AsyncChunkedBytesIterator:
    """Async iterator that yields pre-loaded bytes data in chunks.

    This class implements the async iterator protocol without using async generators,
    allowing it to be compiled by mypyc (which doesn't support async generators).

    Unlike AsyncBytesIterator, this class works with pre-loaded data and yields
    control to the event loop between chunks via asyncio.sleep(0), ensuring
    the event loop is not blocked during iteration.

    Usage pattern:
        # Load data in thread pool to avoid blocking
        data = await asyncio.to_thread(read_bytes, path)
        # Stream chunks without blocking event loop
        return AsyncChunkedBytesIterator(data, chunk_size=65536)
    """

    __slots__ = ("_chunk_size", "_data", "_offset")

    def __init__(self, data: bytes, chunk_size: int = 65536) -> None:
        """Initialize the chunked bytes iterator.

        Args:
            data: The bytes data to iterate over in chunks.
            chunk_size: Size of each chunk to yield (default: 65536 bytes).
        """
        self._data = data
        self._chunk_size = chunk_size
        self._offset = 0

    def __aiter__(self) -> "AsyncChunkedBytesIterator":
        """Return self as the async iterator."""
        return self

    async def __anext__(self) -> bytes:
        """Get the next chunk of bytes asynchronously.

        Yields control to the event loop via asyncio.sleep(0) before returning
        each chunk, ensuring other tasks can run during iteration.

        Returns:
            The next chunk of bytes.

        Raises:
            StopAsyncIteration: When all data has been yielded.
        """

        if self._offset >= len(self._data):
            raise StopAsyncIteration

        # Yield to event loop to allow other tasks to run
        await asyncio.sleep(0)

        chunk = self._data[self._offset : self._offset + self._chunk_size]
        self._offset += self._chunk_size
        return chunk


class AsyncObStoreStreamIterator:
    """Async iterator wrapper for obstore streaming.

    This class wraps obstore's native async stream and ensures it yields
    bytes objects while remaining compatible with mypyc.
    """

    __slots__ = ("_buffer", "_chunk_size", "_stream", "_stream_exhausted")

    def __init__(self, stream: Any, chunk_size: "int | None" = None) -> None:
        """Initialize the obstore stream wrapper.

        Args:
            stream: The native obstore async stream to wrap.
            chunk_size: Optional chunk size to re-chunk streamed data.
        """
        self._stream = stream
        self._buffer = bytearray()
        self._chunk_size = chunk_size if chunk_size is not None and chunk_size > 0 else None
        self._stream_exhausted = False

    def __aiter__(self) -> "AsyncObStoreStreamIterator":
        """Return self as the async iterator."""
        return self

    async def __anext__(self) -> bytes:
        """Get the next chunk from the obstore stream asynchronously.

        Returns:
            The next chunk of bytes.

        Raises:
            StopAsyncIteration: When the stream is exhausted.
        """
        if self._chunk_size is None:
            try:
                chunk = await self._stream.__anext__()
                return bytes(chunk)
            except StopAsyncIteration:
                raise StopAsyncIteration from None

        while not self._stream_exhausted and len(self._buffer) < self._chunk_size:
            try:
                chunk = await self._stream.__anext__()
            except StopAsyncIteration:
                self._stream_exhausted = True
                break
            self._buffer.extend(bytes(chunk))

        if self._buffer:
            if len(self._buffer) >= self._chunk_size:
                data = bytes(self._buffer[: self._chunk_size])
                del self._buffer[: self._chunk_size]
                return data
            if self._stream_exhausted:
                data = bytes(self._buffer)
                self._buffer.clear()
                return data

        raise StopAsyncIteration from None


class AsyncThreadedBytesIterator:
    """Async iterator that reads from a synchronous file-like object in a thread pool.

    This class implements the async iterator protocol without using async generators,
    allowing it to be compiled by mypyc. It offloads blocking read/close calls
    to a thread pool to avoid blocking the event loop.

    Call aclose() or use as an async context manager to ensure cleanup when
    consumers exit early.
    """

    __slots__ = ("_chunk_size", "_closed", "_file_obj")

    def __init__(self, file_obj: Any, chunk_size: int = 65536) -> None:
        """Initialize the threaded bytes iterator.

        Args:
            file_obj: Synchronous file-like object supporting read() and close().
            chunk_size: Size of each chunk to read (default: 65536 bytes).
        """
        self._file_obj = file_obj
        self._chunk_size = chunk_size
        self._closed = False

    def __aiter__(self) -> "AsyncThreadedBytesIterator":
        """Return self as the async iterator."""
        return self

    async def __aenter__(self) -> Self:
        """Return the iterator for async context manager usage."""
        return self

    async def __aexit__(
        self, exc_type: "type[BaseException] | None", exc: "BaseException | None", tb: "Any | None"
    ) -> None:
        """Close the underlying file when exiting a context."""
        await self.aclose()

    def __del__(self) -> None:
        """Best-effort cleanup for early exit."""
        self._close_sync()

    def _raise_stop(self) -> NoReturn:
        raise StopAsyncIteration

    def _close_sync(self) -> None:
        if self._closed:
            return
        self._closed = True
        try:
            self._file_obj.close()
        except Exception:
            return

    async def _close_async(self) -> None:
        if self._closed:
            return
        self._closed = True
        await asyncio.to_thread(self._file_obj.close)

    async def aclose(self) -> None:
        """Close the underlying file object."""
        await self._close_async()

    async def __anext__(self) -> bytes:
        """Read the next chunk of bytes in a thread pool.

        Returns:
            The next chunk of bytes.
        """
        try:
            chunk = await asyncio.to_thread(self._file_obj.read, self._chunk_size)
        except EOFError:
            await self._close_async()
            self._raise_stop()
        except BaseException:
            await asyncio.shield(self._close_async())
            raise

        if not chunk:
            await self._close_async()
            self._raise_stop()

        return cast("bytes", chunk)


[docs] @mypyc_attr(allow_interpreted_subclasses=True) class ObjectStoreBase(ABC): """Base class for storage backends. All synchronous methods follow the *_sync naming convention for consistency with their async counterparts. """ __slots__ = ()
[docs] @abstractmethod def read_bytes_sync(self, path: str, **kwargs: Any) -> bytes: """Read bytes from storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def write_bytes_sync(self, path: str, data: bytes, **kwargs: Any) -> None: """Write bytes to storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def stream_read_sync(self, path: str, chunk_size: "int | None" = None, **kwargs: Any) -> Iterator[bytes]: """Stream bytes from storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def read_text_sync(self, path: str, encoding: str = "utf-8", **kwargs: Any) -> str: """Read text from storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def write_text_sync(self, path: str, data: str, encoding: str = "utf-8", **kwargs: Any) -> None: """Write text to storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def list_objects_sync(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> "list[str]": """List objects in storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def exists_sync(self, path: str, **kwargs: Any) -> bool: """Check if object exists in storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def delete_sync(self, path: str, **kwargs: Any) -> None: """Delete object from storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def copy_sync(self, source: str, destination: str, **kwargs: Any) -> None: """Copy object within storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def move_sync(self, source: str, destination: str, **kwargs: Any) -> None: """Move object within storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def glob_sync(self, pattern: str, **kwargs: Any) -> "list[str]": """Find objects matching pattern synchronously.""" raise NotImplementedError
[docs] @abstractmethod def get_metadata_sync(self, path: str, **kwargs: Any) -> "dict[str, object]": """Get object metadata from storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def is_object_sync(self, path: str) -> bool: """Check if path points to an object synchronously.""" raise NotImplementedError
[docs] @abstractmethod def is_path_sync(self, path: str) -> bool: """Check if path points to a directory synchronously.""" raise NotImplementedError
[docs] @abstractmethod def read_arrow_sync(self, path: str, **kwargs: Any) -> ArrowTable: """Read Arrow table from storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def write_arrow_sync(self, path: str, table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table to storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod def stream_arrow_sync(self, pattern: str, **kwargs: Any) -> Iterator[ArrowRecordBatch]: """Stream Arrow record batches from storage synchronously.""" raise NotImplementedError
[docs] @abstractmethod async def read_bytes_async(self, path: str, **kwargs: Any) -> bytes: """Read bytes from storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def write_bytes_async(self, path: str, data: bytes, **kwargs: Any) -> None: """Write bytes to storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def read_text_async(self, path: str, encoding: str = "utf-8", **kwargs: Any) -> str: """Read text from storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def write_text_async(self, path: str, data: str, encoding: str = "utf-8", **kwargs: Any) -> None: """Write text to storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def stream_read_async( self, path: str, chunk_size: "int | None" = None, **kwargs: Any ) -> AsyncIterator[bytes]: """Stream bytes from storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def list_objects_async(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> "list[str]": """List objects in storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def exists_async(self, path: str, **kwargs: Any) -> bool: """Check if object exists in storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def delete_async(self, path: str, **kwargs: Any) -> None: """Delete object from storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def copy_async(self, source: str, destination: str, **kwargs: Any) -> None: """Copy object within storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def move_async(self, source: str, destination: str, **kwargs: Any) -> None: """Move object within storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def get_metadata_async(self, path: str, **kwargs: Any) -> "dict[str, object]": """Get object metadata from storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def read_arrow_async(self, path: str, **kwargs: Any) -> ArrowTable: """Read Arrow table from storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod async def write_arrow_async(self, path: str, table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table to storage asynchronously.""" raise NotImplementedError
[docs] @abstractmethod def stream_arrow_async(self, pattern: str, **kwargs: Any) -> AsyncIterator[ArrowRecordBatch]: """Stream Arrow record batches from storage asynchronously.""" raise NotImplementedError