Source code for sqlspec.adapters.oracledb.driver

"""Oracle Driver"""

import logging
from typing import TYPE_CHECKING, Any, NamedTuple, cast

import oracledb

from sqlspec.adapters.oracledb._typing import (
    OracleAsyncConnection,
    OracleAsyncCursor,
    OracleAsyncSessionContext,
    OracleSyncConnection,
    OracleSyncCursor,
    OracleSyncSessionContext,
)
from sqlspec.adapters.oracledb.core import (
    ORACLEDB_VERSION,
    build_insert_statement,
    build_pipeline_stack_result,
    build_truncate_statement,
    coerce_large_parameters_async,
    coerce_large_parameters_sync,
    collect_async_rows,
    collect_sync_rows,
    create_mapped_exception,
    default_statement_config,
    driver_profile,
    normalize_column_names,
    normalize_execute_many_parameters_async,
    normalize_execute_many_parameters_sync,
    resolve_row_metadata,
    resolve_rowcount,
)
from sqlspec.adapters.oracledb.data_dictionary import OracledbAsyncDataDictionary, OracledbSyncDataDictionary
from sqlspec.core import (
    SQL,
    StackResult,
    StatementConfig,
    StatementStack,
    build_arrow_result_from_table,
    get_cache_config,
    register_driver_profile,
)
from sqlspec.driver import (
    AsyncDriverAdapterBase,
    BaseAsyncExceptionHandler,
    BaseSyncExceptionHandler,
    StackExecutionObserver,
    SyncDriverAdapterBase,
    describe_stack_statement,
    hash_stack_operations,
)
from sqlspec.exceptions import ImproperConfigurationError, SQLSpecError, StackExecutionError
from sqlspec.utils.logging import get_logger, log_with_context
from sqlspec.utils.module_loader import ensure_pyarrow
from sqlspec.utils.text import normalize_identifier, quote_identifier
from sqlspec.utils.type_guards import has_pipeline_capability

if TYPE_CHECKING:
    from collections.abc import Sequence

    from sqlspec.adapters.oracledb._typing import OraclePipelineDriver
    from sqlspec.builder import QueryBuilder
    from sqlspec.core import ArrowResult, Statement, StatementConfig, StatementFilter
    from sqlspec.core.stack import StackOperation
    from sqlspec.data_dictionary import VersionInfo
    from sqlspec.driver import ExecutionResult
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
    from sqlspec.typing import ArrowReturnFormat, StatementParameters

__all__ = (
    "OracleAsyncDriver",
    "OracleAsyncExceptionHandler",
    "OracleAsyncSessionContext",
    "OracleSyncDriver",
    "OracleSyncExceptionHandler",
    "OracleSyncSessionContext",
)


logger = get_logger(__name__)

# Oracle SQL-context byte thresholds (4000 / 2000) live in driver_features so users
# on MAX_STRING_SIZE=EXTENDED databases can override them; defaults are wired in
# core.apply_driver_features and read at the dispatch_execute call sites below.


PIPELINE_MIN_DRIVER_VERSION: "tuple[int, int, int]" = (2, 4, 0)
PIPELINE_MIN_DATABASE_MAJOR: int = 23


class _CompiledStackOperation(NamedTuple):
    statement: SQL
    sql: str
    parameters: Any
    method: str
    returns_rows: bool
    summary: str


class OraclePipelineMixin:
    """Shared helpers for Oracle pipeline execution."""

    __slots__ = ()

    def _stack_native_blocker(self, stack: "StatementStack") -> "str | None":
        for operation in stack.operations:
            if operation.method == "execute_arrow":
                return "arrow_operation"
            if operation.method == "execute_script":
                return "script_operation"
        return None

    def _log_pipeline_skip(self, reason: str, stack: "StatementStack") -> None:
        log_level = logging.INFO if reason == "env_override" else logging.DEBUG
        log_with_context(
            logger,
            log_level,
            "stack.native_pipeline.skip",
            driver=type(self).__name__,
            reason=reason,
            hashed_operations=hash_stack_operations(stack),
        )

    def _prepare_pipeline_operation(self, operation: "StackOperation") -> _CompiledStackOperation:
        driver = cast("OraclePipelineDriver", self)
        kwargs = dict(operation.keyword_arguments) if operation.keyword_arguments else {}
        statement_config = kwargs.pop("statement_config", None)
        config = statement_config or driver.statement_config

        if operation.method == "execute":
            sql_statement = driver.prepare_statement(
                operation.statement, operation.arguments, statement_config=config, kwargs=kwargs
            )
        elif operation.method == "execute_many":
            if not operation.arguments:
                msg = "execute_many stack operation requires parameter sets"
                raise ValueError(msg)
            parameter_sets = operation.arguments[0]
            filters = operation.arguments[1:]
            if isinstance(operation.statement, SQL):
                statement_seed = operation.statement.raw_expression or operation.statement.raw_sql
                sql_statement = SQL(statement_seed, parameter_sets, statement_config=config, is_many=True, **kwargs)
            else:
                base_statement = driver.prepare_statement(
                    operation.statement, filters, statement_config=config, kwargs=kwargs
                )
                statement_seed = base_statement.raw_expression or base_statement.raw_sql
                sql_statement = SQL(statement_seed, parameter_sets, statement_config=config, is_many=True, **kwargs)
        else:
            msg = f"Unsupported stack operation method: {operation.method}"
            raise ValueError(msg)

        compiled_sql, prepared_parameters = driver._get_compiled_sql(  # pyright: ignore[reportPrivateUsage]
            sql_statement, config
        )
        summary = describe_stack_statement(operation.statement)
        return _CompiledStackOperation(
            statement=sql_statement,
            sql=compiled_sql,
            parameters=prepared_parameters,
            method=operation.method,
            returns_rows=sql_statement.returns_rows(),
            summary=summary,
        )

    def _add_pipeline_operation(self, pipeline: Any, operation: _CompiledStackOperation) -> None:
        parameters = operation.parameters or []
        if operation.method == "execute":
            if operation.returns_rows:
                pipeline.add_fetchall(operation.sql, parameters)
            else:
                pipeline.add_execute(operation.sql, parameters)
            return

        if operation.method == "execute_many":
            pipeline.add_executemany(operation.sql, parameters)
            return

        msg = f"Unsupported pipeline operation: {operation.method}"
        raise ValueError(msg)

    def _build_stack_results_from_pipeline(
        self,
        compiled_operations: "Sequence[_CompiledStackOperation]",
        pipeline_results: "Sequence[Any]",
        continue_on_error: bool,
        observer: StackExecutionObserver,
    ) -> "list[StackResult]":
        driver = cast("OraclePipelineDriver", self)
        stack_results: list[StackResult] = []
        for index, (compiled, result) in enumerate(zip(compiled_operations, pipeline_results, strict=False)):
            try:
                error = result.error
            except AttributeError:
                error = None
            if error is not None:
                stack_error = StackExecutionError(
                    index,
                    compiled.summary,
                    error,
                    adapter=type(self).__name__,
                    mode="continue-on-error" if continue_on_error else "fail-fast",
                )
                if continue_on_error:
                    observer.record_operation_error(stack_error)
                    stack_results.append(StackResult.from_error(stack_error))
                    continue
                raise stack_error

            stack_results.append(
                build_pipeline_stack_result(
                    compiled.statement,
                    compiled.method,
                    compiled.returns_rows,
                    compiled.parameters,
                    result,
                    driver.driver_features,
                )
            )
        return stack_results

    def _wrap_pipeline_error(
        self, error: Exception, stack: "StatementStack", continue_on_error: bool
    ) -> StackExecutionError:
        mode = "continue-on-error" if continue_on_error else "fail-fast"
        return StackExecutionError(
            -1, "Oracle pipeline execution failed", error, adapter=type(self).__name__, mode=mode
        )


class OracleSyncExceptionHandler(BaseSyncExceptionHandler):
    """Sync Context manager for handling Oracle database exceptions.

    Maps Oracle ORA-XXXXX error codes to specific SQLSpec exceptions
    for better error handling in application code.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __exit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, oracledb.DatabaseError):
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


class OracleAsyncExceptionHandler(BaseAsyncExceptionHandler):
    """Async context manager for handling Oracle database exceptions.

    Maps Oracle ORA-XXXXX error codes to specific SQLSpec exceptions
    for better error handling in application code.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __aexit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ()

    def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, oracledb.DatabaseError):
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


[docs] class OracleSyncDriver(OraclePipelineMixin, SyncDriverAdapterBase): """Synchronous Oracle Database driver. Provides Oracle Database connectivity with parameter style conversion, error handling, and transaction management. """ __slots__ = ( "_data_dictionary", "_oracle_version", "_pipeline_support", "_pipeline_support_reason", "_row_metadata_cache", ) dialect = "oracle"
[docs] def __init__( self, connection: OracleSyncConnection, statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: OracledbSyncDataDictionary | None = None self._pipeline_support: bool | None = None self._pipeline_support_reason: str | None = None self._oracle_version: VersionInfo | None = None self._row_metadata_cache: dict[int, tuple[Any, list[str], bool]] = {}
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement with Oracle data handling. Args: cursor: Oracle cursor object statement: SQL statement to execute Returns: Execution result containing data for SELECT statements or row count for others """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = coerce_large_parameters_sync( self.connection, prepared_parameters, clob_type=oracledb.DB_TYPE_CLOB, blob_type=oracledb.DB_TYPE_BLOB, varchar2_byte_limit=self.driver_features.get("oracle_varchar2_byte_limit", 4000), raw_byte_limit=self.driver_features.get("oracle_raw_byte_limit", 2000), ) prepared_parameters = cast("list[Any] | tuple[Any, ...] | dict[Any, Any] | None", prepared_parameters) cursor.execute(sql, prepared_parameters or {}) # SELECT result processing for Oracle is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor) if is_select_like: fetched_data = cursor.fetchall() column_names, requires_lob_coercion = self._resolve_row_metadata(cursor.description) data, column_names = collect_sync_rows( cast("list[Any] | None", fetched_data), cursor.description, self.driver_features, column_names=column_names, requires_lob_coercion=requires_lob_coercion, ) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True, row_format="tuple", ) # Non-SELECT result processing affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs] def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets using Oracle batch processing. Args: cursor: Oracle cursor object statement: SQL statement with multiple parameter sets Returns: Execution result with affected row count """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters_sync(prepared_parameters) cursor.executemany(sql, prepared_parameters) affected_rows = len(prepared_parameters) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Parameters are embedded as static values for script execution compatibility. Args: cursor: Oracle cursor object statement: SQL script statement to execute Returns: Execution result containing statement count and success information """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = cast("list[Any] | tuple[Any, ...] | dict[Any, Any] | None", prepared_parameters) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: cursor.execute(stmt, prepared_parameters or {}) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
# ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] def begin(self) -> None: """Begin a database transaction. Oracle handles transactions automatically, so this is a no-op. """
# Oracle handles transactions implicitly
[docs] def commit(self) -> None: """Commit the current transaction. Raises: SQLSpecError: If commit fails """ try: self.connection.commit() except oracledb.Error as e: msg = f"Failed to commit Oracle transaction: {e}" raise SQLSpecError(msg) from e
[docs] def rollback(self) -> None: """Rollback the current transaction. Raises: SQLSpecError: If rollback fails """ try: self.connection.rollback() except oracledb.Error as e: msg = f"Failed to rollback Oracle transaction: {e}" raise SQLSpecError(msg) from e
[docs] def set_migration_session_schema(self, schema: str) -> None: """Set Oracle CURRENT_SCHEMA for migration SQL.""" normalized_schema = normalize_identifier(schema, "oracle") quoted_schema = quote_identifier(normalized_schema) with self.with_cursor(self.connection) as cursor: cursor.execute(f"ALTER SESSION SET CURRENT_SCHEMA = {quoted_schema}")
[docs] def has_schema(self, schema: str) -> bool: """Return whether an Oracle schema/user exists.""" normalized_schema = normalize_identifier(schema, "oracle") with self.with_cursor(self.connection) as cursor: cursor.execute("SELECT 1 FROM ALL_USERS WHERE USERNAME = :schema_name", {"schema_name": normalized_schema}) return cursor.fetchone() is not None
[docs] def with_cursor(self, connection: OracleSyncConnection) -> OracleSyncCursor: """Create context manager for Oracle cursor. Args: connection: Oracle database connection Returns: Context manager for cursor operations """ return OracleSyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "OracleSyncExceptionHandler": """Handle database-specific exceptions and wrap them appropriately.""" return OracleSyncExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # ARROW API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def select_to_arrow( self, statement: "Statement | QueryBuilder", /, *parameters: "StatementParameters | StatementFilter", statement_config: "StatementConfig | None" = None, return_format: "ArrowReturnFormat" = "table", native_only: bool = False, batch_size: int | None = None, arrow_schema: Any = None, **kwargs: Any, ) -> "Any": """Execute query and return results as Apache Arrow format using Oracle native support. This implementation uses Oracle's native execute_df()/fetch_df_all() methods which return OracleDataFrame objects with Arrow PyCapsule interface, providing zero-copy data transfer and 5-10x performance improvement over dict conversion. If native Arrow is unavailable and native_only is False, it falls back to the conversion path. Args: statement: SQL query string, Statement, or QueryBuilder *parameters: Query parameters (same format as execute()/select()) statement_config: Optional statement configuration override return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch, "batches" for list of RecordBatch, "reader" for RecordBatchReader native_only: If True, raise error if native Arrow is unavailable batch_size: Rows per batch when using "batch" or "batches" format arrow_schema: Optional pyarrow.Schema for type casting **kwargs: Additional keyword arguments Returns: ArrowResult containing pyarrow.Table or RecordBatch """ ensure_pyarrow() import pyarrow as pa config = statement_config or self.statement_config prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) sql, prepared_parameters = self._get_compiled_sql(prepared_statement, config) try: oracle_df = self._execute_arrow_dataframe(sql, prepared_parameters, batch_size) except AttributeError as exc: if native_only: msg = "Oracle native Arrow support is not available for this connection." raise ImproperConfigurationError(msg) from exc return super().select_to_arrow( prepared_statement, statement_config=config, return_format=return_format, native_only=native_only, batch_size=batch_size, arrow_schema=arrow_schema, ) arrow_table = pa.table(oracle_df) column_names = normalize_column_names(arrow_table.column_names, self.driver_features) if column_names != arrow_table.column_names: arrow_table = arrow_table.rename_columns(column_names) return build_arrow_result_from_table( prepared_statement, arrow_table, return_format=return_format, batch_size=batch_size, arrow_schema=arrow_schema, )
# ───────────────────────────────────────────────────────────────────────────── # STACK EXECUTION METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def execute_stack(self, stack: "StatementStack", *, continue_on_error: bool = False) -> "tuple[StackResult, ...]": """Execute a StatementStack using Oracle's pipeline when available.""" if not isinstance(stack, StatementStack) or not stack: return super().execute_stack(stack, continue_on_error=continue_on_error) blocker = self._stack_native_blocker(stack) if blocker is not None: self._log_pipeline_skip(blocker, stack) return super().execute_stack(stack, continue_on_error=continue_on_error) if not self._pipeline_native_supported(): self._log_pipeline_skip(self._pipeline_support_reason or "database_version", stack) return super().execute_stack(stack, continue_on_error=continue_on_error) return self._execute_stack_native(stack, continue_on_error=continue_on_error)
# ───────────────────────────────────────────────────────────────────────────── # STORAGE API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def select_to_storage( self, statement: "Statement | QueryBuilder | SQL | str", destination: "StorageDestination", /, *parameters: "StatementParameters | StatementFilter", statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute a query and stream Arrow-formatted output to storage (sync).""" self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline = self._storage_pipeline() telemetry_payload = self._write_result_to_storage_sync( arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Load Arrow data into Oracle using batched executemany calls.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: statement = build_truncate_statement(table) exc_handler = self.handle_database_exceptions() with exc_handler: self.connection.execute(statement) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: statement = build_insert_statement(table, columns) exc_handler = self.handle_database_exceptions() with self.with_cursor(self.connection) as cursor, exc_handler: cursor.executemany(statement, records) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Load staged artifacts into Oracle.""" arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "OracledbSyncDataDictionary": """Get the data dictionary for this driver. Returns: Data dictionary instance for metadata queries """ if self._data_dictionary is None: self._data_dictionary = OracledbSyncDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE/INTERNAL METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Oracle sync rows for the direct execution path.""" column_names, requires_lob_coercion = self._resolve_row_metadata(cursor.description) data, column_names = collect_sync_rows( cast("list[Any] | None", fetched), cursor.description, self.driver_features, column_names=column_names, requires_lob_coercion=requires_lob_coercion, ) return data, column_names, len(data)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from Oracle cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction.""" return False def _detect_oracle_version(self) -> "VersionInfo | None": if self._oracle_version is not None: return self._oracle_version version = self.data_dictionary.get_version(self) self._oracle_version = version return version def _detect_oracledb_version(self) -> "tuple[int, int, int]": return ORACLEDB_VERSION def _resolve_row_metadata(self, description: Any) -> "tuple[list[str], bool]": return resolve_row_metadata(description, self.driver_features, self._row_metadata_cache) def _execute_arrow_dataframe(self, sql: str, parameters: "Any", batch_size: int | None) -> "Any": """Execute SQL and return an Oracle DataFrame.""" params = parameters if parameters is not None else [] try: execute_df = self.connection.execute_df except AttributeError: execute_df = None if execute_df is not None: try: return execute_df(sql, params, arraysize=batch_size or 1000) except TypeError: return execute_df(sql, params) return self.connection.fetch_df_all(statement=sql, parameters=params, arraysize=batch_size or 1000) def _execute_stack_native(self, stack: "StatementStack", *, continue_on_error: bool) -> "tuple[StackResult, ...]": compiled_operations = [self._prepare_pipeline_operation(op) for op in stack.operations] pipeline = oracledb.create_pipeline() for compiled in compiled_operations: self._add_pipeline_operation(pipeline, compiled) results: list[StackResult] = [] started_transaction = False with StackExecutionObserver(self, stack, continue_on_error, native_pipeline=True) as observer: try: if not continue_on_error and not self._connection_in_transaction(): self.begin() started_transaction = True pipeline_results = self.connection.run_pipeline(pipeline, continue_on_error=continue_on_error) results = self._build_stack_results_from_pipeline( compiled_operations, pipeline_results, continue_on_error, observer ) if started_transaction: self.commit() except Exception as exc: if started_transaction: try: self.rollback() except Exception as rollback_error: # pragma: no cover logger.debug("Rollback after pipeline failure failed: %s", rollback_error) raise self._wrap_pipeline_error(exc, stack, continue_on_error) from exc return tuple(results) def _pipeline_native_supported(self) -> bool: if self._pipeline_support is not None: return self._pipeline_support if self.stack_native_disabled: self._pipeline_support = False self._pipeline_support_reason = "env_override" return False if self._detect_oracledb_version() < PIPELINE_MIN_DRIVER_VERSION: self._pipeline_support = False self._pipeline_support_reason = "driver_version" return False if not has_pipeline_capability(self.connection): self._pipeline_support = False self._pipeline_support_reason = "driver_api_missing" return False version_info = self._detect_oracle_version() if version_info and version_info.major >= PIPELINE_MIN_DATABASE_MAJOR: self._pipeline_support = True self._pipeline_support_reason = None return True self._pipeline_support = False self._pipeline_support_reason = "database_version" return False
[docs] class OracleAsyncDriver(OraclePipelineMixin, AsyncDriverAdapterBase): """Asynchronous Oracle Database driver. Provides Oracle Database connectivity with parameter style conversion, error handling, and transaction management for async operations. """ __slots__ = ( "_data_dictionary", "_oracle_version", "_pipeline_support", "_pipeline_support_reason", "_row_metadata_cache", ) dialect = "oracle"
[docs] def __init__( self, connection: OracleAsyncConnection, statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: OracledbAsyncDataDictionary | None = None self._pipeline_support: bool | None = None self._pipeline_support_reason: str | None = None self._oracle_version: VersionInfo | None = None self._row_metadata_cache: dict[int, tuple[Any, list[str], bool]] = {}
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement with Oracle data handling. Args: cursor: Oracle cursor object statement: SQL statement to execute Returns: Execution result containing data for SELECT statements or row count for others """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = await coerce_large_parameters_async( self.connection, prepared_parameters, clob_type=oracledb.DB_TYPE_CLOB, blob_type=oracledb.DB_TYPE_BLOB, varchar2_byte_limit=self.driver_features.get("oracle_varchar2_byte_limit", 4000), raw_byte_limit=self.driver_features.get("oracle_raw_byte_limit", 2000), ) prepared_parameters = cast("list[Any] | tuple[Any, ...] | dict[Any, Any] | None", prepared_parameters) await cursor.execute(sql, prepared_parameters or {}) # SELECT result processing for Oracle is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor) if is_select_like: fetched_data = await cursor.fetchall() column_names, requires_lob_coercion = self._resolve_row_metadata(cursor.description) data, column_names = await collect_async_rows( cast("list[Any] | None", fetched_data), cursor.description, self.driver_features, column_names=column_names, requires_lob_coercion=requires_lob_coercion, ) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=len(data), is_select_result=True, row_format="tuple", ) # Non-SELECT result processing affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs] async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets using Oracle batch processing. Args: cursor: Oracle cursor object statement: SQL statement with multiple parameter sets Returns: Execution result with affected row count """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters_async(prepared_parameters) await cursor.executemany(sql, prepared_parameters) affected_rows = len(prepared_parameters) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Parameters are embedded as static values for script execution compatibility. Args: cursor: Oracle cursor object statement: SQL script statement to execute Returns: Execution result containing statement count and success information """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) script_params = cast("dict[str, Any]", prepared_parameters or {}) successful_count = 0 last_cursor = cursor for stmt in statements: await cursor.execute(stmt, script_params) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
# ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] async def begin(self) -> None: """Begin a database transaction. Oracle handles transactions automatically, so this is a no-op. """
# Oracle handles transactions implicitly
[docs] async def commit(self) -> None: """Commit the current transaction. Raises: SQLSpecError: If commit fails """ try: await self.connection.commit() except oracledb.Error as e: msg = f"Failed to commit Oracle transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def rollback(self) -> None: """Rollback the current transaction. Raises: SQLSpecError: If rollback fails """ try: await self.connection.rollback() except oracledb.Error as e: msg = f"Failed to rollback Oracle transaction: {e}" raise SQLSpecError(msg) from e
[docs] async def set_migration_session_schema(self, schema: str) -> None: """Set Oracle CURRENT_SCHEMA for migration SQL.""" normalized_schema = normalize_identifier(schema, "oracle") quoted_schema = quote_identifier(normalized_schema) async with self.with_cursor(self.connection) as cursor: await cursor.execute(f"ALTER SESSION SET CURRENT_SCHEMA = {quoted_schema}")
[docs] async def has_schema(self, schema: str) -> bool: """Return whether an Oracle schema/user exists.""" normalized_schema = normalize_identifier(schema, "oracle") async with self.with_cursor(self.connection) as cursor: await cursor.execute( "SELECT 1 FROM ALL_USERS WHERE USERNAME = :schema_name", {"schema_name": normalized_schema} ) row = await cursor.fetchone() return row is not None
[docs] def with_cursor(self, connection: OracleAsyncConnection) -> OracleAsyncCursor: """Create context manager for Oracle cursor. Args: connection: Oracle database connection Returns: Context manager for cursor operations """ return OracleAsyncCursor(connection)
[docs] def handle_database_exceptions(self) -> "OracleAsyncExceptionHandler": """Handle database-specific exceptions and wrap them appropriately.""" return OracleAsyncExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # ARROW API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def select_to_arrow( self, statement: "Statement | QueryBuilder", /, *parameters: "StatementParameters | StatementFilter", statement_config: "StatementConfig | None" = None, return_format: "ArrowReturnFormat" = "table", native_only: bool = False, batch_size: int | None = None, arrow_schema: Any = None, **kwargs: Any, ) -> "Any": """Execute query and return results as Apache Arrow format using Oracle native support. This implementation uses Oracle's native execute_df()/fetch_df_all() methods which return OracleDataFrame objects with Arrow PyCapsule interface, providing zero-copy data transfer and 5-10x performance improvement over dict conversion. If native Arrow is unavailable and native_only is False, it falls back to the conversion path. Args: statement: SQL query string, Statement, or QueryBuilder *parameters: Query parameters (same format as execute()/select()) statement_config: Optional statement configuration override return_format: "table" for pyarrow.Table (default), "batch" for RecordBatch, "batches" for list of RecordBatch, "reader" for RecordBatchReader native_only: If True, raise error if native Arrow is unavailable batch_size: Rows per batch when using "batch" or "batches" format arrow_schema: Optional pyarrow.Schema for type casting **kwargs: Additional keyword arguments Returns: ArrowResult containing pyarrow.Table or RecordBatch """ ensure_pyarrow() import pyarrow as pa config = statement_config or self.statement_config prepared_statement = self.prepare_statement(statement, parameters, statement_config=config, kwargs=kwargs) sql, prepared_parameters = self._get_compiled_sql(prepared_statement, config) try: oracle_df = await self._execute_arrow_dataframe(sql, prepared_parameters, batch_size) except AttributeError as exc: if native_only: msg = "Oracle native Arrow support is not available for this connection." raise ImproperConfigurationError(msg) from exc return await super().select_to_arrow( prepared_statement, statement_config=config, return_format=return_format, native_only=native_only, batch_size=batch_size, arrow_schema=arrow_schema, ) arrow_table = pa.table(oracle_df) column_names = normalize_column_names(arrow_table.column_names, self.driver_features) if column_names != arrow_table.column_names: arrow_table = arrow_table.rename_columns(column_names) return build_arrow_result_from_table( prepared_statement, arrow_table, return_format=return_format, batch_size=batch_size, arrow_schema=arrow_schema, )
# ───────────────────────────────────────────────────────────────────────────── # STACK EXECUTION METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def execute_stack( self, stack: "StatementStack", *, continue_on_error: bool = False ) -> "tuple[StackResult, ...]": """Execute a StatementStack using Oracle's pipeline when available.""" if not isinstance(stack, StatementStack) or not stack: return await super().execute_stack(stack, continue_on_error=continue_on_error) blocker = self._stack_native_blocker(stack) if blocker is not None: self._log_pipeline_skip(blocker, stack) return await super().execute_stack(stack, continue_on_error=continue_on_error) if not await self._pipeline_native_supported(): self._log_pipeline_skip(self._pipeline_support_reason or "database_version", stack) return await super().execute_stack(stack, continue_on_error=continue_on_error) return await self._execute_stack_native(stack, continue_on_error=continue_on_error)
# ───────────────────────────────────────────────────────────────────────────── # STORAGE API METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] async def select_to_storage( self, statement: "Statement | QueryBuilder | SQL | str", destination: "StorageDestination", /, *parameters: "StatementParameters | StatementFilter", statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute a query and write Arrow-compatible output to storage (async).""" self._require_capability("arrow_export_enabled") arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) async_pipeline = self._storage_pipeline() telemetry_payload = await self._write_result_to_storage_async( arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Asynchronously load Arrow data into Oracle.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: statement = build_truncate_statement(table) exc_handler = self.handle_database_exceptions() async with exc_handler: await self.connection.execute(statement) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: statement = build_insert_statement(table, columns) exc_handler = self.handle_database_exceptions() async with self.with_cursor(self.connection) as cursor, exc_handler: await cursor.executemany(statement, records) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] async def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Asynchronously load staged artifacts into Oracle.""" arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) return await self.load_from_arrow( table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound )
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "OracledbAsyncDataDictionary": """Get the data dictionary for this driver. Returns: Data dictionary instance for metadata queries """ if self._data_dictionary is None: self._data_dictionary = OracledbAsyncDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE/INTERNAL METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Oracle async rows for the direct execution path. Uses synchronous LOB coercion. For async LOB coercion, the standard dispatch path via collect_async_rows is used instead. """ column_names, requires_lob_coercion = self._resolve_row_metadata(cursor.description) data, column_names = collect_sync_rows( cast("list[Any] | None", fetched), cursor.description, self.driver_features, column_names=column_names, requires_lob_coercion=requires_lob_coercion, ) return data, column_names, len(data)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from Oracle cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction.""" return False async def _detect_oracle_version(self) -> "VersionInfo | None": if self._oracle_version is not None: return self._oracle_version version = await self.data_dictionary.get_version(self) self._oracle_version = version return version def _detect_oracledb_version(self) -> "tuple[int, int, int]": return ORACLEDB_VERSION def _resolve_row_metadata(self, description: Any) -> "tuple[list[str], bool]": return resolve_row_metadata(description, self.driver_features, self._row_metadata_cache) async def _execute_arrow_dataframe(self, sql: str, parameters: "Any", batch_size: int | None) -> "Any": """Execute SQL and return an Oracle DataFrame.""" params = parameters if parameters is not None else [] try: execute_df = self.connection.execute_df except AttributeError: execute_df = None if execute_df is not None: try: return await execute_df(sql, params, arraysize=batch_size or 1000) except TypeError: return await execute_df(sql, params) return await self.connection.fetch_df_all(statement=sql, parameters=params, arraysize=batch_size or 1000) async def _execute_stack_native( self, stack: "StatementStack", *, continue_on_error: bool ) -> "tuple[StackResult, ...]": compiled_operations = [self._prepare_pipeline_operation(op) for op in stack.operations] pipeline = oracledb.create_pipeline() for compiled in compiled_operations: self._add_pipeline_operation(pipeline, compiled) results: list[StackResult] = [] started_transaction = False with StackExecutionObserver(self, stack, continue_on_error, native_pipeline=True) as observer: try: if not continue_on_error and not self._connection_in_transaction(): await self.begin() started_transaction = True pipeline_results = await self.connection.run_pipeline(pipeline, continue_on_error=continue_on_error) results = self._build_stack_results_from_pipeline( compiled_operations, pipeline_results, continue_on_error, observer ) if started_transaction: await self.commit() except Exception as exc: if started_transaction: try: await self.rollback() except Exception as rollback_error: # pragma: no cover logger.debug("Rollback after pipeline failure failed: %s", rollback_error) raise self._wrap_pipeline_error(exc, stack, continue_on_error) from exc return tuple(results) async def _pipeline_native_supported(self) -> bool: if self._pipeline_support is not None: return self._pipeline_support if self.stack_native_disabled: self._pipeline_support = False self._pipeline_support_reason = "env_override" return False if self._detect_oracledb_version() < PIPELINE_MIN_DRIVER_VERSION: self._pipeline_support = False self._pipeline_support_reason = "driver_version" return False if not has_pipeline_capability(self.connection): self._pipeline_support = False self._pipeline_support_reason = "driver_api_missing" return False version_info = await self._detect_oracle_version() if version_info and version_info.major >= PIPELINE_MIN_DATABASE_MAJOR: self._pipeline_support = True self._pipeline_support_reason = None return True self._pipeline_support = False self._pipeline_support_reason = "database_version" return False
register_driver_profile("oracledb", driver_profile)