Source code for sqlspec.adapters.sqlite.driver

"""SQLite driver implementation."""

import contextlib
import sqlite3
from typing import TYPE_CHECKING, Any

from typing_extensions import Self

from sqlspec.adapters.sqlite._typing import SqliteSessionContext
from sqlspec.adapters.sqlite.core import (
    build_insert_statement,
    collect_rows,
    create_mapped_exception,
    default_statement_config,
    driver_profile,
    format_identifier,
    normalize_execute_many_parameters,
    normalize_execute_parameters,
    resolve_rowcount,
)
from sqlspec.adapters.sqlite.data_dictionary import SqliteDataDictionary
from sqlspec.core import ArrowResult, ParameterStyle, TypedParameter, get_cache_config, register_driver_profile
from sqlspec.core.result import DMLResult
from sqlspec.driver import SyncDriverAdapterBase
from sqlspec.exceptions import SQLSpecError

if TYPE_CHECKING:
    from collections.abc import Sequence

    from sqlspec.adapters.sqlite._typing import SqliteConnection
    from sqlspec.builder import QueryBuilder
    from sqlspec.core import SQL, SQLResult, Statement, StatementConfig, StatementFilter
    from sqlspec.core.compiler import OperationType
    from sqlspec.driver import ExecutionResult
    from sqlspec.driver._query_cache import CachedQuery
    from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry
    from sqlspec.typing import StatementParameters

__all__ = ("SqliteCursor", "SqliteDriver", "SqliteExceptionHandler", "SqliteSessionContext")


class SqliteCursor:
    """Context manager for SQLite cursor management.

    Provides automatic cursor creation and cleanup for SQLite database operations.
    """

    __slots__ = ("connection", "cursor")

    def __init__(self, connection: "SqliteConnection") -> None:
        """Initialize cursor manager.

        Args:
            connection: SQLite database connection
        """
        self.connection = connection
        self.cursor: sqlite3.Cursor | None = None

    def __enter__(self) -> "sqlite3.Cursor":
        """Create and return a new cursor.

        Returns:
            Active SQLite cursor object
        """
        self.cursor = self.connection.cursor()
        return self.cursor

    def __exit__(self, *_: Any) -> None:
        """Clean up cursor resources.

        Args:
            exc_type: Exception type if an exception occurred
            exc_val: Exception value if an exception occurred
            exc_tb: Exception traceback if an exception occurred
        """
        if self.cursor is not None:
            with contextlib.suppress(Exception):
                self.cursor.close()


class SqliteExceptionHandler:
    """Context manager for handling SQLite database exceptions.

    Maps SQLite extended result codes to specific SQLSpec exceptions
    for better error handling in application code.

    Uses deferred exception pattern for mypyc compatibility: exceptions
    are stored in pending_exception rather than raised from __exit__
    to avoid ABI boundary violations with compiled code.
    """

    __slots__ = ("pending_exception",)

    def __init__(self) -> None:
        self.pending_exception: Exception | None = None

    def __enter__(self) -> Self:
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
        if exc_type is None:
            return False
        if issubclass(exc_type, sqlite3.Error):
            self.pending_exception = create_mapped_exception(exc_val)
            return True
        return False


[docs] class SqliteDriver(SyncDriverAdapterBase): """SQLite driver implementation. Provides SQL statement execution, transaction management, and result handling for SQLite databases using the standard sqlite3 module. """ __slots__ = ("_data_dictionary",) dialect = "sqlite"
[docs] def __init__( self, connection: "SqliteConnection", statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, ) -> None: """Initialize SQLite driver. Args: connection: SQLite database connection statement_config: Statement configuration settings driver_features: Driver-specific feature flags """ if statement_config is None: statement_config = default_statement_config.replace( enable_caching=get_cache_config().compiled_cache_enabled ) super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: SqliteDataDictionary | None = None
# ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def dispatch_execute(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: cursor: SQLite cursor object statement: SQL statement to execute Returns: ExecutionResult with statement execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) if statement.returns_rows(): fetched_data = cursor.fetchall() data, column_names, row_count = collect_rows(fetched_data, cursor.description) return self.create_execution_result( cursor, selected_data=data, column_names=column_names, data_row_count=row_count, is_select_result=True, row_format="tuple", ) affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows)
[docs] def dispatch_execute_many(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets. Args: cursor: SQLite cursor object statement: SQL statement with multiple parameter sets Returns: ExecutionResult with batch execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.executemany(sql, normalize_execute_many_parameters(prepared_parameters)) affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
[docs] def dispatch_execute_script(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Args: cursor: SQLite cursor object statement: SQL statement containing multiple statements Returns: ExecutionResult with script execution details """ sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) successful_count = 0 last_cursor = cursor for stmt in statements: cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) successful_count += 1 return self.create_execution_result( last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True )
[docs] def execute_many( self, statement: "SQL | Statement | QueryBuilder", /, parameters: "Sequence[StatementParameters]", *filters: "StatementParameters | StatementFilter", statement_config: "StatementConfig | None" = None, **kwargs: Any, ) -> "SQLResult": """Execute many with a SQLite thin path for simple qmark batches.""" config = statement_config or self.statement_config if ( isinstance(statement, str) and not filters and not kwargs and config is self.statement_config and self.observability.is_idle and self._can_use_execute_many_thin_path(statement, parameters, config) ): try: cursor = self.connection.executemany(statement, parameters) except sqlite3.Error as exc: raise create_mapped_exception(exc) from exc rowcount = cursor.rowcount affected_rows = rowcount if isinstance(rowcount, int) and rowcount > 0 else 0 operation = self._resolve_dml_operation_type(statement) return DMLResult(operation, affected_rows) return super().execute_many(statement, parameters, *filters, statement_config=statement_config, **kwargs)
def _stmt_cache_execute_direct( self, sql: str, params: "tuple[Any, ...] | list[Any]", cached: "CachedQuery" ) -> "SQLResult": """Execute cached query through SQLite connection.execute fast path. This bypasses cursor context-manager overhead for repeated cached statements while preserving driver exception mapping behavior. """ direct_statement: SQL | None = None returns_rows = cached.operation_profile.returns_rows try: if not returns_rows: try: cursor = self.connection.execute(cached.compiled_sql, params) except sqlite3.Error as exc: raise create_mapped_exception(exc) from exc rowcount = cursor.rowcount affected_rows = rowcount if isinstance(rowcount, int) and rowcount > 0 else 0 return DMLResult(cached.operation_type, affected_rows) try: cursor = self.connection.execute(cached.compiled_sql, params) except sqlite3.Error as exc: raise create_mapped_exception(exc) from exc if returns_rows: fetched_data = cursor.fetchall() column_names = cached.column_names if column_names is None: description = cursor.description column_names = [col[0] for col in description] if description else [] execution_result = self.create_execution_result( cursor, selected_data=fetched_data, column_names=column_names, data_row_count=len(fetched_data), is_select_result=True, row_format="tuple", ) direct_statement = self._stmt_cache_build_direct( sql, params, cached, params, params_are_simple=True, compiled_sql=cached.compiled_sql ) return self.build_statement_result(direct_statement, execution_result) finally: if direct_statement is not None: self._release_pooled_statement(direct_statement) msg = "unreachable" raise AssertionError(msg) # pragma: no cover - all paths return or raise def _can_use_execute_many_thin_path( self, statement: str, parameters: "Sequence[StatementParameters]", config: "StatementConfig" ) -> bool: if type(parameters) is not list: return False if not parameters: return False if "?" not in statement: return False parameter_config = config.parameter_config if parameter_config.default_parameter_style is not ParameterStyle.QMARK: return False if ( parameter_config.default_execution_parameter_style is not None and parameter_config.default_execution_parameter_style is not ParameterStyle.QMARK ): return False if parameter_config.ast_transformer is not None or parameter_config.output_transformer is not None: return False if parameter_config.needs_static_script_compilation: return False if config.output_transformer is not None or config.statement_transformers: return False return self._thin_path_parameters_are_eligible(parameters, parameter_config.type_coercion_map) @staticmethod def _thin_path_parameters_are_eligible( parameters: "list[StatementParameters]", type_coercion_map: "dict[type, Any] | None" ) -> bool: """Validate parameter payload for the SQLite execute-many thin path.""" first_sequence = SqliteDriver._as_sequence_parameter_set(parameters[0]) if first_sequence is None: return False first_type = type(first_sequence) row_len = len(first_sequence) coercion_map = type_coercion_map has_type_coercion = bool(coercion_map) # Common benchmark shape: list[tuple[value]] if row_len == 1: if has_type_coercion and coercion_map is not None: for param_set in parameters: sequence = SqliteDriver._as_sequence_parameter_set(param_set) if sequence is None or type(sequence) is not first_type: return False if len(sequence) != 1: return False value_type = type(sequence[0]) if value_type is TypedParameter or value_type in coercion_map: return False return True for param_set in parameters: sequence = SqliteDriver._as_sequence_parameter_set(param_set) if sequence is None or type(sequence) is not first_type: return False if len(sequence) != 1: return False if type(sequence[0]) is TypedParameter: return False return True if has_type_coercion and coercion_map is not None: for param_set in parameters: sequence = SqliteDriver._as_sequence_parameter_set(param_set) if sequence is None or type(sequence) is not first_type: return False if len(sequence) != row_len: return False for value in sequence: value_type = type(value) if value_type is TypedParameter or value_type in coercion_map: return False return True for param_set in parameters: sequence = SqliteDriver._as_sequence_parameter_set(param_set) if sequence is None or type(sequence) is not first_type: return False if len(sequence) != row_len: return False for value in sequence: if type(value) is TypedParameter: return False return True @staticmethod def _as_sequence_parameter_set(param_set: "StatementParameters") -> "list[Any] | tuple[Any, ...] | None": if isinstance(param_set, list): return param_set if isinstance(param_set, tuple): return param_set return None @staticmethod def _resolve_dml_operation_type(statement: str) -> "OperationType": command_keyword = statement.lstrip().split(None, 1)[0].upper() if statement.strip() else "COMMAND" if command_keyword == "INSERT": return "INSERT" if command_keyword == "UPDATE": return "UPDATE" if command_keyword == "DELETE": return "DELETE" return "COMMAND" # ───────────────────────────────────────────────────────────────────────────── # TRANSACTION MANAGEMENT # ─────────────────────────────────────────────────────────────────────────────
[docs] def begin(self) -> None: """Begin a database transaction. Raises: SQLSpecError: If transaction cannot be started """ try: if not self.connection.in_transaction: self.connection.execute("BEGIN") except sqlite3.Error as e: msg = f"Failed to begin transaction: {e}" raise SQLSpecError(msg) from e
[docs] def commit(self) -> None: """Commit the current transaction. Raises: SQLSpecError: If transaction cannot be committed """ try: self.connection.commit() except sqlite3.Error as e: msg = f"Failed to commit transaction: {e}" raise SQLSpecError(msg) from e
[docs] def rollback(self) -> None: """Rollback the current transaction. Raises: SQLSpecError: If transaction cannot be rolled back """ try: self.connection.rollback() except sqlite3.Error as e: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e
[docs] def with_cursor(self, connection: "SqliteConnection") -> "SqliteCursor": """Create context manager for SQLite cursor. Args: connection: SQLite database connection Returns: Cursor context manager for safe cursor operations """ return SqliteCursor(connection)
[docs] def handle_database_exceptions(self) -> "SqliteExceptionHandler": """Handle database-specific exceptions and wrap them appropriately. Returns: Exception handler with deferred exception pattern for mypyc compatibility. """ return SqliteExceptionHandler()
# ───────────────────────────────────────────────────────────────────────────── # STORAGE API # ─────────────────────────────────────────────────────────────────────────────
[docs] def select_to_storage( self, statement: "SQL | str", destination: "StorageDestination", /, *parameters: Any, statement_config: "StatementConfig | None" = None, partitioner: "dict[str, object] | None" = None, format_hint: "StorageFormat | None" = None, telemetry: "StorageTelemetry | None" = None, **kwargs: Any, ) -> "StorageBridgeJob": """Execute a query and write Arrow-compatible output to storage (sync).""" self._require_capability("arrow_export_enabled") arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) sync_pipeline = self._storage_pipeline() telemetry_payload = self._write_result_to_storage_sync( arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline ) self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_arrow( self, table: str, source: "ArrowResult | Any", *, partitioner: "dict[str, object] | None" = None, overwrite: bool = False, telemetry: "StorageTelemetry | None" = None, ) -> "StorageBridgeJob": """Load Arrow data into SQLite using batched inserts.""" self._require_capability("arrow_import_enabled") arrow_table = self._coerce_arrow_table(source) if overwrite: statement = f"DELETE FROM {format_identifier(table)}" exc_handler = self.handle_database_exceptions() with exc_handler, self.with_cursor(self.connection) as cursor: cursor.execute(statement) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None columns, records = self._arrow_table_to_rows(arrow_table) if records: insert_sql = build_insert_statement(table, columns) exc_handler = self.handle_database_exceptions() with exc_handler, self.with_cursor(self.connection) as cursor: cursor.executemany(insert_sql, records) if exc_handler.pending_exception is not None: raise exc_handler.pending_exception from None telemetry_payload = self._build_ingest_telemetry(arrow_table) telemetry_payload["destination"] = table self._attach_partition_telemetry(telemetry_payload, partitioner) return self._create_storage_job(telemetry_payload, telemetry)
[docs] def load_from_storage( self, table: str, source: "StorageDestination", *, file_format: "StorageFormat", partitioner: "dict[str, object] | None" = None, overwrite: bool = False, ) -> "StorageBridgeJob": """Load staged artifacts from storage into SQLite.""" arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format) return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
# ───────────────────────────────────────────────────────────────────────────── # UTILITY METHODS # ───────────────────────────────────────────────────────────────────────────── @property def data_dictionary(self) -> "SqliteDataDictionary": """Get the data dictionary for this driver. Returns: Data dictionary instance for metadata queries """ if self._data_dictionary is None: self._data_dictionary = SqliteDataDictionary() return self._data_dictionary # ───────────────────────────────────────────────────────────────────────────── # PRIVATE/INTERNAL METHODS # ─────────────────────────────────────────────────────────────────────────────
[docs] def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect SQLite rows for the direct execution path.""" return collect_rows(fetched, cursor.description)
[docs] def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from SQLite cursor for the direct execution path.""" return resolve_rowcount(cursor)
def _connection_in_transaction(self) -> bool: """Check if connection is in transaction. Returns: True if connection is in an active transaction. """ return bool(self.connection.in_transaction)
register_driver_profile("sqlite", driver_profile)