diff --git a/.gitattributes b/.gitattributes index a748d2ce..0a49a4b9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -31,3 +31,10 @@ Dockerfile* text # .gitattributes export-ignore .gitignore export-ignore + +# napi-rs auto-generates these files from the kernel's `napi-binding/napi/` +# crate; regenerated by `npm run build:native`. Tell git/GitHub they're +# machine-generated so they collapse in diffs and are excluded from +# blame and language stats. +native/sea/index.d.ts linguist-generated=true +native/sea/index.js linguist-generated=true diff --git a/.gitignore b/.gitignore index 99381ce5..c3801f4b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,12 @@ coverage_unit dist *.DS_Store lib/version.ts + +# SEA native binding — copied/generated from kernel workspace by `npm run build:native`. +# The committed contract is `native/sea/index.d.ts` (TypeScript declarations) and +# `native/sea/index.js` (the napi-rs platform router — small, stable, and required in +# the publish tarball so a missing build step can't ship a tarball that can't load). +# The `.node` binaries are large per-platform artifacts and must NOT be committed; +# in production they arrive via the `@databricks/sql-kernel-` optional deps. +native/sea/index.node +native/sea/index.*.node diff --git a/.npmignore b/.npmignore index 2bfe597c..448289a7 100644 --- a/.npmignore +++ b/.npmignore @@ -3,6 +3,13 @@ !dist/**/* !thrift/**/* +# SEA napi-rs router shim + TypeScript declarations. The router (index.js) +# selects the per-platform `.node` artifact from `@databricks/sql-kernel-*` +# optionalDependencies (populated when the kernel CI publishes them); +# the .d.ts is the consumer-facing type contract. +!native/sea/index.js +!native/sea/index.d.ts + !LICENSE !NOTICE !package.json diff --git a/.prettierignore b/.prettierignore index 9a9ec6bc..4a764095 100644 --- a/.prettierignore +++ b/.prettierignore @@ -11,3 +11,9 @@ coverage dist thrift package-lock.json + +# Generated by napi-rs from the kernel's `napi-binding/napi/` crate; +# regenerated by `npm run build:native`. Format follows napi-rs's +# defaults (no semicolons), not this repo's prettier config. +native/sea/index.d.ts +native/sea/index.js diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 38d55a54..7cdd9659 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -1,9 +1,7 @@ import thrift from 'thrift'; -import Int64 from 'node-int64'; import { EventEmitter } from 'events'; import TCLIService from '../thrift/TCLIService'; -import { TProtocolVersion } from '../thrift/TCLIService_types'; import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient'; import IDriver from './contracts/IDriver'; import IClientContext, { ClientConfig } from './contracts/IClientContext'; @@ -14,9 +12,12 @@ import IDBSQLSession from './contracts/IDBSQLSession'; import IAuthentication from './connection/contracts/IAuthentication'; import HttpConnection from './connection/connections/HttpConnection'; import IConnectionOptions from './connection/contracts/IConnectionOptions'; -import Status from './dto/Status'; import HiveDriverError from './errors/HiveDriverError'; -import { buildUserAgentString, definedOrError, serializeQueryTags } from './utils'; +import { buildUserAgentString } from './utils'; +import IBackend from './contracts/IBackend'; +import { InternalConnectionOptions } from './contracts/InternalConnectionOptions'; +import ThriftBackend from './thrift-backend/ThriftBackend'; +import SeaBackend from './sea/SeaBackend'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth'; import { @@ -31,26 +32,7 @@ import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; import DBSQLLogger from './DBSQLLogger'; import CloseableCollection from './utils/CloseableCollection'; import IConnectionProvider from './connection/contracts/IConnectionProvider'; - -function prependSlash(str: string): string { - if (str.length > 0 && str.charAt(0) !== '/') { - return `/${str}`; - } - return str; -} - -function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { - if (!catalogName && !schemaName) { - return {}; - } - - return { - initialNamespace: { - catalogName, - schemaName, - }, - }; -} +import prependSlash from './utils/prependSlash'; export type ThriftLibrary = Pick; @@ -75,6 +57,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I private readonly sessions = new CloseableCollection(); + private backend?: IBackend; + private static getDefaultLogger(): IDBSQLLogger { if (!this.defaultLogger) { this.defaultLogger = new DBSQLLogger(); @@ -244,40 +228,61 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I this.config.userAgentEntry = options.userAgentEntry; } - this.authProvider = this.createAuthProvider(options, authProvider); + // M0: `useSEA` is consumed via a non-exported internal-options cast so it + // doesn't ship in the public `.d.ts`. Mirrors Python's `kwargs.get("use_sea")` + // pattern (see databricks-sql-python/src/databricks/sql/session.py). + const internalOptions = options as ConnectionOptions & InternalConnectionOptions; + if (internalOptions.useSEA) { + // The SEA backend authenticates inside the native binding; the + // Thrift auth/connection providers are never read on this path, so + // we don't build them (avoids validating the PAT twice and + // constructing a throwaway OAuth provider for an OAuth+useSEA call). + // The backend reads logger/config off the IClientContext it's given. + this.logger.log(LogLevel.info, 'Connecting via the SEA (native) backend'); + this.backend = new SeaBackend({ context: this }); + } else { + this.authProvider = this.createAuthProvider(options, authProvider); + this.connectionProvider = this.createConnectionProvider(options); + this.backend = new ThriftBackend({ + context: this, + onConnectionEvent: (event, payload) => this.forwardConnectionEvent(event, payload), + }); + } - this.connectionProvider = this.createConnectionProvider(options); + await this.backend.connect(options); - const thriftConnection = await this.connectionProvider.getThriftConnection(); + return this; + } - thriftConnection.on('error', (error: Error) => { - // Error.stack already contains error type and message, so log stack if available, - // otherwise fall back to just error type + message - this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`); - try { - this.emit('error', error); - } catch (e) { - // EventEmitter will throw unhandled error when emitting 'error' event. - // Since we already logged it few lines above, just suppress this behaviour + private forwardConnectionEvent(event: 'error' | 'reconnecting' | 'close' | 'timeout', payload?: unknown): void { + switch (event) { + case 'error': { + const error = payload as Error; + this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`); + try { + this.emit('error', error); + } catch (e) { + // EventEmitter throws when 'error' has no listeners; we've already logged it. + } + return; } - }); - - thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => { - this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`); - this.emit('reconnecting', params); - }); - - thriftConnection.on('close', () => { - this.logger.log(LogLevel.debug, 'Closing connection.'); - this.emit('close'); - }); - - thriftConnection.on('timeout', () => { - this.logger.log(LogLevel.debug, 'Connection timed out.'); - this.emit('timeout'); - }); - - return this; + case 'reconnecting': + this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(payload)}`); + this.emit('reconnecting', payload); + return; + case 'close': + this.logger.log(LogLevel.debug, 'Closing connection.'); + this.emit('close'); + return; + case 'timeout': + this.logger.log(LogLevel.debug, 'Connection timed out.'); + this.emit('timeout'); + // Explicit return mirrors the other cases and protects against + // fall-through if a new event is added below. + // eslint-disable-next-line no-useless-return + return; + // no default + } } /** @@ -290,44 +295,20 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I * const session = await client.openSession(); */ public async openSession(request: OpenSessionRequest = {}): Promise { - // Prepare session configuration - const configuration = request.configuration ? { ...request.configuration } : {}; - - // Add metric view metadata config if enabled - if (this.config.enableMetricViewMetadata) { - configuration['spark.sql.thriftserver.metadata.metricview.enabled'] = 'true'; - } - - // Serialize queryTags dict and set in configuration; takes precedence over configuration.QUERY_TAGS - if (request.queryTags !== undefined) { - const serialized = serializeQueryTags(request.queryTags); - if (serialized) { - configuration.QUERY_TAGS = serialized; - } else { - delete configuration.QUERY_TAGS; - } + if (!this.backend) { + throw new HiveDriverError('DBSQLClient: not connected'); } - - const response = await this.driver.openSession({ - client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8), - ...getInitialNamespaceOptions(request.initialCatalog, request.initialSchema), - configuration, - canUseMultipleCatalogs: true, - }); - - Status.assert(response.status); - const session = new DBSQLSession({ - handle: definedOrError(response.sessionHandle), - context: this, - serverProtocolVersion: response.serverProtocolVersion, - }); + const sessionBackend = await this.backend.openSession(request); + const session = new DBSQLSession({ backend: sessionBackend, context: this }); this.sessions.add(session); return session; } public async close(): Promise { await this.sessions.closeAll(); + await this.backend?.close(); + this.backend = undefined; this.client = undefined; this.connectionProvider = undefined; this.authProvider = undefined; diff --git a/lib/DBSQLOperation.ts b/lib/DBSQLOperation.ts index fe22995d..21b8f0fd 100644 --- a/lib/DBSQLOperation.ts +++ b/lib/DBSQLOperation.ts @@ -1,4 +1,3 @@ -import { stringify, NIL } from 'uuid'; import { Readable } from 'node:stream'; import IOperation, { FetchOptions, @@ -12,91 +11,45 @@ import IOperation, { } from './contracts/IOperation'; import { TGetOperationStatusResp, - TOperationHandle, - TTableSchema, - TSparkDirectResults, TGetResultSetMetadataResp, - TSparkRowSetType, - TCloseOperationResp, - TOperationState, + TTableSchema, } from '../thrift/TCLIService_types'; import Status from './dto/Status'; import { LogLevel } from './contracts/IDBSQLLogger'; import OperationStateError, { OperationStateErrorCode } from './errors/OperationStateError'; -import IResultsProvider from './result/IResultsProvider'; -import RowSetProvider from './result/RowSetProvider'; -import JsonResultHandler from './result/JsonResultHandler'; -import ArrowResultHandler from './result/ArrowResultHandler'; -import CloudFetchResultHandler from './result/CloudFetchResultHandler'; -import ArrowResultConverter from './result/ArrowResultConverter'; -import ResultSlicer from './result/ResultSlicer'; -import { definedOrError } from './utils'; import { OperationChunksIterator, OperationRowsIterator } from './utils/OperationIterator'; -import HiveDriverError from './errors/HiveDriverError'; import IClientContext from './contracts/IClientContext'; +import IOperationBackend from './contracts/IOperationBackend'; +import { ResultMetadata } from './contracts/ResultMetadata'; +import ThriftOperationBackend from './thrift-backend/ThriftOperationBackend'; +import { synthesizeThriftStatus, synthesizeThriftResultSetMetadata } from './utils/thriftWireSynthesis'; interface DBSQLOperationConstructorOptions { - handle: TOperationHandle; - directResults?: TSparkDirectResults; + backend: IOperationBackend; context: IClientContext; } -async function delay(ms?: number): Promise { - return new Promise((resolve) => { - setTimeout(() => { - resolve(); - }, ms); - }); -} - export default class DBSQLOperation implements IOperation { private readonly context: IClientContext; - private readonly operationHandle: TOperationHandle; + private readonly backend: IOperationBackend; public onClose?: () => void; - private readonly _data: RowSetProvider; - - private readonly closeOperation?: TCloseOperationResp; - private closed: boolean = false; private cancelled: boolean = false; - private metadata?: TGetResultSetMetadataResp; - - private metadataPromise?: Promise; - - private state: TOperationState = TOperationState.INITIALIZED_STATE; - - // Once operation is finished or fails - cache status response, because subsequent calls - // to `getOperationStatus()` may fail with irrelevant errors, e.g. HTTP 404 - private operationStatus?: TGetOperationStatusResp; - - private resultHandler?: ResultSlicer; - - constructor({ handle, directResults, context }: DBSQLOperationConstructorOptions) { - this.operationHandle = handle; - this.context = context; - - const useOnlyPrefetchedResults = Boolean(directResults?.closeOperation); - - if (directResults?.operationStatus) { - this.processOperationStatusResponse(directResults.operationStatus); - } - - this.metadata = directResults?.resultSetMetadata; - this._data = new RowSetProvider( - this.context, - this.operationHandle, - [directResults?.resultSet], - useOnlyPrefetchedResults, - ); - this.closeOperation = directResults?.closeOperation; + constructor(options: DBSQLOperationConstructorOptions) { + this.context = options.context; + this.backend = options.backend; this.context.getLogger().log(LogLevel.debug, `Operation created with id: ${this.id}`); } + public get id() { + return this.backend.id; + } + public iterateChunks(options?: IteratorOptions): IOperationChunksIterator { return new OperationChunksIterator(this, options); } @@ -122,11 +75,6 @@ export default class DBSQLOperation implements IOperation { return Readable.from(iterable, options?.streamOptions); } - public get id() { - const operationId = this.operationHandle?.operationId?.guid; - return operationId ? stringify(operationId) : NIL; - } - /** * Fetches all data * @public @@ -141,8 +89,6 @@ export default class DBSQLOperation implements IOperation { const fetchChunkOptions = { ...options, - // Tell slicer to return raw chunks. We're going to process all of them anyway, - // so no need to additionally buffer and slice chunks returned by server disableBuffering: true, }; @@ -168,70 +114,44 @@ export default class DBSQLOperation implements IOperation { public async fetchChunk(options?: FetchOptions): Promise> { await this.failIfClosed(); - if (!this.operationHandle.hasResultSet) { + if (!this.backend.hasResultSet) { return []; } - await this.waitUntilReady(options); - - const resultHandler = await this.getResultHandler(); + await this.waitUntilReadyThroughBackend(options); await this.failIfClosed(); - // All the library code is Promise-based, however, since Promises are microtasks, - // enqueueing a lot of promises may block macrotasks execution for a while. - // Usually, there are no much microtasks scheduled, however, when fetching query - // results (especially CloudFetch ones) it's quite easy to block event loop for - // long enough to break a lot of things. For example, with CloudFetch, after first - // set of files are downloaded and being processed immediately one by one, event - // loop easily gets blocked for enough time to break connection pool. `http.Agent` - // stops receiving socket events, and marks all sockets invalid on the next attempt - // to use them. See these similar issues that helped to debug this particular case - - // https://github.com/nodejs/node/issues/47130 and https://github.com/node-fetch/node-fetch/issues/1735 - // This simple fix allows to clean up a microtasks queue and allow Node to process - // macrotasks as well, allowing the normal operation of other code. Also, this - // fix is added to `fetchChunk` method because, unlike other methods, `fetchChunk` is - // a potential source of issues described above - await new Promise((resolve) => { - setTimeout(resolve, 0); - }); - const defaultMaxRows = this.context.getConfig().fetchChunkDefaultMaxRows; - - const result = resultHandler.fetchNext({ - limit: options?.maxRows ?? defaultMaxRows, - disableBuffering: options?.disableBuffering, - }); + const limit = options?.maxRows ?? defaultMaxRows; + const result = await this.backend.fetchChunk({ limit, disableBuffering: options?.disableBuffering }); await this.failIfClosed(); - this.context - .getLogger() - .log( - LogLevel.debug, - `Fetched chunk of size: ${options?.maxRows ?? defaultMaxRows} from operation with id: ${this.id}`, - ); + this.context.getLogger().log(LogLevel.debug, `Fetched chunk of size: ${limit} from operation with id: ${this.id}`); return result; } /** - * Requests operation status + * Requests operation status. Returns the Thrift wire response for + * back-compat with existing user code. On the Thrift backend the response + * is returned verbatim; on any other backend (e.g. SEA) the response is + * synthesized from the neutral {@link IOperationBackend.status} result, + * with Thrift-only fields (`taskStatus`, `numModifiedRows`, etc.) left + * undefined. + * * @param progress * @throws {StatusError} */ public async status(progress: boolean = false): Promise { await this.failIfClosed(); this.context.getLogger().log(LogLevel.debug, `Fetching status for operation with id: ${this.id}`); - - if (this.operationStatus) { - return this.operationStatus; + if (this.backend instanceof ThriftOperationBackend) { + // Zero-loss path: the Thrift backend has the wire response on hand. + return this.backend.thriftStatusResponse(progress); } - - const driver = await this.context.getDriver(); - const response = await driver.getOperationStatus({ - operationHandle: this.operationHandle, - getProgressUpdate: progress, - }); - - return this.processOperationStatusResponse(response); + // Non-Thrift backend: synthesize the Thrift-shaped response from the + // neutral OperationStatus DTO. + const status = await this.backend.status(progress); + return synthesizeThriftStatus(status); } /** @@ -242,18 +162,8 @@ export default class DBSQLOperation implements IOperation { if (this.closed || this.cancelled) { return Status.success(); } - - this.context.getLogger().log(LogLevel.debug, `Cancelling operation with id: ${this.id}`); - - const driver = await this.context.getDriver(); - const response = await driver.cancelOperation({ - operationHandle: this.operationHandle, - }); - Status.assert(response.status); + const result = await this.backend.cancel(); this.cancelled = true; - const result = new Status(response.status); - - // Cancelled operation becomes unusable, similarly to being closed this.onClose?.(); return result; } @@ -266,63 +176,66 @@ export default class DBSQLOperation implements IOperation { if (this.closed || this.cancelled) { return Status.success(); } - - this.context.getLogger().log(LogLevel.debug, `Closing operation with id: ${this.id}`); - - const driver = await this.context.getDriver(); - const response = - this.closeOperation ?? - (await driver.closeOperation({ - operationHandle: this.operationHandle, - })); - Status.assert(response.status); + const result = await this.backend.close(); this.closed = true; - const result = new Status(response.status); - this.onClose?.(); return result; } public async finished(options?: FinishedOptions): Promise { await this.failIfClosed(); - await this.waitUntilReady(options); + await this.waitUntilReadyThroughBackend(options); } public async hasMoreRows(): Promise { - // If operation is closed or cancelled - we should not try to get data from it if (this.closed || this.cancelled) { return false; } - // Wait for operation to finish before checking for more rows - // This ensures metadata can be fetched successfully - if (this.operationHandle.hasResultSet) { - await this.waitUntilReady(); + if (this.backend.hasResultSet) { + await this.waitUntilReadyThroughBackend(); } - // If we fetched all the data from server - check if there's anything buffered in result handler - const resultHandler = await this.getResultHandler(); - return resultHandler.hasMore(); + return this.backend.hasMore(); } public async getSchema(options?: GetSchemaOptions): Promise { await this.failIfClosed(); - if (!this.operationHandle.hasResultSet) { + if (!this.backend.hasResultSet) { return null; } - await this.waitUntilReady(options); + await this.waitUntilReadyThroughBackend(options); this.context.getLogger().log(LogLevel.debug, `Fetching schema for operation with id: ${this.id}`); - const metadata = await this.fetchMetadata(); + const metadata = await this.backend.getResultMetadata(); return metadata.schema ?? null; } + public async getResultMetadata(): Promise { + await this.failIfClosed(); + await this.waitUntilReadyThroughBackend(); + return this.backend.getResultMetadata(); + } + + /** + * Fetch result-set metadata as the Thrift wire response. Kept for + * back-compat with existing user code. On the Thrift backend the wire + * response is returned verbatim; on any other backend the response is + * synthesized from the neutral {@link ResultMetadata}, with Thrift-only + * fields (`cacheLookupResult`, `uncompressedBytes`, `compressedBytes`, + * `status`) left undefined / defaulted. + * + * Prefer {@link DBSQLOperation.getResultMetadata} in new code. + */ public async getMetadata(): Promise { await this.failIfClosed(); - await this.waitUntilReady(); - return this.fetchMetadata(); + await this.waitUntilReadyThroughBackend(); + if (this.backend instanceof ThriftOperationBackend) { + return this.backend.thriftResultMetadataResponse(); + } + return synthesizeThriftResultSetMetadata(await this.backend.getResultMetadata()); } private async failIfClosed(): Promise { @@ -334,151 +247,20 @@ export default class DBSQLOperation implements IOperation { } } - private async waitUntilReady(options?: WaitUntilReadyOptions) { - if (this.state === TOperationState.FINISHED_STATE) { - return; - } - - let isReady = false; - - while (!isReady) { - // eslint-disable-next-line no-await-in-loop - const response = await this.status(Boolean(options?.progress)); - - if (options?.callback) { - // eslint-disable-next-line no-await-in-loop - await Promise.resolve(options.callback(response)); - } - - switch (response.operationState) { - // For these states do nothing and continue waiting - case TOperationState.INITIALIZED_STATE: - case TOperationState.PENDING_STATE: - case TOperationState.RUNNING_STATE: - break; - - // Operation is completed, so exit the loop - case TOperationState.FINISHED_STATE: - isReady = true; - break; - - // Operation was cancelled, so set a flag and exit the loop (throw an error) - case TOperationState.CANCELED_STATE: + private async waitUntilReadyThroughBackend(options?: WaitUntilReadyOptions) { + try { + await this.backend.waitUntilReady(options); + } catch (err) { + // Reflect terminal states back into facade flags so subsequent calls + // short-circuit via failIfClosed(). + if (err instanceof OperationStateError) { + if (err.errorCode === OperationStateErrorCode.Canceled) { this.cancelled = true; - throw new OperationStateError(OperationStateErrorCode.Canceled, response); - - // Operation was closed, so set a flag and exit the loop (throw an error) - case TOperationState.CLOSED_STATE: + } else if (err.errorCode === OperationStateErrorCode.Closed) { this.closed = true; - throw new OperationStateError(OperationStateErrorCode.Closed, response); - - // Error states - throw and exit the loop - case TOperationState.ERROR_STATE: - throw new OperationStateError(OperationStateErrorCode.Error, response); - case TOperationState.TIMEDOUT_STATE: - throw new OperationStateError(OperationStateErrorCode.Timeout, response); - case TOperationState.UKNOWN_STATE: - default: - throw new OperationStateError(OperationStateErrorCode.Unknown, response); + } } - - // If not ready yet - make some delay before the next status requests - if (!isReady) { - // eslint-disable-next-line no-await-in-loop - await delay(100); - } - } - } - - private async fetchMetadata() { - // If metadata is already cached, return it immediately - if (this.metadata) { - return this.metadata; + throw err; } - - // If a fetch is already in progress, wait for it to complete - if (this.metadataPromise) { - return this.metadataPromise; - } - - // Start a new fetch and cache the promise to prevent concurrent fetches - this.metadataPromise = (async () => { - const driver = await this.context.getDriver(); - const metadata = await driver.getResultSetMetadata({ - operationHandle: this.operationHandle, - }); - Status.assert(metadata.status); - this.metadata = metadata; - return metadata; - })(); - - try { - return await this.metadataPromise; - } finally { - // Clear the promise once completed (success or failure) - this.metadataPromise = undefined; - } - } - - private async getResultHandler(): Promise> { - const metadata = await this.fetchMetadata(); - const resultFormat = definedOrError(metadata.resultFormat); - - if (!this.resultHandler) { - let resultSource: IResultsProvider> | undefined; - - switch (resultFormat) { - case TSparkRowSetType.COLUMN_BASED_SET: - resultSource = new JsonResultHandler(this.context, this._data, metadata); - break; - case TSparkRowSetType.ARROW_BASED_SET: - resultSource = new ArrowResultConverter( - this.context, - new ArrowResultHandler(this.context, this._data, metadata), - metadata, - ); - break; - case TSparkRowSetType.URL_BASED_SET: - resultSource = new ArrowResultConverter( - this.context, - new CloudFetchResultHandler(this.context, this._data, metadata), - metadata, - ); - break; - // no default - } - - if (resultSource) { - this.resultHandler = new ResultSlicer(this.context, resultSource); - } - } - - if (!this.resultHandler) { - throw new HiveDriverError(`Unsupported result format: ${TSparkRowSetType[resultFormat]}`); - } - - return this.resultHandler; - } - - private processOperationStatusResponse(response: TGetOperationStatusResp) { - Status.assert(response.status); - - this.state = response.operationState ?? this.state; - - if (typeof response.hasResultSet === 'boolean') { - this.operationHandle.hasResultSet = response.hasResultSet; - } - - const isInProgress = [ - TOperationState.INITIALIZED_STATE, - TOperationState.PENDING_STATE, - TOperationState.RUNNING_STATE, - ].includes(this.state); - - if (!isInProgress) { - this.operationStatus = response; - } - - return response; } } diff --git a/lib/DBSQLSession.ts b/lib/DBSQLSession.ts index 95715e1b..0e1cc934 100644 --- a/lib/DBSQLSession.ts +++ b/lib/DBSQLSession.ts @@ -2,19 +2,7 @@ import * as fs from 'fs'; import * as path from 'path'; import stream from 'node:stream'; import util from 'node:util'; -import { stringify, NIL } from 'uuid'; -import Int64 from 'node-int64'; import fetch, { HeadersInit } from 'node-fetch'; -import { - TSessionHandle, - TStatus, - TOperationHandle, - TSparkDirectResults, - TSparkArrowTypes, - TSparkParameter, - TProtocolVersion, - TExecuteStatementReq, -} from '../thrift/TCLIService_types'; import IDBSQLSession, { ExecuteStatementOptions, TypeInfoRequest, @@ -31,153 +19,44 @@ import IOperation from './contracts/IOperation'; import DBSQLOperation from './DBSQLOperation'; import Status from './dto/Status'; import InfoValue from './dto/InfoValue'; -import { definedOrError, LZ4, ProtocolVersion, serializeQueryTags } from './utils'; import CloseableCollection from './utils/CloseableCollection'; import { LogLevel } from './contracts/IDBSQLLogger'; import HiveDriverError from './errors/HiveDriverError'; import StagingError from './errors/StagingError'; -import { DBSQLParameter, DBSQLParameterValue } from './DBSQLParameter'; -import ParameterError from './errors/ParameterError'; -import IClientContext, { ClientConfig } from './contracts/IClientContext'; +import IClientContext from './contracts/IClientContext'; +import ISessionBackend from './contracts/ISessionBackend'; +import IOperationBackend from './contracts/IOperationBackend'; // Explicitly promisify a callback-style `pipeline` because `node:stream/promises` is not available in Node 14 const pipeline = util.promisify(stream.pipeline); -interface OperationResponseShape { - status: TStatus; - operationHandle?: TOperationHandle; - directResults?: TSparkDirectResults; -} - -export function numberToInt64(value: number | bigint | Int64): Int64 { - if (value instanceof Int64) { - return value; - } - - if (typeof value === 'bigint') { - const buffer = new ArrayBuffer(BigInt64Array.BYTES_PER_ELEMENT); - const view = new DataView(buffer); - view.setBigInt64(0, value, false); // `false` to use big-endian order - return new Int64(Buffer.from(buffer)); - } - - return new Int64(value); -} - -function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undefined, config: ClientConfig) { - if (maxRows === null) { - return {}; - } - - return { - getDirectResults: { - maxRows: numberToInt64(maxRows ?? config.directResultsDefaultMaxRows), - }, - }; -} - -function getArrowOptions( - config: ClientConfig, - serverProtocolVersion: TProtocolVersion | undefined | null, -): { - canReadArrowResult: boolean; - useArrowNativeTypes?: TSparkArrowTypes; -} { - const { arrowEnabled = true, useArrowNativeTypes = true } = config; - - if (!arrowEnabled || !ProtocolVersion.supportsArrowMetadata(serverProtocolVersion)) { - return { - canReadArrowResult: false, - }; - } - - return { - canReadArrowResult: true, - useArrowNativeTypes: { - timestampAsArrow: useArrowNativeTypes, - decimalAsArrow: useArrowNativeTypes, - complexTypesAsArrow: useArrowNativeTypes, - // TODO: currently unsupported by `apache-arrow` (see https://github.com/streamlit/streamlit/issues/4489) - intervalTypesAsArrow: false, - }, - }; -} - -function getQueryParameters( - namedParameters?: Record, - ordinalParameters?: Array, -): Array { - const namedParametersProvided = namedParameters !== undefined && Object.keys(namedParameters).length > 0; - const ordinalParametersProvided = ordinalParameters !== undefined && ordinalParameters.length > 0; - - if (namedParametersProvided && ordinalParametersProvided) { - throw new ParameterError('Driver does not support both ordinal and named parameters.'); - } - - if (!namedParametersProvided && !ordinalParametersProvided) { - return []; - } - - const result: Array = []; - - if (namedParameters !== undefined) { - for (const name of Object.keys(namedParameters)) { - const value = namedParameters[name]; - const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); - result.push(param.toSparkParameter({ name })); - } - } - - if (ordinalParameters !== undefined) { - for (const value of ordinalParameters) { - const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); - result.push(param.toSparkParameter()); - } - } - - return result; -} +// Re-export for back-compat with existing imports. +export { numberToInt64 } from './thrift-backend/ThriftSessionBackend'; interface DBSQLSessionConstructorOptions { - handle: TSessionHandle; + backend: ISessionBackend; context: IClientContext; - serverProtocolVersion?: TProtocolVersion; } export default class DBSQLSession implements IDBSQLSession { private readonly context: IClientContext; - private readonly sessionHandle: TSessionHandle; + private readonly backend: ISessionBackend; private isOpen = true; - private serverProtocolVersion?: TProtocolVersion; - public onClose?: () => void; private operations = new CloseableCollection(); - /** - * Helper method to determine if runAsync should be set for metadata operations - * @private - * @returns true if supported by protocol version, undefined otherwise - */ - private getRunAsyncForMetadataOperations(): boolean | undefined { - return ProtocolVersion.supportsAsyncMetadataOperations(this.serverProtocolVersion) ? true : undefined; - } - - constructor({ handle, context, serverProtocolVersion }: DBSQLSessionConstructorOptions) { - this.sessionHandle = handle; - this.context = context; - // Get the server protocol version from the provided parameter (from TOpenSessionResp) - this.serverProtocolVersion = serverProtocolVersion; + constructor(options: DBSQLSessionConstructorOptions) { + this.context = options.context; + this.backend = options.backend; this.context.getLogger().log(LogLevel.debug, `Session created with id: ${this.id}`); - this.context.getLogger().log(LogLevel.debug, `Server protocol version: ${this.serverProtocolVersion}`); } public get id() { - const sessionId = this.sessionHandle?.sessionId?.guid; - return sessionId ? stringify(sessionId) : NIL; + return this.backend.id; } /** @@ -190,14 +69,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getInfo(infoType: number): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const operationPromise = driver.getInfo({ - sessionHandle: this.sessionHandle, - infoType, - }); - const response = await this.handleResponse(operationPromise); - Status.assert(response.status); - return new InfoValue(response.infoValue); + const result = await this.backend.getInfo(infoType); + await this.failIfClosed(); + return result; } /** @@ -211,48 +85,13 @@ export default class DBSQLSession implements IDBSQLSession { */ public async executeStatement(statement: string, options: ExecuteStatementOptions = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const request = new TExecuteStatementReq({ - sessionHandle: this.sessionHandle, - statement, - queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined, - runAsync: true, - ...getDirectResultsOptions(options.maxRows, clientConfig), - ...getArrowOptions(clientConfig, this.serverProtocolVersion), - }); - - if (ProtocolVersion.supportsParameterizedQueries(this.serverProtocolVersion)) { - request.parameters = getQueryParameters(options.namedParameters, options.ordinalParameters); - } - - const serializedQueryTags = serializeQueryTags(options.queryTags); - if (serializedQueryTags !== undefined) { - request.confOverlay = { ...request.confOverlay, query_tags: serializedQueryTags }; - } - - if (ProtocolVersion.supportsCloudFetch(this.serverProtocolVersion)) { - request.canDownloadResult = options.useCloudFetch ?? clientConfig.useCloudFetch; - } - - if (ProtocolVersion.supportsArrowCompression(this.serverProtocolVersion) && request.canDownloadResult !== true) { - request.canDecompressLZ4Result = (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4()); - } + const opBackend = await this.backend.executeStatement(statement, options); + await this.failIfClosed(); + const operation = this.wrapOperation(opBackend); - const operationPromise = driver.executeStatement(request); - const response = await this.handleResponse(operationPromise); - const operation = this.createOperation(response); - - // If `stagingAllowedLocalPath` is provided - assume that operation possibly may be a staging operation. - // To know for sure, fetch metadata and check a `isStagingOperation` flag. If it happens that it wasn't - // a staging operation - not a big deal, we just fetched metadata earlier, but operation is still usable - // and user can get data from it. - // If `stagingAllowedLocalPath` is not provided - don't do anything to the operation. In a case of regular - // operation, everything will work as usual. In a case of staging operation, it will be processed like any - // other query - it will be possible to get data from it as usual, or use other operation methods. + // Staging detection: only run when stagingAllowedLocalPath is provided. if (options.stagingAllowedLocalPath !== undefined) { - const metadata = await operation.getMetadata(); + const metadata = await operation.getResultMetadata(); if (metadata.isStagingOperation) { const allowedLocalPath = Array.isArray(options.stagingAllowedLocalPath) ? options.stagingAllowedLocalPath @@ -276,7 +115,6 @@ export default class DBSQLSession implements IDBSQLSession { } const row = rows[0] as StagingResponse; - // For REMOVE operation local file is not available, so no need to validate it if (row.localFile !== undefined) { let allowOperation = false; @@ -328,7 +166,6 @@ export default class DBSQLSession implements IDBSQLSession { } const fileStream = fs.createWriteStream(localFile); - // `pipeline` will do all the dirty job for us, including error handling and closing all the streams properly return pipeline(response.body, fileStream); } @@ -337,13 +174,6 @@ export default class DBSQLSession implements IDBSQLSession { const agent = await connectionProvider.getAgent(); const response = await fetch(presignedUrl, { method: 'DELETE', headers, agent }); - // Looks that AWS and Azure have a different behavior of HTTP `DELETE` for non-existing files. - // AWS assumes that - since file already doesn't exist - the goal is achieved, and returns HTTP 200. - // Azure, on the other hand, is somewhat stricter and check if file exists before deleting it. And if - // file doesn't exist - Azure returns HTTP 404. - // - // For us, it's totally okay if file didn't exist before removing. So when we get an HTTP 404 - - // just ignore it and report success. This way we can have a uniform library behavior for all clouds if (!response.ok && response.status !== 404) { throw new StagingError(`HTTP error ${response.status} ${response.statusText}`); } @@ -368,7 +198,6 @@ export default class DBSQLSession implements IDBSQLSession { method: 'PUT', headers: { ...headers, - // This header is required by server 'Content-Length': fileInfo.size.toString(), }, agent, @@ -387,16 +216,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getTypeInfo(request: TypeInfoRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getTypeInfo({ - sessionHandle: this.sessionHandle, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getTypeInfo(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -407,16 +229,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getCatalogs(request: CatalogsRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getCatalogs({ - sessionHandle: this.sessionHandle, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getCatalogs(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -427,18 +242,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getSchemas(request: SchemasRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getSchemas({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getSchemas(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -449,20 +255,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getTables(request: TablesRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getTables({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - tableName: request.tableName, - tableTypes: request.tableTypes, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getTables(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -473,16 +268,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getTableTypes(request: TableTypesRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getTableTypes({ - sessionHandle: this.sessionHandle, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getTableTypes(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -493,20 +281,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getColumns(request: ColumnsRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getColumns({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - tableName: request.tableName, - columnName: request.columnName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getColumns(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -517,36 +294,16 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getFunctions(request: FunctionsRequest): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getFunctions({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - functionName: request.functionName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getFunctions(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } public async getPrimaryKeys(request: PrimaryKeysRequest): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getPrimaryKeys({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - tableName: request.tableName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getPrimaryKeys(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -557,22 +314,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getCrossReference(request: CrossReferenceRequest): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getCrossReference({ - sessionHandle: this.sessionHandle, - parentCatalogName: request.parentCatalogName, - parentSchemaName: request.parentSchemaName, - parentTableName: request.parentTableName, - foreignCatalogName: request.foreignCatalogName, - foreignSchemaName: request.foreignSchemaName, - foreignTableName: request.foreignTableName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getCrossReference(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -585,35 +329,20 @@ export default class DBSQLSession implements IDBSQLSession { return Status.success(); } - // Close owned operations one by one, removing successfully closed ones from the list await this.operations.closeAll(); - const driver = await this.context.getDriver(); - const response = await driver.closeSession({ - sessionHandle: this.sessionHandle, - }); - // check status for being successful - Status.assert(response.status); + const status = await this.backend.close(); - // notify owner connection this.onClose?.(); this.isOpen = false; this.context.getLogger().log(LogLevel.debug, `Session closed with id: ${this.id}`); - return new Status(response.status); + return status; } - private createOperation(response: OperationResponseShape): DBSQLOperation { - Status.assert(response.status); - const handle = definedOrError(response.operationHandle); - const operation = new DBSQLOperation({ - handle, - directResults: response.directResults, - context: this.context, - }); - + private wrapOperation(backend: IOperationBackend): DBSQLOperation { + const operation = new DBSQLOperation({ backend, context: this.context }); this.operations.add(operation); - return operation; } @@ -622,13 +351,4 @@ export default class DBSQLSession implements IDBSQLSession { throw new HiveDriverError('The session was closed or has expired'); } } - - private async handleResponse(requestPromise: Promise): Promise { - // Currently, after being closed sessions remains usable - server will not - // error out when trying to run operations on closed session. So it's - // basically useless to process any errors here - const result = await requestPromise; - await this.failIfClosed(); - return result; - } } diff --git a/lib/contracts/IBackend.ts b/lib/contracts/IBackend.ts new file mode 100644 index 00000000..2e5edd16 --- /dev/null +++ b/lib/contracts/IBackend.ts @@ -0,0 +1,34 @@ +import { ConnectionOptions, OpenSessionRequest } from './IDBSQLClient'; +import ISessionBackend from './ISessionBackend'; + +/** + * Top-level backend dispatch handle. One instance per `DBSQLClient`, + * chosen at `connect()` time based on the `useSEA` flag and never + * re-selected per-call. + */ +export default interface IBackend { + /** + * Establish backend-level state before any session is opened. Implementations + * consume `options` to build backend-specific connection parameters (e.g. the + * SEA backend derives napi-binding `SeaNativeConnectionOptions` from the auth + * + host fields here). Transport-layer connection providers are owned by + * `DBSQLClient` (via `IClientContext`) and exposed to backends through + * constructor injection. + */ + connect(options: ConnectionOptions): Promise; + + /** + * Open a session. Returned `ISessionBackend` is owned by the caller + * and torn down via its own `close()`. + */ + openSession(request: OpenSessionRequest): Promise; + + /** + * Backend-level teardown. Transport-layer cleanup (connection provider, + * thrift client, auth provider) is owned by `DBSQLClient` and runs + * after this returns. Implementations release backend-internal resources + * here, and MUST be safe to call on a partially-initialized backend + * (i.e. after a failed `connect()`). + */ + close(): Promise; +} diff --git a/lib/contracts/IOperation.ts b/lib/contracts/IOperation.ts index 1d0bb9a1..bbeed622 100644 --- a/lib/contracts/IOperation.ts +++ b/lib/contracts/IOperation.ts @@ -1,6 +1,7 @@ import { Readable, ReadableOptions } from 'node:stream'; import { TGetOperationStatusResp, TTableSchema } from '../../thrift/TCLIService_types'; import Status from '../dto/Status'; +import { ResultMetadata } from './ResultMetadata'; export type OperationStatusCallback = (progress: TGetOperationStatusResp) => unknown; @@ -59,7 +60,10 @@ export default interface IOperation { fetchAll(options?: FetchOptions): Promise>; /** - * Request status of operation + * Request status of operation. Returns the Thrift wire response for + * back-compat. New code should prefer {@link IOperation.getResultMetadata} + * for metadata and may consume the neutral `IOperationBackend.status` via + * a typed downcast when implementing alternative backends. * * @param progress */ @@ -90,6 +94,12 @@ export default interface IOperation { */ getSchema(options?: GetSchemaOptions): Promise; + /** + * Fetch result-set metadata in the backend-neutral `ResultMetadata` shape. + * Prefer this over the Thrift-shaped surface for new code. + */ + getResultMetadata(): Promise; + iterateChunks(options?: IteratorOptions): IOperationChunksIterator; iterateRows(options?: IteratorOptions): IOperationRowsIterator; diff --git a/lib/contracts/IOperationBackend.ts b/lib/contracts/IOperationBackend.ts new file mode 100644 index 00000000..4c17020b --- /dev/null +++ b/lib/contracts/IOperationBackend.ts @@ -0,0 +1,55 @@ +import Status from '../dto/Status'; +import { WaitUntilReadyOptions } from './IOperation'; +import { OperationStatus } from './OperationStatus'; +import { ResultMetadata } from './ResultMetadata'; + +/** + * What a `DBSQLOperation` needs from its backend. Returned by + * `ISessionBackend.executeStatement` and the metadata methods. + */ +export default interface IOperationBackend { + /** Operation identifier. */ + readonly id: string; + + /** + * Whether this operation has a result set. Initial value may be derived + * from the create-operation response; implementations MUST refresh it + * from terminal status responses (the Thrift impl updates + * `operationHandle.hasResultSet` inside `processOperationStatusResponse`). + * `readonly` here means external callers cannot reassign the property — + * not that the underlying value is fixed at construction time. + */ + readonly hasResultSet: boolean; + + /** Fetch the next chunk of result rows. */ + fetchChunk(options: { limit: number; disableBuffering?: boolean }): Promise>; + + /** Whether more rows are available beyond what has been fetched. */ + hasMore(): Promise; + + /** + * Poll the backend until the operation reaches a terminal state. + * + * MUST throw `OperationStateError` (with one of `OperationStateErrorCode.{Canceled, + * Closed, Error, Timeout, Unknown}`) on terminal non-success states. The + * `DBSQLOperation` facade depends on `Canceled` and `Closed` codes to mirror + * the operation into its closed/cancelled flags; future implementations must + * use the same error type for the facade to stay in sync. + */ + waitUntilReady(options?: WaitUntilReadyOptions): Promise; + + /** + * Fetch operation status as a neutral `OperationStatus`. Pass `progress: true` + * to request that the backend include a progress payload. + */ + status(progress: boolean): Promise; + + /** Fetch result-set metadata (schema, format, lz4 flag, arrow schema, staging flag). */ + getResultMetadata(): Promise; + + /** Cancel the operation. */ + cancel(): Promise; + + /** Close the operation. Idempotent. */ + close(): Promise; +} diff --git a/lib/contracts/ISessionBackend.ts b/lib/contracts/ISessionBackend.ts new file mode 100644 index 00000000..2404dc68 --- /dev/null +++ b/lib/contracts/ISessionBackend.ts @@ -0,0 +1,60 @@ +import IOperationBackend from './IOperationBackend'; +import { + ExecuteStatementOptions, + TypeInfoRequest, + CatalogsRequest, + SchemasRequest, + TablesRequest, + TableTypesRequest, + ColumnsRequest, + FunctionsRequest, + PrimaryKeysRequest, + CrossReferenceRequest, +} from './IDBSQLSession'; +import Status from '../dto/Status'; +import InfoValue from '../dto/InfoValue'; + +/** + * What a `DBSQLSession` needs from its backend. Returned by + * `IBackend.openSession()`. Lifecycle tied to a single `DBSQLSession`. + */ +export default interface ISessionBackend { + /** Session identifier. */ + readonly id: string; + + /** Returns general information about the data source. */ + getInfo(infoType: number): Promise; + + /** Executes DDL/DML statements. */ + executeStatement(statement: string, options: ExecuteStatementOptions): Promise; + + /** Information about supported data types. */ + getTypeInfo(request: TypeInfoRequest): Promise; + + /** List of catalogs. */ + getCatalogs(request: CatalogsRequest): Promise; + + /** List of schemas. */ + getSchemas(request: SchemasRequest): Promise; + + /** List of tables. */ + getTables(request: TablesRequest): Promise; + + /** List of supported table types. */ + getTableTypes(request: TableTypesRequest): Promise; + + /** Full column information for a table. */ + getColumns(request: ColumnsRequest): Promise; + + /** Information about a function. */ + getFunctions(request: FunctionsRequest): Promise; + + /** Primary keys of a table. */ + getPrimaryKeys(request: PrimaryKeysRequest): Promise; + + /** Foreign-key relationships between two tables. */ + getCrossReference(request: CrossReferenceRequest): Promise; + + /** Close the session. Idempotent. */ + close(): Promise; +} diff --git a/lib/contracts/InternalConnectionOptions.ts b/lib/contracts/InternalConnectionOptions.ts new file mode 100644 index 00000000..a115aa47 --- /dev/null +++ b/lib/contracts/InternalConnectionOptions.ts @@ -0,0 +1,21 @@ +/** + * Internal, non-exported extension of `ConnectionOptions`. Carries M0-only + * flags that should not appear in the published `.d.ts`. + * + * Matches the Python connector pattern: there, `use_sea` is consumed via + * `kwargs.get("use_sea", False)` and is intentionally absent from the typed + * signature (see `databricks-sql-python/src/databricks/sql/session.py`). + * + * Callers cast `ConnectionOptions` to this type *only* at the read site + * inside the driver; user code that wants to set `useSEA` may still do so + * via an untyped object literal — the option is not part of the public + * contract and may be removed without notice. + */ +export interface InternalConnectionOptions { + /** + * Opt-in flag to dispatch through the Statement Execution API (SEA) + * backend instead of the default Thrift backend. Defaults to `false`. + * @internal Not stable; M0 stub only. + */ + useSEA?: boolean; +} diff --git a/lib/contracts/OperationStatus.ts b/lib/contracts/OperationStatus.ts new file mode 100644 index 00000000..7f167aba --- /dev/null +++ b/lib/contracts/OperationStatus.ts @@ -0,0 +1,56 @@ +/** + * Backend-neutral operation state. Mirrors the kernel/pyo3 `StatementStatus` + * and the Python connector's `CommandState`, so a SEA `IOperationBackend` + * implementer can return these without depending on the Thrift wire enum. + * + * Thrift mapping (in `ThriftOperationBackend.adaptOperationStatus`): + * - INITIALIZED_STATE, PENDING_STATE → Pending + * - RUNNING_STATE → Running + * - FINISHED_STATE → Succeeded + * - CANCELED_STATE → Cancelled + * - CLOSED_STATE → Closed + * - ERROR_STATE, TIMEDOUT_STATE → Failed + * - UKNOWN_STATE / anything else → Unknown + */ +export enum OperationState { + Pending = 'Pending', + Running = 'Running', + Succeeded = 'Succeeded', + Failed = 'Failed', + Cancelled = 'Cancelled', + Closed = 'Closed', + Unknown = 'Unknown', +} + +/** + * Neutral status snapshot returned by `IOperationBackend.status()`. Backends + * adapt their wire format at the boundary; callers in `DBSQLOperation` and + * `IOperationBackend.waitUntilReady` switch on `state` alone. + * + * Fields beyond `state` are best-effort and may be undefined depending on + * what the backend exposes. + */ +export interface OperationStatus { + /** Current operation state. */ + state: OperationState; + + /** + * Whether this operation has produced (or is producing) a result set. + * Some backends only know this after the operation reaches a terminal + * state — undefined means "no signal from this backend". + */ + hasResultSet?: boolean; + + /** Human-readable error/display message, if the backend supplied one. */ + errorMessage?: string; + + /** SQL state code (e.g. "42000"), if available. */ + sqlState?: string; + + /** + * Opaque progress payload as returned by the backend when callers pass + * `progress: true`. Treated as untyped by the facade — passed through + * to `WaitUntilReadyOptions.callback` for the consumer to interpret. + */ + progressUpdateResponse?: unknown; +} diff --git a/lib/contracts/ResultMetadata.ts b/lib/contracts/ResultMetadata.ts new file mode 100644 index 00000000..5fc09a79 --- /dev/null +++ b/lib/contracts/ResultMetadata.ts @@ -0,0 +1,39 @@ +import { TTableSchema } from '../../thrift/TCLIService_types'; + +/** + * Backend-neutral result-format taxonomy. Mirrors the three on-wire shapes + * `ThriftOperationBackend` actually dispatches on (`COLUMN_BASED_SET`, + * `ARROW_BASED_SET`, `URL_BASED_SET`); a SEA implementer surfaces the same + * three so result-handling stays format-agnostic. + */ +export enum ResultFormat { + ColumnBased = 'COLUMN_BASED', + ArrowBased = 'ARROW_BASED', + UrlBased = 'URL_BASED', +} + +/** + * Neutral result-set metadata returned by `IOperationBackend.getResultMetadata()`. + * + * `schema` keeps the Thrift `TTableSchema` shape for now because the public + * `DBSQLOperation.getSchema()` and `getMetadata()` already expose it on + * `IOperation`; carrying it across the boundary preserves back-compat. The + * SEA backend will adapt its column descriptors into the same shape until + * the public IOperation surface is migrated in a later PR. + */ +export interface ResultMetadata { + /** Column schema; null if the operation has no result set. */ + schema?: TTableSchema; + + /** Wire format the result handler should dispatch on. */ + resultFormat: ResultFormat; + + /** Whether the result payload is LZ4-compressed. */ + lz4Compressed?: boolean; + + /** Optional Arrow IPC schema bytes (for ARROW_BASED / URL_BASED formats). */ + arrowSchema?: Buffer; + + /** True iff the operation is a staging (PUT/GET/REMOVE) operation. */ + isStagingOperation: boolean; +} diff --git a/lib/result/ArrowResultConverter.ts b/lib/result/ArrowResultConverter.ts index 57fa02af..b641d3a4 100644 --- a/lib/result/ArrowResultConverter.ts +++ b/lib/result/ArrowResultConverter.ts @@ -23,6 +23,145 @@ const { isArrowBigNumSymbol, bigNumToBigInt } = arrowUtils; type ArrowSchema = Schema; type ArrowSchemaField = Field>; +/** + * Metadata key carrying the original Arrow `Duration` time unit on + * fields that were rewritten to `Int64` by the SEA IPC pre-processor + * (`lib/sea/SeaArrowIpcDurationFix.ts`). We re-declare the constant + * here (rather than importing it) so the converter has no compile-time + * dependency on the SEA module — it's reused unchanged by the + * thrift-path which has no SEA awareness. + */ +const DURATION_UNIT_METADATA_KEY = 'databricks.arrow.duration_unit'; +const ZERO_BIGINT = BigInt(0); +const NS_PER_MICRO = BigInt(1_000); +const NS_PER_MILLI = BigInt(1_000_000); +const NS_PER_SEC = BigInt(1_000_000_000); +const MS_PER_DAY = BigInt(86_400_000); +const NS_PER_MIN = NS_PER_SEC * BigInt(60); +const NS_PER_HOUR = NS_PER_MIN * BigInt(60); +const NS_PER_DAY = NS_PER_HOUR * BigInt(24); + +/** + * Format an Arrow `Interval[YearMonth]` or `Interval[DayTime]` value + * into the canonical thrift string the JDBC/ODBC server emits: + * YEAR-MONTH → `"Y-M"` (e.g. 1 year 2 months → `"1-2"`) + * DAY-TIME → `"D HH:mm:ss.fffffffff"` + * (e.g. 1 day 02:03:04 → `"1 02:03:04.000000000"`) + * + * Arrow surfaces these as `Int32Array(2)` via the `GetVisitor` + * (`apache-arrow/visitor/get.js:177-185`): + * YEAR-MONTH: `[years, months]` (years/months derived from a single + * int32 holding total months) + * DAY-TIME: `[days, milliseconds]` (legacy two-int32 form) + * + * Negative intervals: the FULL interval is emitted with a leading `-` + * (Spark convention), and individual fields are unsigned. We mirror + * Spark's display. + */ +function formatArrowInterval(value: any, valueType: any): string { + // `value` is an Int32Array of length 2. + const a = Number(value[0]); + const b = Number(value[1]); + // unit 0 = YEAR_MONTH, unit 1 = DAY_TIME, unit 2 = MONTH_DAY_NANO + const unit = valueType?.unit; + if (unit === 0) { + return formatYearMonth(a, b); + } + // DAY_TIME: a = days, b = milliseconds (within the day, can be ≥0 or <0) + // We re-normalise: total milliseconds = a * 86_400_000 + b, then split into + // days, hours, minutes, seconds, nanoseconds (nanoseconds is always 0 + // because the legacy IntervalDayTime carries only millisecond precision). + const totalMs = BigInt(a) * MS_PER_DAY + BigInt(b); + return formatDayTimeFromTotal(totalMs * NS_PER_MILLI /* → ns */, 'NANOSECOND'); +} + +/** + * Format the (years, months) decomposition into `"Y-M"` (or `"-Y-M"` + * for negative intervals). Arrow's `getIntervalYearMonth` (in + * `apache-arrow/visitor/get.js:179`) decomposes a signed total-months + * int32 via integer truncation, so years and months always share the + * same sign. We render the absolute values with a single leading `-` + * to match the Spark display format used on the thrift path. + */ +function formatYearMonth(years: number, months: number): string { + const total = years * 12 + months; + if (total < 0) { + const abs = -total; + const y = Math.trunc(abs / 12); + const m = abs % 12; + return `-${y}-${m}`; + } + return `${years}-${months}`; +} + +/** + * Format an Arrow `Duration` value (rewritten by the SEA IPC + * pre-processor to `Int64`) into the thrift INTERVAL DAY-TIME string. + * + * @param value the duration value as `bigint` (signed nanos/micros/ + * millis/seconds depending on `unit`) + * @param unit one of `SECOND` / `MILLISECOND` / `MICROSECOND` / + * `NANOSECOND` (the original Arrow time unit, captured + * by `SeaArrowIpcDurationFix.ts`) + */ +function formatDurationToIntervalDayTime(value: bigint | number, unit: string): string { + const bi = typeof value === 'bigint' ? value : BigInt(value); + const nanos = toNanoseconds(bi, unit); + return formatDayTimeFromTotal(nanos, unit); +} + +/** + * Scale a duration value to nanoseconds based on its unit. + * + * SECOND → ×1_000_000_000 + * MILLISECOND → × 1_000_000 + * MICROSECOND → × 1_000 + * NANOSECOND → × 1 + */ +function toNanoseconds(value: bigint, unit: string): bigint { + switch (unit) { + case 'SECOND': + return value * NS_PER_SEC; + case 'MILLISECOND': + return value * NS_PER_MILLI; + case 'MICROSECOND': + return value * NS_PER_MICRO; + case 'NANOSECOND': + default: + return value; + } +} + +/** + * Format a signed total-nanoseconds value as `"D HH:mm:ss.fffffffff"`. + * Always emits 9 fractional digits to match the thrift driver's wire + * format (`"1 02:03:04.000000000"` — 9 digits regardless of the + * server-side storage precision). Negative values get a single + * leading `-`. + * + * The `unit` parameter is currently unused for formatting (the value + * is already in nanoseconds by the time we get here) but is retained + * for future use if a unit-aware precision is ever needed. + */ +function formatDayTimeFromTotal(totalNanos: bigint, _unit: string): string { + const sign = totalNanos < ZERO_BIGINT ? '-' : ''; + const abs = totalNanos < ZERO_BIGINT ? -totalNanos : totalNanos; + + const days = abs / NS_PER_DAY; + let rem = abs % NS_PER_DAY; + const hours = rem / NS_PER_HOUR; + rem %= NS_PER_HOUR; + const minutes = rem / NS_PER_MIN; + rem %= NS_PER_MIN; + const seconds = rem / NS_PER_SEC; + const subSeconds = rem % NS_PER_SEC; + + const pad2 = (n: bigint): string => n.toString().padStart(2, '0'); + const fraction = `.${subSeconds.toString().padStart(9, '0')}`; + + return `${sign}${days.toString()} ${pad2(hours)}:${pad2(minutes)}:${pad2(seconds)}${fraction}`; +} + export default class ArrowResultConverter implements IResultsProvider> { private readonly context: IClientContext; @@ -142,37 +281,52 @@ export default class ArrowResultConverter implements IResultsProvider private getRows(schema: ArrowSchema, rows: Array): Array { return rows.map((row) => { // First, convert native Arrow values to corresponding plain JS objects - const record = this.convertArrowTypes(row, undefined, schema.fields); + const record = this.convertArrowTypes(row, undefined, schema.fields, undefined); // Second, cast all the values to original Thrift types return this.convertThriftTypes(record); }); } - private convertArrowTypes(value: any, valueType: DataType | undefined, fields: Array = []): any { + private convertArrowTypes( + value: any, + valueType: DataType | undefined, + fields: Array = [], + field?: ArrowSchemaField, + ): any { if (value === null) { return value; } const fieldsMap: Record = {}; - for (const field of fields) { - fieldsMap[field.name] = field; + for (const f of fields) { + fieldsMap[f.name] = f; } // Convert structures to plain JS object and process all its fields recursively if (value instanceof StructRow) { const result = value.toJSON(); for (const key of Object.keys(result)) { - const field: ArrowSchemaField | undefined = fieldsMap[key]; - result[key] = this.convertArrowTypes(result[key], field?.type, field?.type.children || []); + const childField: ArrowSchemaField | undefined = fieldsMap[key]; + result[key] = this.convertArrowTypes( + result[key], + childField?.type, + childField?.type.children || [], + childField, + ); } return result; } if (value instanceof MapRow) { const result = value.toJSON(); // Map type consists of its key and value types. We need only value type here, key will be cast to string anyway - const field = fieldsMap.entries?.type.children.find((item) => item.name === 'value'); + const valueField = fieldsMap.entries?.type.children.find((item) => item.name === 'value'); for (const key of Object.keys(result)) { - result[key] = this.convertArrowTypes(result[key], field?.type, field?.type.children || []); + result[key] = this.convertArrowTypes( + result[key], + valueField?.type, + valueField?.type.children || [], + valueField, + ); } return result; } @@ -181,14 +335,28 @@ export default class ArrowResultConverter implements IResultsProvider if (value instanceof Vector) { const result = value.toJSON(); // Array type contains the only child which defines a type of each array's element - const field = fieldsMap.element; - return result.map((item) => this.convertArrowTypes(item, field?.type, field?.type.children || [])); + const elementField = fieldsMap.element; + return result.map((item) => + this.convertArrowTypes(item, elementField?.type, elementField?.type.children || [], elementField), + ); } if (DataType.isTimestamp(valueType)) { return new Date(value); } + // INTERVAL — Spark/Databricks SEA emits two flavours: native Arrow + // `Interval[YearMonth]` / `Interval[DayTime]` (handled here) and + // `Duration` (transparently rewritten to `Int64` upstream by + // `SeaArrowIpcDurationFix.ts`; handled in the bigint/Int64 branch + // below). In every case we coerce to the canonical thrift string + // form so the SEA path is byte-identical with the thrift path: + // YEAR-MONTH → `"Y-M"` + // DAY-TIME → `"D HH:mm:ss.fffffffff"` + if (DataType.isInterval(valueType)) { + return formatArrowInterval(value, valueType); + } + // Convert big number values to BigInt // Decimals are also represented as big numbers in Arrow, so additionally process them (convert to float) if (value instanceof Object && value[isArrowBigNumSymbol]) { @@ -196,16 +364,38 @@ export default class ArrowResultConverter implements IResultsProvider if (DataType.isDecimal(valueType)) { return Number(result) / 10 ** valueType.scale; } + // Duration columns rewritten to Int64 — detect via metadata. + const durationUnit = field?.metadata.get(DURATION_UNIT_METADATA_KEY); + if (durationUnit) { + return formatDurationToIntervalDayTime(result, durationUnit); + } return result; } // Convert binary data to Buffer if (value instanceof Uint8Array) { + // INTERVAL DAY-TIME / YEAR-MONTH that apache-arrow surfaced as + // an Int32Array (size 2). `Uint8Array.isInstanceOf` is true for + // every TypedArray subclass, so we have to check the parent type + // first. The `DataType.isInterval` branch above already handles + // the case where Arrow knew the field was an interval — this + // fallback covers schemas where the interval surfaced as bare + // bytes (defensive; not exercised in M0). return Buffer.from(value); } + // Bigint fallback — for raw bigints (not BigNum wrappers), the + // duration_unit metadata also gates the INTERVAL DAY-TIME format. + if (typeof value === 'bigint') { + const durationUnit = field?.metadata.get(DURATION_UNIT_METADATA_KEY); + if (durationUnit) { + return formatDurationToIntervalDayTime(value, durationUnit); + } + return Number(value); + } + // Return other values as is - return typeof value === 'bigint' ? Number(value) : value; + return value; } private convertThriftTypes(record: Record): any { diff --git a/lib/sea/SeaArrowIpc.ts b/lib/sea/SeaArrowIpc.ts new file mode 100644 index 00000000..c111b6bd --- /dev/null +++ b/lib/sea/SeaArrowIpc.ts @@ -0,0 +1,257 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { RecordBatchReader, Schema, Field, DataType, TypeMap } from 'apache-arrow'; +import { TTableSchema, TTypeId, TPrimitiveTypeEntry } from '../../thrift/TCLIService_types'; +import { rewriteDurationToInt64, DURATION_UNIT_METADATA_KEY } from './SeaArrowIpcDurationFix'; + +/** + * Field metadata key used by the kernel to attach the original Databricks + * SQL type name to each Arrow field. See `databricks-sql-kernel/src/reader/mod.rs`. + */ +const DATABRICKS_TYPE_NAME = 'databricks.type_name'; + +/** + * Decode an Arrow IPC stream payload (schema header + zero-or-more + * record-batch messages) into its row count. + * + * Returns `{ schema, rowCount }`. The schema is left intact as the + * apache-arrow Schema object so callers can reuse it; the rowCount is + * the sum of `RecordBatch.numRows` across every record-batch message + * in the stream. + * + * Why we parse upfront: `ArrowResultConverter` consumes `ArrowBatch` + * objects which carry an explicit `rowCount`. The kernel's IPC payload + * does not carry a separate count — only per-RecordBatch numRows. We + * walk the messages once to sum them so the converter sees the same + * shape as the thrift path (`ArrowResultHandler.fetchNext` at + * `lib/result/ArrowResultHandler.ts:55`). + * + * Re-parsing inside the converter is unavoidable because `RecordBatch` + * instances created here cannot be passed across the converter's + * `Buffer[]` boundary without rewriting the converter. Callers that already + * patched the IPC bytes can set `alreadyPatched` to avoid running the + * FlatBuffer rewrite twice on the hot fetch path. + */ +export function decodeIpcBatch( + ipcBytes: Buffer, + options: { alreadyPatched?: boolean } = {}, +): { schema: Schema; rowCount: number } { + const patched = options.alreadyPatched ? ipcBytes : rewriteDurationToInt64(ipcBytes); + const reader = RecordBatchReader.from(patched); + // Eagerly open so `schema` is populated. + reader.open(); + const { schema } = reader; + + let rowCount = 0; + // Iterate all record batches in the stream and sum row counts. + for (const batch of reader) { + rowCount += batch.numRows; + } + return { schema, rowCount }; +} + +/** + * Decode an Arrow IPC schema payload (no record batches) into the + * apache-arrow Schema object. + */ +export function decodeIpcSchema(ipcBytes: Buffer): Schema { + const patched = rewriteDurationToInt64(ipcBytes); + const reader = RecordBatchReader.from(patched); + reader.open(); + return reader.schema; +} + +/** + * Pre-process raw IPC bytes from the kernel so they're consumable by + * `apache-arrow@13`. The current transformation is `Duration → Int64` + * with the original duration unit preserved in field metadata (see + * `SeaArrowIpcDurationFix.ts`). Returned bytes are byte-identical to + * the input when no transformation is needed. + * + * Exposed so callers can pre-patch the buffer **once** and pass the + * result through both `decodeIpcBatch` (for row-count extraction in + * `SeaResultsProvider`) and `ArrowResultConverter.fetchNext` (which + * re-decodes the same bytes via `RecordBatchReader.from`). Without + * this, the converter would re-throw on `Duration` because it never + * sees the patched bytes. + */ +export function patchIpcBytes(ipcBytes: Buffer): Buffer { + return rewriteDurationToInt64(ipcBytes); +} + +/** + * Map an Arrow `DataType` (with optional `databricks.type_name` + * metadata) onto the closest Thrift `TTypeId`. + * + * This is the synthesis step that lets the existing + * `ArrowResultConverter` Phase-2 dispatch (`convertThriftValue` in + * `lib/result/utils.ts:61-98`) keep working unchanged for the SEA + * path. Phase-2 keys exclusively off `TPrimitiveTypeEntry.type` per + * column, so we synthesize a `TColumnDesc` whose `TTypeId` matches the + * server-emitted Arrow type as closely as possible. + * + * Resolution order: + * 1. The kernel attaches `databricks.type_name` (e.g. "DECIMAL", + * "INTERVAL", "STRUCT") to each field's metadata. Prefer that when + * present — it carries the original SQL semantic that the Arrow + * type alone can lose (e.g. INTERVAL → Utf8 with metadata). + * 2. Fall back to the Arrow `DataType.typeId` for primitive types. + * + * This matches the JDBC and Python drivers' policy of trusting the + * server's logical type assignment over the wire-level Arrow encoding. + */ +function arrowTypeToTTypeId(field: Field): TTypeId { + const typeName = field.metadata.get(DATABRICKS_TYPE_NAME)?.toUpperCase(); + + switch (typeName) { + case 'BOOLEAN': + return TTypeId.BOOLEAN_TYPE; + case 'TINYINT': + case 'BYTE': + return TTypeId.TINYINT_TYPE; + case 'SMALLINT': + case 'SHORT': + return TTypeId.SMALLINT_TYPE; + case 'INT': + case 'INTEGER': + return TTypeId.INT_TYPE; + case 'BIGINT': + case 'LONG': + return TTypeId.BIGINT_TYPE; + case 'FLOAT': + case 'REAL': + return TTypeId.FLOAT_TYPE; + case 'DOUBLE': + return TTypeId.DOUBLE_TYPE; + case 'STRING': + return TTypeId.STRING_TYPE; + case 'VARCHAR': + return TTypeId.VARCHAR_TYPE; + case 'CHAR': + return TTypeId.CHAR_TYPE; + case 'BINARY': + return TTypeId.BINARY_TYPE; + case 'DATE': + return TTypeId.DATE_TYPE; + case 'TIMESTAMP': + case 'TIMESTAMP_NTZ': + return TTypeId.TIMESTAMP_TYPE; + case 'DECIMAL': + return TTypeId.DECIMAL_TYPE; + case 'INTERVAL': + case 'INTERVAL DAY': + case 'INTERVAL DAY TO HOUR': + case 'INTERVAL DAY TO MINUTE': + case 'INTERVAL DAY TO SECOND': + case 'INTERVAL HOUR': + case 'INTERVAL HOUR TO MINUTE': + case 'INTERVAL HOUR TO SECOND': + case 'INTERVAL MINUTE': + case 'INTERVAL MINUTE TO SECOND': + case 'INTERVAL SECOND': + return TTypeId.INTERVAL_DAY_TIME_TYPE; + case 'INTERVAL YEAR': + case 'INTERVAL YEAR TO MONTH': + case 'INTERVAL MONTH': + return TTypeId.INTERVAL_YEAR_MONTH_TYPE; + case 'ARRAY': + return TTypeId.ARRAY_TYPE; + case 'MAP': + return TTypeId.MAP_TYPE; + case 'STRUCT': + return TTypeId.STRUCT_TYPE; + case 'NULL': + case 'VOID': + return TTypeId.NULL_TYPE; + default: + break; + } + + // Fall back to Arrow's own type id when no databricks metadata is set + // (e.g. unit tests constructing batches without metadata). + const arrowType = field.type; + if (DataType.isBool(arrowType)) return TTypeId.BOOLEAN_TYPE; + if (DataType.isInt(arrowType)) { + // Duration columns are rewritten to Int64 with a + // `databricks.arrow.duration_unit` metadata marker (see + // `SeaArrowIpcDurationFix.ts`). Surface them as INTERVAL_DAY_TIME + // so the converter formats them back into the thrift string form. + if (arrowType.bitWidth === 64 && field.metadata.has(DURATION_UNIT_METADATA_KEY)) { + return TTypeId.INTERVAL_DAY_TIME_TYPE; + } + switch (arrowType.bitWidth) { + case 8: + return TTypeId.TINYINT_TYPE; + case 16: + return TTypeId.SMALLINT_TYPE; + case 32: + return TTypeId.INT_TYPE; + case 64: + return TTypeId.BIGINT_TYPE; + default: + return TTypeId.BIGINT_TYPE; + } + } + if (DataType.isFloat(arrowType)) { + // arrow Float precision: 16=HALF, 32=SINGLE, 64=DOUBLE + return arrowType.precision === 2 ? TTypeId.DOUBLE_TYPE : TTypeId.FLOAT_TYPE; + } + if (DataType.isDecimal(arrowType)) return TTypeId.DECIMAL_TYPE; + if (DataType.isUtf8(arrowType)) return TTypeId.STRING_TYPE; + if (DataType.isBinary(arrowType)) return TTypeId.BINARY_TYPE; + if (DataType.isDate(arrowType)) return TTypeId.DATE_TYPE; + if (DataType.isTimestamp(arrowType)) return TTypeId.TIMESTAMP_TYPE; + // Native Arrow Interval types. The server-side INTERVAL YEAR-MONTH + // (and the legacy IntervalDayTime variant) come through with type + // id 11 / -25 / -26 — apache-arrow@13 surfaces them as `Int32Array` + // pairs which the converter formats to thrift's `"Y-M"` / day-time + // strings. + if (DataType.isInterval(arrowType)) { + // unit 0 = YEAR_MONTH, unit 1 = DAY_TIME, unit 2 = MONTH_DAY_NANO + return arrowType.unit === 0 ? TTypeId.INTERVAL_YEAR_MONTH_TYPE : TTypeId.INTERVAL_DAY_TIME_TYPE; + } + if (DataType.isList(arrowType)) return TTypeId.ARRAY_TYPE; + if (DataType.isMap(arrowType)) return TTypeId.MAP_TYPE; + if (DataType.isStruct(arrowType)) return TTypeId.STRUCT_TYPE; + if (DataType.isNull(arrowType)) return TTypeId.NULL_TYPE; + + return TTypeId.STRING_TYPE; +} + +/** + * Synthesize a Thrift `TTableSchema` from an Arrow schema decoded out + * of the kernel's IPC stream. Used by `SeaOperationBackend.getResultMetadata` + * to drive `ArrowResultConverter.convertThriftTypes` (Phase 2) without + * changing that code. + */ +export function arrowSchemaToThriftSchema(arrowSchema: Schema): TTableSchema { + const columns = arrowSchema.fields.map((field, index) => { + const primitiveEntry: TPrimitiveTypeEntry = { + type: arrowTypeToTTypeId(field), + }; + return { + columnName: field.name, + typeDesc: { + types: [ + { + primitiveEntry, + }, + ], + }, + position: index + 1, + }; + }); + return { columns }; +} diff --git a/lib/sea/SeaArrowIpcDurationFix.ts b/lib/sea/SeaArrowIpcDurationFix.ts new file mode 100644 index 00000000..c7e8f65c --- /dev/null +++ b/lib/sea/SeaArrowIpcDurationFix.ts @@ -0,0 +1,609 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Pre-process an Arrow IPC stream payload to make it consumable by + * `apache-arrow@13`, which predates the addition of the `Duration` type + * (FlatBuffer `Type` enum id 18) in version 14. + * + * The Databricks SQL server emits INTERVAL DAY-TIME columns as Arrow + * `Duration(MICROSECOND)` in the SEA IPC stream. apache-arrow@13's + * `decodeFieldType` (`node_modules/apache-arrow/ipc/metadata/message.js:339-397`) + * throws `Unrecognized type: "Duration" (18)` on the schema FlatBuffer + * before any record batch is read, breaking the entire SEA path for any + * result that contains an INTERVAL DAY-TIME column. + * + * Because the physical layout of an Arrow `Duration` column is + * **identical** to an Arrow `Int64` column (8 bytes of signed integer per + * row in the values buffer, plus the validity bitmap), we can losslessly + * rewrite the schema FlatBuffer to advertise `Int(bitWidth=64, + * signed=true)` in place of `Duration(unit)`. The record-batch body + * bytes pass through unchanged. We embed the original `Duration` time + * unit (`SECOND`/`MILLISECOND`/`MICROSECOND`/`NANOSECOND`) into the + * rewritten field's `custom_metadata` under the key + * `databricks.arrow.duration_unit` so the JS converter can format the + * Int64 value back into a thrift-equivalent string (e.g. + * `"1 02:03:04.000000000"`). + * + * Why this lives in its own file: the rewriter is the only place in the + * codebase that needs to construct FlatBuffers by hand using the + * `flatbuffers` library; isolating it keeps `SeaArrowIpc.ts` focused on + * the high-level Arrow-decoded views. + * + * @see lib/result/ArrowResultConverter.ts — Phase-1 INTERVAL formatting + * reads the metadata key written here. + * @see findings/parity-mismatch/round5-implementation-2026-05-15.md — + * original failure mode (`Unrecognized type: "Duration" (18)`). + */ + +import * as flatbuffers from 'flatbuffers'; +// We reach into apache-arrow's internal FlatBuffer accessor modules +// rather than the high-level Schema/Field classes because the latter +// throw on the `Duration` type id 18 (`apache-arrow@13` predates the +// `Duration` enum entry). The internal `fb/*` modules are generated +// FlatBuffer code and recognize every type id present in the +// FlatBuffer schema, including `Duration`, so we can decode the +// original schema and rebuild it with `Duration` rewritten to `Int64`. +// eslint-disable-next-line import/no-internal-modules +import { Message } from 'apache-arrow/fb/message'; +// eslint-disable-next-line import/no-internal-modules +import { MessageHeader } from 'apache-arrow/fb/message-header'; +// eslint-disable-next-line import/no-internal-modules +import { Schema as FbSchema } from 'apache-arrow/fb/schema'; +// eslint-disable-next-line import/no-internal-modules +import { Field as FbField } from 'apache-arrow/fb/field'; +// eslint-disable-next-line import/no-internal-modules +import { KeyValue as FbKeyValue } from 'apache-arrow/fb/key-value'; +// eslint-disable-next-line import/no-internal-modules +import { Type as FbType } from 'apache-arrow/fb/type'; +// eslint-disable-next-line import/no-internal-modules +import { Duration as FbDuration } from 'apache-arrow/fb/duration'; +// eslint-disable-next-line import/no-internal-modules +import { Int as FbInt } from 'apache-arrow/fb/int'; +// eslint-disable-next-line import/no-internal-modules +import { TimeUnit as FbTimeUnit } from 'apache-arrow/fb/time-unit'; + +/** + * Metadata key written onto rewritten fields to preserve the original + * `Duration` time unit. Consumed by + * `lib/result/ArrowResultConverter.ts` Phase 1 to choose the correct + * scale when formatting INTERVAL DAY-TIME values. + */ +export const DURATION_UNIT_METADATA_KEY = 'databricks.arrow.duration_unit'; + +const IPC_CONTINUATION_MARKER = 0xffffffff; + +const TIME_UNIT_NAME: Record = { + [FbTimeUnit.SECOND]: 'SECOND', + [FbTimeUnit.MILLISECOND]: 'MILLISECOND', + [FbTimeUnit.MICROSECOND]: 'MICROSECOND', + [FbTimeUnit.NANOSECOND]: 'NANOSECOND', +}; + +/** + * Walk an IPC stream payload and rewrite any `Duration` field in the + * schema message to `Int64` (preserving the original time unit in + * custom metadata). Subsequent record-batch messages are forwarded + * verbatim — the data layout matches the rewritten `Int64` type + * bit-for-bit. + * + * If the schema contains no `Duration` fields, the input buffer is + * returned unchanged (zero-copy fast path). + * + * The caller is expected to pass a complete IPC stream payload (the + * full byte buffer the kernel returned for one `fetchNextBatch` call, + * or the schema-only payload from `statement.schema()`). Multi-segment + * stream payloads are supported; we walk through each message until + * the buffer is exhausted. + * + * @param ipcBytes raw IPC stream bytes from the napi binding + * @returns either the original buffer (no rewrite needed) or a fresh + * buffer with the schema message replaced + */ +export function rewriteDurationToInt64(ipcBytes: Buffer | Uint8Array): Buffer { + const view = ipcBytes instanceof Buffer ? ipcBytes : Buffer.from(ipcBytes); + + // First message must be the schema. If we can't find a schema message + // we leave the bytes alone — better to surface apache-arrow's normal + // error path than to mask a malformed stream. + const first = readMessageAt(view, 0); + if (!first) { + return view; + } + + if (first.message.headerType() !== MessageHeader.Schema) { + return view; + } + + const rewrittenSchema = maybeRewriteSchemaMessage(first.messageBytes); + if (!rewrittenSchema) { + // No Duration fields; nothing to do. + return view; + } + + // Splice the rewritten schema back into the stream: continuation + // marker + new metadata length + new metadata bytes + everything after + // the original schema message (body of schema is empty per Arrow spec; + // record batches follow). + const outputs: Buffer[] = []; + outputs.push(encodeContinuationAndLength(rewrittenSchema.byteLength)); + outputs.push(rewrittenSchema); + // Schema messages have no body (bodyLength=0 always — Arrow spec). + // Forward everything after the schema's metadata bytes unchanged. + const tailStart = first.totalEnd; + if (tailStart < view.byteLength) { + outputs.push(view.subarray(tailStart)); + } + + return Buffer.concat(outputs); +} + +/** + * Read one IPC message at the given offset. Returns the parsed Message + * object and byte ranges, or `null` if the buffer is exhausted. + * + * IPC stream message format (post-0.15): + * [continuation: 0xFFFFFFFF (4 bytes LE)] [length: int32 LE] + * [metadata: flatbuffer Message of `length` bytes] [body: bodyLength bytes] + * + * Pre-0.15 streams omit the continuation marker — the first 4 bytes are + * the metadata length directly. apache-arrow handles both + * (`message.js:44-50`); we mirror that here. + */ +function readMessageAt( + view: Buffer, + start: number, +): { + message: Message; + messageBytes: Buffer; + metadataStart: number; + metadataEnd: number; + bodyEnd: number; + totalEnd: number; +} | null { + if (start + 4 > view.byteLength) { + return null; + } + let cursor = start; + let metadataLength = view.readInt32LE(cursor); + cursor += 4; + + // Continuation marker (0xFFFFFFFF reads as -1 as int32) — followed by + // the actual length. + if (metadataLength === -1) { + if (cursor + 4 > view.byteLength) { + return null; + } + metadataLength = view.readInt32LE(cursor); + cursor += 4; + } + + if (metadataLength === 0) { + return null; + } + + const metadataStart = cursor; + const metadataEnd = cursor + metadataLength; + if (metadataEnd > view.byteLength) { + return null; + } + + const metadataBytes = view.subarray(metadataStart, metadataEnd); + const bb = new flatbuffers.ByteBuffer(metadataBytes); + const message = Message.getRootAsMessage(bb); + + const bodyLength = Number(message.bodyLength()); + const bodyStart = metadataEnd; + const bodyEnd = bodyStart + bodyLength; + if (bodyEnd > view.byteLength) { + // Malformed; let apache-arrow surface the error downstream. + return null; + } + + return { + message, + messageBytes: metadataBytes, + metadataStart, + metadataEnd, + bodyEnd, + totalEnd: bodyEnd, + }; +} + +/** + * If the schema message contains any `Duration` fields, returns a fresh + * FlatBuffer-encoded Message containing the rewritten schema. Otherwise + * returns `null` so the caller can short-circuit. + */ +function maybeRewriteSchemaMessage(schemaMessageBytes: Buffer): Buffer | null { + const bb = new flatbuffers.ByteBuffer(schemaMessageBytes); + const message = Message.getRootAsMessage(bb); + const fbSchema = message.header(new FbSchema()) as FbSchema | null; + if (!fbSchema) { + return null; + } + + // Scan top-level fields and children for Duration. We rewrite only + // top-level Duration fields for M0 (Spark INTERVAL DAY-TIME surfaces + // as a top-level column — children of Struct/List/Map are out of + // scope until we see a real-world payload with nested Duration). + let hasDuration = false; + const fieldsLength = fbSchema.fieldsLength(); + for (let i = 0; i < fieldsLength; i += 1) { + const f = fbSchema.fields(i); + if (f && f.typeType() === FbType.Duration) { + hasDuration = true; + break; + } + } + if (!hasDuration) { + return null; + } + + // Re-encode the whole schema. This is more verbose than an in-place + // FlatBuffer patch, but it avoids relying on vtable layout details. + return rebuildSchemaWithDurationRewritten(message, fbSchema); +} + +/** + * Full re-encode path: parse every field in the schema, substitute + * `Duration` with `Int64` (carrying the unit in custom metadata), and + * emit a fresh Message FlatBuffer. This handles arbitrary schemas + * correctly at the cost of decode+re-encode of all fields. + * + * For non-Duration fields we copy the *bytes* of the original + * `type` sub-table verbatim into the new builder — FlatBuffer + * sub-tables are self-contained address spaces, so this is safe. + */ +function rebuildSchemaWithDurationRewritten(message: Message, fbSchema: FbSchema): Buffer { + const builder = new flatbuffers.Builder(1024); + + // Re-encode each field. + const fieldOffsets: number[] = []; + const fieldsLength = fbSchema.fieldsLength(); + for (let i = 0; i < fieldsLength; i += 1) { + const field = fbSchema.fields(i); + if (!field) { + continue; + } + fieldOffsets.push(reEncodeField(builder, field)); + } + + // Re-encode top-level schema custom_metadata verbatim. + const schemaMetadataOffsets: number[] = []; + const schemaMetadataLength = fbSchema.customMetadataLength(); + for (let i = 0; i < schemaMetadataLength; i += 1) { + const kv = fbSchema.customMetadata(i); + if (!kv) { + continue; + } + const keyStr = kv.key() ?? ''; + const valStr = kv.value() ?? ''; + const keyOff = builder.createString(keyStr); + const valOff = builder.createString(valStr); + FbKeyValue.startKeyValue(builder); + FbKeyValue.addKey(builder, keyOff); + FbKeyValue.addValue(builder, valOff); + schemaMetadataOffsets.push(FbKeyValue.endKeyValue(builder)); + } + + // Build the fields and metadata vectors, then the Schema, then the Message. + const fieldsVec = FbSchema.createFieldsVector(builder, fieldOffsets); + const metadataVec = + schemaMetadataOffsets.length > 0 + ? FbSchema.createCustomMetadataVector(builder, schemaMetadataOffsets) + : 0; + + // Preserve features vector — `features()` requires walking the + // bigint vector; for the kernel's payloads this is typically empty + // so we skip it. If a non-empty features vector appears, we drop it + // (Arrow features encode optional compression flags; the kernel + // emits uncompressed streams for the SEA path per + // `findings/rust-kernel/M0-kernel-async-readiness-2026-05-15.md`). + FbSchema.startSchema(builder); + FbSchema.addEndianness(builder, fbSchema.endianness()); + FbSchema.addFields(builder, fieldsVec); + if (metadataVec !== 0) { + FbSchema.addCustomMetadata(builder, metadataVec); + } + const schemaOffset = FbSchema.endSchema(builder); + + // Wrap in a Message. version + headerType + header + bodyLength + custom_metadata. + Message.startMessage(builder); + Message.addVersion(builder, message.version()); + Message.addHeaderType(builder, MessageHeader.Schema); + Message.addHeader(builder, schemaOffset); + Message.addBodyLength(builder, BigInt(0)); + const newMessage = Message.endMessage(builder); + builder.finish(newMessage); + + let bytes = builder.asUint8Array(); + + // The Arrow IPC spec requires each message to be 8-byte aligned so + // that subsequent record batches' body buffers stay aligned for SIMD + // reads. apache-arrow's MessageReader doesn't enforce this on read + // (it just trusts the metadata length), so any padding is fine. + // Round up the metadata bytes to a multiple of 8 by appending zero + // padding — this keeps the IPC stream spec-compliant. + const padded = padToAlignment(bytes, 8); + return Buffer.from(padded); +} + +/** + * Re-encode a single Field. For `Duration` fields, substitute `Int64` + * and add `databricks.arrow.duration_unit` metadata. For all other + * types we re-encode via the appropriate type-sub-table-aware path — + * but to keep this rewriter compact we just walk the FlatBuffer-level + * accessors needed for the M0 primitive types and complex types Arrow + * surfaces from the kernel. Unknown types fall back to copying the + * raw type sub-table bytes via FlatBuffer's serialization (which + * always works because sub-tables are self-contained). + */ +function reEncodeField(builder: flatbuffers.Builder, field: FbField): number { + const nameStr = field.name() ?? ''; + const nameOffset = builder.createString(nameStr); + + // Re-encode children recursively (Struct/List/Map all carry children). + const childOffsets: number[] = []; + const childrenLength = field.childrenLength(); + for (let i = 0; i < childrenLength; i += 1) { + const child = field.children(i); + if (child) { + childOffsets.push(reEncodeField(builder, child)); + } + } + const childrenVec = + childOffsets.length > 0 ? FbField.createChildrenVector(builder, childOffsets) : 0; + + // Re-encode custom_metadata (preserving everything). For Duration + // fields we'll add our marker on top. + const metadataOffsets: number[] = []; + const metadataLength = field.customMetadataLength(); + for (let i = 0; i < metadataLength; i += 1) { + const kv = field.customMetadata(i); + if (!kv) { + continue; + } + const keyStr = kv.key() ?? ''; + const valStr = kv.value() ?? ''; + const keyOff = builder.createString(keyStr); + const valOff = builder.createString(valStr); + FbKeyValue.startKeyValue(builder); + FbKeyValue.addKey(builder, keyOff); + FbKeyValue.addValue(builder, valOff); + metadataOffsets.push(FbKeyValue.endKeyValue(builder)); + } + + const originalTypeType = field.typeType(); + let typeType = originalTypeType; + let typeOffset = 0; + + if (originalTypeType === FbType.Duration) { + // Read the original Duration unit. Substitute Int(64, signed) and + // append a custom_metadata entry recording the original unit. + const durationTable = field.type(new FbDuration()) as FbDuration | null; + const unit = durationTable ? durationTable.unit() : FbTimeUnit.MICROSECOND; + const unitName = TIME_UNIT_NAME[unit] ?? 'MICROSECOND'; + + const keyOff = builder.createString(DURATION_UNIT_METADATA_KEY); + const valOff = builder.createString(unitName); + FbKeyValue.startKeyValue(builder); + FbKeyValue.addKey(builder, keyOff); + FbKeyValue.addValue(builder, valOff); + metadataOffsets.push(FbKeyValue.endKeyValue(builder)); + + typeType = FbType.Int; + typeOffset = FbInt.createInt(builder, 64, true); + } else { + // Copy the original type sub-table by re-encoding it from the + // FlatBuffer-level accessor. Sub-tables are self-contained, but + // the builder API requires us to write each known type with its + // generated `createXxx`. For M0, the kernel emits a fixed set of + // top-level types (matching the SQL datatype table in + // `findings/rust-kernel/datatype-emission-and-block-on-2026-05-15.md`). + // We re-encode each known type sub-table; unsupported types fall + // through to a generic offset-only copy (zero-byte type sub-table), + // which apache-arrow's `decodeFieldType` accepts for the + // children-only types (List, Struct, Null). + typeOffset = reEncodeTypeSubtable(builder, field, originalTypeType); + } + + const metadataVec = + metadataOffsets.length > 0 ? FbField.createCustomMetadataVector(builder, metadataOffsets) : 0; + + FbField.startField(builder); + FbField.addName(builder, nameOffset); + FbField.addNullable(builder, field.nullable()); + FbField.addTypeType(builder, typeType); + if (typeOffset !== 0) { + FbField.addType(builder, typeOffset); + } + if (childrenVec !== 0) { + FbField.addChildren(builder, childrenVec); + } + if (metadataVec !== 0) { + FbField.addCustomMetadata(builder, metadataVec); + } + // Note: dictionary encoding is not re-emitted. The kernel doesn't + // emit dictionary-encoded columns for M0; if it ever does, this + // rewriter would need to copy the DictionaryEncoding sub-table too. + return FbField.endField(builder); +} + +/** + * Re-encode a Field's type sub-table by reading it from the original + * FlatBuffer (via the apache-arrow generated accessors) and writing it + * into the new builder. Supports the full M0 type matrix: + * primitives: Null, Int (all widths), FloatingPoint (Float16/32/64), + * Bool, Utf8, Binary, Decimal, Date, Time, Timestamp, Interval + * complex: List (header only), Struct (header only), Map, FixedSizeList, + * FixedSizeBinary, Union + * Children-only types (Struct, List, Null) emit an empty sub-table. + */ +function reEncodeTypeSubtable( + builder: flatbuffers.Builder, + field: FbField, + typeType: number, +): number { + // Lazy imports to avoid cyclic resolution and to keep this file's + // top-of-module imports tight. These are zero-cost — Node caches + // them after the first require. + /* eslint-disable @typescript-eslint/no-var-requires, global-require, import/no-internal-modules */ + const { Null } = require('apache-arrow/fb/null'); + const { FloatingPoint } = require('apache-arrow/fb/floating-point'); + const { Binary } = require('apache-arrow/fb/binary'); + const { Utf8 } = require('apache-arrow/fb/utf8'); + const { Bool } = require('apache-arrow/fb/bool'); + const { Decimal } = require('apache-arrow/fb/decimal'); + const { Date: DateTbl } = require('apache-arrow/fb/date'); + const { Time } = require('apache-arrow/fb/time'); + const { Timestamp } = require('apache-arrow/fb/timestamp'); + const { Interval } = require('apache-arrow/fb/interval'); + const { List } = require('apache-arrow/fb/list'); + const { Struct_ } = require('apache-arrow/fb/struct-'); + const { Union } = require('apache-arrow/fb/union'); + const { FixedSizeBinary } = require('apache-arrow/fb/fixed-size-binary'); + const { FixedSizeList } = require('apache-arrow/fb/fixed-size-list'); + const { Map: MapTbl } = require('apache-arrow/fb/map'); + /* eslint-enable @typescript-eslint/no-var-requires, global-require, import/no-internal-modules */ + + switch (typeType) { + case FbType.NONE: + case FbType.Null: { + // Null has no fields; emit an empty table. + const t = new Null(); + field.type(t); + Null.startNull(builder); + return Null.endNull(builder); + } + case FbType.Int: { + const t = field.type(new FbInt()) as InstanceType | null; + if (!t) { + return FbInt.createInt(builder, 32, true); + } + return FbInt.createInt(builder, t.bitWidth(), t.isSigned()); + } + case FbType.FloatingPoint: { + const t = field.type(new FloatingPoint()); + return FloatingPoint.createFloatingPoint(builder, t.precision()); + } + case FbType.Binary: { + Binary.startBinary(builder); + return Binary.endBinary(builder); + } + case FbType.Utf8: { + Utf8.startUtf8(builder); + return Utf8.endUtf8(builder); + } + case FbType.Bool: { + Bool.startBool(builder); + return Bool.endBool(builder); + } + case FbType.Decimal: { + const t = field.type(new Decimal()); + return Decimal.createDecimal(builder, t.precision(), t.scale(), t.bitWidth()); + } + case FbType.Date: { + const t = field.type(new DateTbl()); + return DateTbl.createDate(builder, t.unit()); + } + case FbType.Time: { + const t = field.type(new Time()); + return Time.createTime(builder, t.unit(), t.bitWidth()); + } + case FbType.Timestamp: { + const t = field.type(new Timestamp()); + const tz: string | null = t.timezone(); + const tzOffset = tz ? builder.createString(tz) : 0; + Timestamp.startTimestamp(builder); + Timestamp.addUnit(builder, t.unit()); + if (tzOffset !== 0) { + Timestamp.addTimezone(builder, tzOffset); + } + return Timestamp.endTimestamp(builder); + } + case FbType.Interval: { + const t = field.type(new Interval()); + return Interval.createInterval(builder, t.unit()); + } + case FbType.List: { + List.startList(builder); + return List.endList(builder); + } + case FbType.Struct_: { + Struct_.startStruct_(builder); + return Struct_.endStruct_(builder); + } + case FbType.Union: { + const t = field.type(new Union()); + // typeIds is an int32 vector — copy it. + const typeIdsArr = t.typeIdsArray(); + let typeIdsOffset = 0; + if (typeIdsArr) { + typeIdsOffset = Union.createTypeIdsVector(builder, Array.from(typeIdsArr)); + } + Union.startUnion(builder); + Union.addMode(builder, t.mode()); + if (typeIdsOffset !== 0) { + Union.addTypeIds(builder, typeIdsOffset); + } + return Union.endUnion(builder); + } + case FbType.FixedSizeBinary: { + const t = field.type(new FixedSizeBinary()); + return FixedSizeBinary.createFixedSizeBinary(builder, t.byteWidth()); + } + case FbType.FixedSizeList: { + const t = field.type(new FixedSizeList()); + return FixedSizeList.createFixedSizeList(builder, t.listSize()); + } + case FbType.Map: { + const t = field.type(new MapTbl()); + return MapTbl.createMap(builder, t.keysSorted()); + } + default: + // Unknown / newer types (LargeBinary, LargeUtf8, LargeList, + // RunEndEncoded, ...). The kernel doesn't emit these for M0; + // emit an empty sub-table and let apache-arrow's normal error + // path fire when it tries to decode an unrecognized type id. + return 0; + } +} + +/** + * Prefix the given FlatBuffer message bytes with the IPC stream + * framing: the continuation marker (0xFFFFFFFF) followed by the + * little-endian int32 metadata length. + */ +function encodeContinuationAndLength(metadataLength: number): Buffer { + const out = Buffer.alloc(8); + out.writeInt32LE(IPC_CONTINUATION_MARKER | 0, 0); // -1 + out.writeInt32LE(metadataLength, 4); + return out; +} + +/** + * Pad `bytes` with trailing zeros so its length is a multiple of + * `alignment`. Returns the original buffer when it is already + * aligned. + */ +function padToAlignment(bytes: Uint8Array, alignment: number): Uint8Array { + const remainder = bytes.byteLength % alignment; + if (remainder === 0) { + return bytes; + } + const padded = new Uint8Array(bytes.byteLength + (alignment - remainder)); + padded.set(bytes, 0); + return padded; +} diff --git a/lib/sea/SeaAuth.ts b/lib/sea/SeaAuth.ts new file mode 100644 index 00000000..69b10ddc --- /dev/null +++ b/lib/sea/SeaAuth.ts @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { ConnectionOptions } from '../contracts/IDBSQLClient'; +import AuthenticationError from '../errors/AuthenticationError'; +import HiveDriverError from '../errors/HiveDriverError'; +import prependSlash from '../utils/prependSlash'; +import { SeaConnectionOptions } from './SeaNativeLoader'; + +/** + * Shape consumed by the napi-binding's `openSession()`. M0 sends the PAT + * triple plus optional session defaults, so we `Pick` those fields off the + * binding's generated `ConnectionOptions` (re-exported as + * `SeaConnectionOptions`) rather than re-declaring them — if the kernel renames + * one of these fields this stops compiling instead of silently drifting. + */ +export type SeaNativeConnectionOptions = Pick< + SeaConnectionOptions, + 'hostName' | 'httpPath' | 'token' | 'catalog' | 'schema' | 'sessionConf' +>; + +/** + * Validate that the user-supplied `ConnectionOptions` describe a PAT auth + * configuration and build the napi-binding's connection-options shape. + * + * M0 SCOPE: PAT only. + * - Accepts `authType: 'access-token'` and the undefined-authType default + * (which already means PAT throughout the existing driver — see + * `DBSQLClient.createAuthProvider`). + * - Rejects every other `authType` discriminant with a clear + * "M0 supports only PAT" message so callers know OAuth / Federation / + * custom providers land in M1. + * + * Throws: + * - `AuthenticationError` when the auth mode is PAT but `token` is missing + * or empty. + * - `HiveDriverError` when the auth mode is anything other than PAT. + */ +export function buildSeaConnectionOptions(options: ConnectionOptions): SeaNativeConnectionOptions { + const { authType } = options as { authType?: string }; + + if (authType !== undefined && authType !== 'access-token') { + throw new HiveDriverError( + `SEA backend (M0) supports only PAT auth (authType: 'access-token'); ` + + `got authType: '${authType}'. Other auth modes (databricks-oauth, ` + + `token-provider, external-token, static-token, custom) will land in M1.`, + ); + } + + // PAT path — at this point `options` is structurally the access-token branch + // of `AuthOptions`, which guarantees a `token` field at the type level. We + // still defensively re-check because the public ConnectionOptions type + // permits `authType: undefined` with no token at runtime. + const { token } = options as { token?: string }; + if (typeof token !== 'string' || token.length === 0) { + throw new AuthenticationError( + 'SEA backend: a non-empty PAT must be supplied via `token` when using `authType: \'access-token\'`.', + ); + } + // Reject whitespace / control characters in the PAT. The kernel's + // reqwest `HeaderValue` already hard-rejects CR/LF/NUL at build time so + // this isn't a header-injection fix — it's parity with the Python + // driver (auth_bridge.py rejects `[\x00-\x20\x7f]`) and catches + // copy-paste whitespace before a confusing downstream failure. + // eslint-disable-next-line no-control-regex + if (/[\x00-\x20\x7f]/.test(token)) { + throw new AuthenticationError( + 'SEA backend: the PAT supplied via `token` must not contain whitespace or control characters.', + ); + } + + return { + hostName: options.host, + httpPath: prependSlash(options.path), + token, + }; +} diff --git a/lib/sea/SeaBackend.ts b/lib/sea/SeaBackend.ts new file mode 100644 index 00000000..13063e59 --- /dev/null +++ b/lib/sea/SeaBackend.ts @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import IBackend from '../contracts/IBackend'; +import ISessionBackend from '../contracts/ISessionBackend'; +import IClientContext from '../contracts/IClientContext'; +import { ConnectionOptions, OpenSessionRequest } from '../contracts/IDBSQLClient'; +import HiveDriverError from '../errors/HiveDriverError'; +import { getSeaNative, SeaNativeBinding, SeaNativeConnection } from './SeaNativeLoader'; +import { decodeNapiKernelError } from './SeaErrorMapping'; +import { buildSeaConnectionOptions, SeaNativeConnectionOptions } from './SeaAuth'; +import SeaSessionBackend from './SeaSessionBackend'; + +export interface SeaBackendOptions { + context: IClientContext; + /** + * Optional injection seam for unit tests. When provided, replaces the + * default `getSeaNative()` call so tests can swap in a mock napi + * binding without loading the `.node` artifact. + */ + nativeBinding?: SeaNativeBinding; +} + +/** + * SEA-backed implementation of `IBackend`. + * + * **M0 dispatch model:** the napi binding's `openSession()` already + * builds a kernel `Session` from PAT + hostname + httpPath, so there is + * no "connect" round-trip before `openSession` — `connect()` only + * captures the `ConnectionOptions` and validates that PAT auth is in + * use. The actual session open happens inside `openSession()`. + * + * **Auth validation:** delegates to `buildSeaConnectionOptions` from + * `SeaAuth`, which mirrors the existing DBSQLClient PAT validation + * pattern (slash-prepended httpPath, AuthenticationError on missing + * token, HiveDriverError on non-PAT authType naming M1 modes). + * + * **Why we don't use IClientContext's connectionProvider here:** that + * provider is the Thrift HTTP transport. The kernel owns its own + * reqwest+rustls stack inside the native binding, so there is no + * NodeJS-level connection state to manage on the SEA path. The + * `IClientContext` is still useful for logger + config access. + */ +export default class SeaBackend implements IBackend { + private readonly context: IClientContext; + + private readonly binding: SeaNativeBinding; + + private nativeOptions?: SeaNativeConnectionOptions; + + constructor(options?: SeaBackendOptions) { + this.context = options?.context as IClientContext; + this.binding = options?.nativeBinding ?? getSeaNative(); + } + + public async connect(options: ConnectionOptions): Promise { + // Validate PAT auth + capture the napi-binding option shape. + // Any non-PAT mode (or a missing/empty token) throws here, before + // we ever touch the native binding. + this.nativeOptions = buildSeaConnectionOptions(options); + } + + public async openSession(request: OpenSessionRequest): Promise { + if (!this.nativeOptions) { + throw new HiveDriverError('SeaBackend: not connected. Call connect() first.'); + } + + const sessionOptions: SeaNativeConnectionOptions = { ...this.nativeOptions }; + if (request.initialCatalog !== undefined) { + sessionOptions.catalog = request.initialCatalog; + } + if (request.initialSchema !== undefined) { + sessionOptions.schema = request.initialSchema; + } + if (request.configuration !== undefined) { + sessionOptions.sessionConf = { ...request.configuration }; + } + + let nativeConnection: SeaNativeConnection; + try { + nativeConnection = (await this.binding.openSession(sessionOptions)) as SeaNativeConnection; + } catch (err) { + throw decodeNapiKernelError(err); + } + + return new SeaSessionBackend({ + connection: nativeConnection!, + context: this.context, + id: nativeConnection!.sessionId, + }); + } + + public async close(): Promise { + // No backend-level resources to release — each `SeaSessionBackend` + // owns its own napi `Connection` lifecycle. + this.nativeOptions = undefined; + } +} diff --git a/lib/sea/SeaErrorMapping.ts b/lib/sea/SeaErrorMapping.ts new file mode 100644 index 00000000..892efb31 --- /dev/null +++ b/lib/sea/SeaErrorMapping.ts @@ -0,0 +1,172 @@ +import HiveDriverError from '../errors/HiveDriverError'; +import AuthenticationError from '../errors/AuthenticationError'; +import OperationStateError, { OperationStateErrorCode } from '../errors/OperationStateError'; +import ParameterError from '../errors/ParameterError'; + +/** + * Shape of the kernel error surfaced by the napi-binding's `napi_err_from_kernel`. + * + * The Rust kernel's `kernel_error::Error` is exposed as a `JsError` whose + * properties mirror the Rust struct: the `ErrorCode` variant name (as a string), + * the message, and an optional SQLSTATE (either taken from the structured + * server response or recovered via `extract_sqlstate_from_message`). + */ +export interface KernelErrorShape { + /** Kernel `ErrorCode` variant name, e.g. `"Unauthenticated"`, `"SqlError"`. */ + code: string; + /** Human-readable error message. */ + message: string; + /** Optional SQLSTATE — five-char alphanumeric, when the kernel was able to surface it. */ + sqlstate?: string; +} + +/** + * Kernel `ErrorCode` variants — the 13 variants of the `#[non_exhaustive]` enum + * defined in `src/kernel_error.rs:66-134`. + * + * Kept here as a literal type rather than an `enum` so test exhaustiveness checks + * and runtime `code` strings are guaranteed to stay in lockstep with the kernel. + */ +export type KernelErrorCode = + | 'InvalidArgument' + | 'Unauthenticated' + | 'PermissionDenied' + | 'NotFound' + | 'ResourceExhausted' + | 'Unavailable' + | 'Timeout' + | 'Cancelled' + | 'DataLoss' + | 'Internal' + | 'InvalidStatementHandle' + | 'NetworkError' + | 'SqlError'; + +/** + * An `Error` with a preserved SQLSTATE on the `sqlState` property. Used as the + * narrowed return type of {@link mapKernelErrorToJsError} so callers that need + * the SQLSTATE can `error.sqlState` without an `any` cast. + */ +export interface ErrorWithSqlState extends Error { + sqlState?: string; +} + +const KERNEL_ERROR_SENTINEL = '__databricks_error__:'; + +/** + * Attach the kernel's SQLSTATE to the JS error object via the `sqlState` property. + * The driver has no pre-existing `sqlState` convention (no other error class + * sets it today) so this single helper defines it for the SEA path. + */ +function attachSqlState(error: ErrorWithSqlState, sqlstate?: string): ErrorWithSqlState { + if (sqlstate !== undefined) { + // Using Object.defineProperty so the property is non-enumerable but still + // visible via direct access — matches the way Node attaches `.code` to system errors. + Object.defineProperty(error, 'sqlState', { + value: sqlstate, + writable: true, + enumerable: false, + configurable: true, + }); + } + return error; +} + +/** + * Map a kernel error (as surfaced by the napi-binding) to the appropriate JS + * driver error class. + * + * M0 mapping table: + * Unauthenticated, PermissionDenied → AuthenticationError + * Cancelled → OperationStateError(Canceled) + * Timeout → OperationStateError(Timeout) + * InvalidArgument → ParameterError + * NetworkError, Unavailable, + * NotFound, ResourceExhausted, + * DataLoss, Internal, + * InvalidStatementHandle, SqlError → HiveDriverError + * + * Unknown `code` values (e.g. if the kernel adds a new variant) fall through + * to HiveDriverError so the driver never silently drops an error. The kernel's + * `ErrorCode` is `#[non_exhaustive]` so this can legitimately happen. + * + * SQLSTATE, when present, is attached on `error.sqlState` regardless of which + * class is returned. + */ +export function mapKernelErrorToJsError(kErr: KernelErrorShape): ErrorWithSqlState { + const { code, message, sqlstate } = kErr; + + let error: ErrorWithSqlState; + + switch (code as KernelErrorCode) { + case 'Unauthenticated': + case 'PermissionDenied': + error = new AuthenticationError(message); + break; + + case 'Cancelled': + // OperationStateError with the Canceled code carries the kernel message + // through the response.displayMessage fallback path. + error = new OperationStateError(OperationStateErrorCode.Canceled); + error.message = message; + break; + + case 'Timeout': + error = new OperationStateError(OperationStateErrorCode.Timeout); + error.message = message; + break; + + case 'InvalidArgument': + error = new ParameterError(message); + break; + + // All remaining kernel ErrorCode variants map to the base driver error class. + // M0 intentionally does not introduce new error classes; M1 may add nuance. + case 'NotFound': + case 'ResourceExhausted': + case 'Unavailable': + case 'DataLoss': + case 'Internal': + case 'InvalidStatementHandle': + case 'NetworkError': + case 'SqlError': + error = new HiveDriverError(message); + break; + + default: + // Unknown/future kernel variant — never drop the error, surface as base class. + error = new HiveDriverError(message); + break; + } + + return attachSqlState(error, sqlstate); +} + +/** + * Decode a napi-rs error that may contain the kernel's sentinel-prefixed JSON + * envelope. Older generated bindings used `reason`; newer napi-rs errors put + * the same envelope in `message`, so check both. Malformed envelopes preserve + * the original error rather than replacing it with a JSON parse failure. + */ +export function decodeNapiKernelError(err: unknown): Error { + if (err instanceof Error) { + const candidates = [(err as { reason?: unknown }).reason, err.message]; + for (const candidate of candidates) { + if (typeof candidate !== 'string') { + continue; + } + const idx = candidate.indexOf(KERNEL_ERROR_SENTINEL); + if (idx < 0) { + continue; + } + try { + const payload = JSON.parse(candidate.slice(idx + KERNEL_ERROR_SENTINEL.length)) as KernelErrorShape; + return mapKernelErrorToJsError(payload); + } catch { + return err; + } + } + return err; + } + return new HiveDriverError(String(err)); +} diff --git a/lib/sea/SeaNativeLoader.ts b/lib/sea/SeaNativeLoader.ts new file mode 100644 index 00000000..45409881 --- /dev/null +++ b/lib/sea/SeaNativeLoader.ts @@ -0,0 +1,226 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Lazy loader for the SEA (Statement Execution API) native binding. + * + * Mirrors the load-failure-tolerant pattern of `lib/utils/lz4.ts`: the + * `.node` artifact ships via per-platform optional dependencies + * (`@databricks/sql-kernel-`), so its absence must not crash + * a Thrift-only consumer of the driver. Callers that actually need + * SEA construct a {@link SeaNativeLoader} (or use the process-global + * {@link getSeaNative}) which throws a structured error if the binding + * could not be loaded. + * + * M0 publishes a single triple (`linux-x64-gnu`); see + * `native/sea/README.md` for the supported-platform policy. + */ + +import type { + Connection as NativeConnection, + Statement as NativeStatement, + ConnectionOptions as NativeConnectionOptions, + ArrowBatch as NativeArrowBatch, + ArrowSchema as NativeArrowSchema, +} from '../../native/sea'; + +// SEA-prefixed re-exports. The kernel-generated `.d.ts` keeps the +// napi-rs default names (`ConnectionOptions`, `ArrowBatch`, …); we +// disambiguate on the TS-wrapper side so these never collide with the +// Thrift-side `ConnectionOptions` (lib/contracts/IDBSQLClient.ts) or +// `ArrowBatch` (lib/result/utils.ts) when imported elsewhere. +export type SeaConnectionOptions = NativeConnectionOptions; +export type SeaArrowBatch = NativeArrowBatch; +export type SeaArrowSchema = NativeArrowSchema; +export type SeaConnection = NativeConnection; +export type SeaStatement = NativeStatement; + +// Back-compat aliases for the downstream SEA stack branches that landed +// against the pre-rename loader. The merged kernel (@databricks/sql-kernel) +// moved per-statement catalog/schema/sessionConfig to session-level +// `openSession`, so `ExecuteOptions` no longer exists on the binding; +// `SeaExecuteOptions` is kept as a deprecated shim describing the old +// per-statement shape so the stack keeps compiling. Per-statement options +// are now applied at session creation — see native/sea/README.md. +export type SeaNativeConnection = NativeConnection; +export type SeaNativeStatement = NativeStatement; +export type SeaNativeConnectionOptions = NativeConnectionOptions; +/** @deprecated per-statement options moved to session-level `openSession`. */ +export interface SeaExecuteOptions { + initialCatalog?: string; + initialSchema?: string; + sessionConfig?: Record; +} + +/** + * The full native binding surface, derived from the generated module + * so it can never drift from the `.d.ts` contract: when the kernel + * adds or renames a free function / class, this type follows + * automatically and `defaultRequire`'s cast stays correct. + */ +export type SeaNativeBinding = typeof import('../../native/sea'); + +const MIN_NODE_MAJOR = 18; + +function detectNodeMajor(): number { + // `process.version` is `vX.Y.Z`; parseInt stops at the first non-digit. + return parseInt(process.version.slice(1), 10); +} + +function platformLabel(): string { + return `${process.platform}-${process.arch}`; +} + +function loadFailureHint(err: NodeJS.ErrnoException): string { + const platform = platformLabel(); + // Do not name a concrete package: the published name uses the napi-rs + // triple (e.g. `-linux-x64-gnu` / `-linux-x64-musl` / `-win32-x64-msvc`), + // not the bare `${platform}` shown here, so a literal example would + // 404. Point at the README's supported-triple list instead. + const installHint = + 'Install the matching @databricks/sql-kernel-* optional dependency for your platform ' + + '(see native/sea/README.md for the supported triples; M0 ships linux-x64-gnu only).'; + if (err.code === 'MODULE_NOT_FOUND') { + return `SEA native binding not installed for platform ${platform} on Node ${process.version}. ${installHint}`; + } + if (err.code === 'ERR_DLOPEN_FAILED') { + // Surface the underlying dlerror string (e.g. `GLIBC_2.32 not found`) + // plus concrete remediation — without it the cause is invisible. + return ( + `SEA native binding present but failed to dlopen on platform ${platform} / Node ${process.version}: ` + + `${err.message}. Common causes: glibc/musl mismatch (e.g. Alpine Linux — install the -musl variant), ` + + `Node ABI mismatch (try \`rm -rf node_modules && npm install\`), or CPU-architecture mismatch. ` + + `The binding requires Node >=${MIN_NODE_MAJOR}.` + ); + } + return `SEA native binding failed to load on platform ${platform} / Node ${process.version}: ${err.message}`; +} + +/** + * Default loader: resolves `native/sea/index.js` (the napi-rs router), + * which selects the per-platform `.node`. `.js` is omitted so eslint's + * `import/extensions` rule accepts the call. + */ +function defaultRequire(): SeaNativeBinding { + // eslint-disable-next-line @typescript-eslint/no-var-requires, global-require + return require('../../native/sea') as SeaNativeBinding; +} + +/** + * Verify the loaded module exposes the surface the driver depends on. + * Catches kernel-side renames at load time rather than letting them + * surface as `undefined is not a function` deep in a call path. + */ +function assertBindingShape(binding: SeaNativeBinding): void { + const missing: string[] = []; + if (typeof binding.version !== 'function') missing.push('version'); + if (typeof binding.openSession !== 'function') missing.push('openSession'); + if (typeof binding.Connection !== 'function') missing.push('Connection'); + if (typeof binding.Statement !== 'function') missing.push('Statement'); + if (missing.length > 0) { + throw new Error( + `SEA native binding loaded but is missing expected export(s): ${missing.join(', ')}. ` + + `The kernel-generated binding and the JS loader are out of sync.`, + ); + } +} + +/** + * Loads and caches the SEA native binding. Exposed as a class with an + * injectable `load` seam so consumers (e.g. `SeaBackend`) can be unit + * tested with a stub binding instead of requiring a real `.node` on the + * test machine. Most production code uses the process-global default + * via {@link getSeaNative} / {@link tryGetSeaNative}. + */ +export class SeaNativeLoader { + private cached: SeaNativeBinding | null | undefined; + + private cachedError: Error | undefined; + + constructor(private readonly load: () => SeaNativeBinding = defaultRequire) {} + + private tryLoad(): SeaNativeBinding | undefined { + const nodeMajor = detectNodeMajor(); + // Fail closed: if we cannot determine the Node major (NaN) or it is + // below the floor, refuse the load and fall back to Thrift. + if (!Number.isFinite(nodeMajor) || nodeMajor < MIN_NODE_MAJOR) { + this.cachedError = new Error( + `SEA native binding requires Node >=${MIN_NODE_MAJOR}; running Node ${process.version}. ` + + `Continue using the Thrift backend on this runtime.`, + ); + return undefined; + } + + try { + const binding = this.load(); + assertBindingShape(binding); + return binding; + } catch (err) { + if (err instanceof Error && 'code' in err) { + this.cachedError = new Error(loadFailureHint(err as NodeJS.ErrnoException)); + } else if (err instanceof Error) { + // Shape-check failure or any other Error — preserve its message. + this.cachedError = err; + } else { + this.cachedError = new Error(`SEA native binding failed to load with non-standard error: ${String(err)}`); + } + return undefined; + } + } + + /** + * Returns the loaded native binding. Throws a structured error if the + * binding is unavailable on this platform / Node version. + */ + get(): SeaNativeBinding { + if (this.cached === undefined) { + this.cached = this.tryLoad() ?? null; + } + if (this.cached === null) { + throw this.cachedError ?? new Error('SEA native binding unavailable'); + } + return this.cached; + } + + /** + * Returns the loaded binding or `undefined` if it could not be + * loaded. Use this for capability-detection at startup; use + * {@link get} at the point where SEA is actually required. + */ + tryGet(): SeaNativeBinding | undefined { + if (this.cached === undefined) { + this.cached = this.tryLoad() ?? null; + } + return this.cached ?? undefined; + } +} + +// Process-global default instance + thin convenience wrappers. +const defaultLoader = new SeaNativeLoader(); + +/** + * Returns the loaded native binding from the process-global loader. + * Throws a structured error if the binding is unavailable. + */ +export function getSeaNative(): SeaNativeBinding { + return defaultLoader.get(); +} + +/** + * Returns the loaded binding from the process-global loader, or + * `undefined` if it could not be loaded. + */ +export function tryGetSeaNative(): SeaNativeBinding | undefined { + return defaultLoader.tryGet(); +} diff --git a/lib/sea/SeaOperationBackend.ts b/lib/sea/SeaOperationBackend.ts new file mode 100644 index 00000000..ed177989 --- /dev/null +++ b/lib/sea/SeaOperationBackend.ts @@ -0,0 +1,265 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * `IOperationBackend` implementation for the SEA path. + * + * Combines: + * - **Fetch pipeline (from sea-results):** + * `napi.Statement.fetchNextBatch()` → `SeaResultsProvider` → + * `ArrowResultConverter` (Phase 1 + Phase 2; reused unchanged) → + * `ResultSlicer` (chunk-size normalisation; reused unchanged). The M0 + * row shape is byte-identical to the thrift path for every M0 + * datatype (parity gate exercised by `tests/integration/sea/results-e2e.test.ts`). + * + * - **Lifecycle (from sea-operation):** `cancel()` / `close()` / + * `finished()` (alias of `waitUntilReady`) delegate to the helpers + * in `SeaOperationLifecycle.ts`. The helpers handle idempotency, + * flag-set-before-await ordering (so cancel-mid-fetch propagates), + * logging via `IClientContext`, and kernel-error mapping. + * + * The lifecycle helpers route fetch-after-cancel / fetch-after-close + * through `failIfNotActive`, which throws an `OperationStateError` + * matching the Thrift `failIfClosed` semantics. We call it from + * `fetchChunk`/`hasMore`/`getResultMetadata` so the cancel-mid-fetch + * e2e (cancel < 200ms) drives against this backend cleanly. + */ + +import { v4 as uuidv4 } from 'uuid'; +import { + TGetOperationStatusResp, + TGetResultSetMetadataResp, + TSparkRowSetType, + TStatusCode, + TTableSchema, +} from '../../thrift/TCLIService_types'; +import IOperationBackend from '../contracts/IOperationBackend'; +import IClientContext from '../contracts/IClientContext'; +import { OperationState, OperationStatus } from '../contracts/OperationStatus'; +import { ResultFormat, ResultMetadata } from '../contracts/ResultMetadata'; +import Status from '../dto/Status'; +import ArrowResultConverter from '../result/ArrowResultConverter'; +import ResultSlicer from '../result/ResultSlicer'; +import SeaResultsProvider from './SeaResultsProvider'; +import { arrowSchemaToThriftSchema, decodeIpcSchema } from './SeaArrowIpc'; +import { SeaNativeStatement } from './SeaNativeLoader'; +import { + SeaStatementHandle, + SeaOperationLifecycleState, + createLifecycleState, + seaCancel, + seaClose, + seaFinished, + failIfNotActive, +} from './SeaOperationLifecycle'; + +/** + * Structural union of the lifecycle surface (cancel/close) and the + * fetch surface (fetchNextBatch/schema). The real napi `Statement` + * implements both; lifecycle-only test stubs implement only the + * cancel/close half — fetch methods are accessed lazily and the + * lifecycle tests never reach that path. + */ +export type SeaOperationStatement = SeaStatementHandle & Partial; + +/** + * Constructor options for `SeaOperationBackend`. + */ +export interface SeaOperationBackendOptions { + /** The opaque napi `Statement` handle returned by `Connection.executeStatement(...)`. */ + statement: SeaOperationStatement; + context: IClientContext; + /** + * Optional override for `id`. When not provided a fresh UUIDv4 is + * generated upstream (in `SeaSessionBackend.executeStatement`); the + * kernel does not yet surface its internal statement-id at the napi + * boundary. Once it does, the JS layer can thread it through here. + */ + id?: string; +} + +export default class SeaOperationBackend implements IOperationBackend { + private readonly statement: SeaOperationStatement; + + private readonly context: IClientContext; + + private readonly _id: string; + + private readonly lifecycle: SeaOperationLifecycleState = createLifecycleState(); + + private resultSlicer?: ResultSlicer; + + private resultsProvider?: SeaResultsProvider; + + private metadata?: TGetResultSetMetadataResp; + + private metadataPromise?: Promise; + + constructor({ statement, context, id }: SeaOperationBackendOptions) { + this.statement = statement; + this.context = context; + this._id = id ?? uuidv4(); + } + + public get id(): string { + return this._id; + } + + public get hasResultSet(): boolean { + // M0 only routes through SeaOperationBackend for executeStatement + // calls. DDL/DML without a result set is not exercised through SEA + // for M0; the napi Statement still produces a schema (empty) in + // that case, which the converter renders as zero rows. Reporting + // `true` keeps the facade's fetch path enabled for M0 parity. + return true; + } + + // --------------------------------------------------------------------------- + // Fetch / metadata (owned by the sea-results pipeline). + // --------------------------------------------------------------------------- + + public async fetchChunk({ + limit, + disableBuffering, + }: { + limit: number; + disableBuffering?: boolean; + }): Promise> { + // Cancel-mid-fetch propagation: if cancel() has flipped the + // lifecycle flag, fail locally without a wire round-trip. + failIfNotActive(this.lifecycle); + const slicer = await this.getResultSlicer(); + return slicer.fetchNext({ limit, disableBuffering }); + } + + public async hasMore(): Promise { + failIfNotActive(this.lifecycle); + const slicer = await this.getResultSlicer(); + return slicer.hasMore(); + } + + public async getResultMetadata(): Promise { + const metadata = await this.thriftResultMetadataResponse(); + return { + schema: metadata.schema, + resultFormat: ResultFormat.ArrowBased, + lz4Compressed: metadata.lz4Compressed, + isStagingOperation: Boolean(metadata.isStagingOperation), + }; + } + + private async thriftResultMetadataResponse(): Promise { + failIfNotActive(this.lifecycle); + if (this.metadata) { + return this.metadata; + } + if (this.metadataPromise) { + return this.metadataPromise; + } + this.metadataPromise = (async () => { + if (!this.statement.schema) { + throw new Error('SeaOperationBackend: statement.schema() is not available on this handle'); + } + const arrowSchemaIpc = await this.statement.schema(); + const arrowSchema = decodeIpcSchema(arrowSchemaIpc.ipcBytes); + const thriftSchema: TTableSchema = arrowSchemaToThriftSchema(arrowSchema); + const meta: TGetResultSetMetadataResp = { + status: { statusCode: TStatusCode.SUCCESS_STATUS }, + schema: thriftSchema, + // SEA inline + CloudFetch both surface to JS as Arrow batches; + // both flow through the same converter that handles the + // ARROW_BASED_SET path on the thrift side. + resultFormat: TSparkRowSetType.ARROW_BASED_SET, + lz4Compressed: false, + isStagingOperation: false, + }; + this.metadata = meta; + return meta; + })(); + try { + return await this.metadataPromise; + } finally { + this.metadataPromise = undefined; + } + } + + // --------------------------------------------------------------------------- + // Status / lifecycle (owned by the sea-operation lifecycle helpers). + // --------------------------------------------------------------------------- + + public async status(_progress: boolean): Promise { + // Synthesised — kernel only surfaces terminal-or-running statements + // through its public API; we report CANCELED/CLOSED if the lifecycle + // flag is set, else FINISHED. Matches the Thrift status shape so + // facade-level callers see consistent telemetry across backends. + if (this.lifecycle.isCancelled) { + return { + state: OperationState.Cancelled, + hasResultSet: true, + }; + } + if (this.lifecycle.isClosed) { + return { + state: OperationState.Closed, + hasResultSet: true, + }; + } + return { + state: OperationState.Succeeded, + hasResultSet: true, + }; + } + + public async waitUntilReady(options?: { + progress?: boolean; + callback?: (progress: TGetOperationStatusResp) => unknown; + }): Promise { + // Kernel's `Statement::execute().await` has already resolved by the + // time we hold a Statement handle — there is no pending/running + // state to poll for M0. seaFinished fires the progress callback + // once with a synthesised FINISHED response so progress-UI callers + // see the same one-shot completion tick the Thrift path emits at + // the end of its polling loop. + return seaFinished(this.lifecycle, options); + } + + public async cancel(): Promise { + return seaCancel(this.lifecycle, this.statement, this.context, this._id); + } + + public async close(): Promise { + return seaClose(this.lifecycle, this.statement, this.context, this._id); + } + + // --------------------------------------------------------------------------- + // Internals. + // --------------------------------------------------------------------------- + + private async getResultSlicer(): Promise> { + if (this.resultSlicer) { + return this.resultSlicer; + } + if (!this.statement.fetchNextBatch) { + throw new Error('SeaOperationBackend: statement.fetchNextBatch() is not available on this handle'); + } + const metadata = await this.thriftResultMetadataResponse(); + // The lifecycle subset has cancel/close only; fetch methods exist on + // the full napi Statement. Cast is safe here because we've just + // verified `fetchNextBatch` is callable. + this.resultsProvider = new SeaResultsProvider(this.statement as SeaNativeStatement); + const converter = new ArrowResultConverter(this.context, this.resultsProvider, metadata); + this.resultSlicer = new ResultSlicer(this.context, converter); + return this.resultSlicer; + } +} diff --git a/lib/sea/SeaOperationLifecycle.ts b/lib/sea/SeaOperationLifecycle.ts new file mode 100644 index 00000000..a3294ba2 --- /dev/null +++ b/lib/sea/SeaOperationLifecycle.ts @@ -0,0 +1,283 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * SEA operation lifecycle helpers (M0). + * + * The three methods exposed here (`cancel`, `close`, `finished`) are + * standalone functions that the `SeaOperationBackend` implementation + * delegates to. Keeping them in this dedicated file lets the parallel + * impl-results work (which owns the fetch-* methods on + * `SeaOperationBackend`) land independently — at merge time it can + * either import these helpers from here or inline them, with no + * conflicts on the call sites. + * + * Mapping to the existing `DBSQLOperation` semantics: + * - `cancel()` → ` driver.cancelOperation(...)` on Thrift today + * (`lib/DBSQLOperation.ts:241-259`). For SEA this is a one-shot + * forward to the napi `Statement.cancel()` which in turn calls + * `ExecutedStatementHandle::cancel(&self).await` in the kernel. + * - `close()` → `driver.closeOperation(...)` on Thrift today + * (`lib/DBSQLOperation.ts:265-284`). For SEA this is the napi + * `Statement.close()` which awaits the server-side delete. + * - `finished({progress, callback})` → the 100ms polling loop in + * `DBSQLOperation.waitUntilReady` today (`lib/DBSQLOperation.ts:337-391`). + * For M0 the kernel's `Statement::execute().await` already blocks + * until the statement is in a terminal state, so by the time the JS + * side has an `ExecutedStatement` (and therefore a binding-level + * `Statement`) the underlying operation is already finished. The + * M0 implementation here therefore resolves immediately, optionally + * firing the progress callback once with a synthesized "finished" + * response so callers that wire a progress UI still see a single + * completion tick. + */ + +import { + TGetOperationStatusResp, + TOperationState, + TStatusCode, +} from '../../thrift/TCLIService_types'; +import Status from '../dto/Status'; +import { LogLevel } from '../contracts/IDBSQLLogger'; +import IClientContext from '../contracts/IClientContext'; +import { mapKernelErrorToJsError, KernelErrorShape } from './SeaErrorMapping'; +import OperationStateError, { OperationStateErrorCode } from '../errors/OperationStateError'; + +/** + * Minimal shape of the napi `Statement` that the lifecycle helpers + * depend on. Declared structurally so unit tests can hand in a mock + * without pulling the real native binding into the test process. + * + * The real binding's `Statement` (see `native/sea/index.d.ts`) has + * additional methods (`fetchNextBatch`, `schema`) which the lifecycle + * helpers deliberately don't touch — those belong to the results + * feature's surface. + */ +export interface SeaStatementHandle { + cancel(): Promise; + close(): Promise; +} + +/** + * Internal lifecycle state shared between the operation backend and + * these helpers. `SeaOperationBackend` keeps an instance of this and + * passes it to each helper call. Centralising the flags here means + * the helpers stay pure (no `this`) and the backend stays + * straightforward. + */ +export interface SeaOperationLifecycleState { + /** True once `cancel()` has succeeded — subsequent fetch* must throw. */ + isCancelled: boolean; + /** True once `close()` has been called (idempotent). */ + isClosed: boolean; +} + +/** + * Factory for a fresh lifecycle-state record. Helps keep test setup + * tidy. + */ +export function createLifecycleState(): SeaOperationLifecycleState { + return { isCancelled: false, isClosed: false }; +} + +/** + * Normalise an error thrown by the napi `Statement` into one of the + * driver's typed error classes. The binding surfaces kernel errors as + * a JSON envelope on `napi::Error.reason` with the sentinel prefix + * `__databricks_error__:` (see the napi-binding round 2 findings, + * section "JSON-envelope error reason"). If we can parse out a kernel + * payload, we route it through `mapKernelErrorToJsError`; otherwise + * the original error is rethrown unchanged. + */ +function rethrowKernelError(err: unknown): never { + if (err instanceof Error && typeof err.message === 'string') { + const sentinel = '__databricks_error__:'; + const idx = err.message.indexOf(sentinel); + if (idx >= 0) { + const json = err.message.slice(idx + sentinel.length); + let parsed: KernelErrorShape | undefined; + try { + parsed = JSON.parse(json) as KernelErrorShape; + } catch { + // Malformed envelope — fall through and rethrow the original + // below; we never silently drop a kernel error. + parsed = undefined; + } + if (parsed) { + throw mapKernelErrorToJsError(parsed); + } + } + } + throw err; +} + +/** + * Cancel an in-flight SEA operation. + * + * Mirrors `DBSQLOperation.cancel` semantics + * (`lib/DBSQLOperation.ts:241-259`): + * - idempotent: returns success if already cancelled or closed + * (no-ops are not bubbled to the kernel because the binding's + * `Statement::cancel` already treats already-finished statements as + * a no-op, but we still want to avoid a network round-trip here), + * - sets the cancelled flag _before_ awaiting the napi call so that a + * concurrent `fetchChunk()` observing the flag short-circuits as + * soon as the await yields (matches the Thrift flag-set ordering + * at `lib/DBSQLOperation.ts:254`), + * - returns a `Status.success()` on success (no rich Thrift status + * payload is available from the kernel side). + */ +export async function seaCancel( + state: SeaOperationLifecycleState, + statement: SeaStatementHandle, + context: IClientContext, + operationId: string, +): Promise { + if (state.isCancelled || state.isClosed) { + return Status.success(); + } + + context + .getLogger() + .log(LogLevel.debug, `Cancelling SEA operation with id: ${operationId}`); + + state.isCancelled = true; + + try { + await statement.cancel(); + } catch (err) { + state.isCancelled = false; + rethrowKernelError(err); + } + + return Status.success(); +} + +/** + * Close a SEA operation. + * + * Mirrors `DBSQLOperation.close` semantics + * (`lib/DBSQLOperation.ts:265-284`) without the Thrift-only + * direct-results-prefetch optimisation: + * - idempotent: a second call is a no-op, + * - awaits the binding's `Statement::close` (which goes through to + * the kernel's `delete_statement` RPC), + * - sets the closed flag _before_ awaiting so a concurrent fetch + * sees the closed state as soon as the await yields. + */ +export async function seaClose( + state: SeaOperationLifecycleState, + statement: SeaStatementHandle, + context: IClientContext, + operationId: string, +): Promise { + if (state.isClosed) { + return Status.success(); + } + + context + .getLogger() + .log(LogLevel.debug, `Closing SEA operation with id: ${operationId}`); + + state.isClosed = true; + + try { + await statement.close(); + } catch (err) { + state.isClosed = false; + rethrowKernelError(err); + } + + return Status.success(); +} + +/** + * Synthesize a `TGetOperationStatusResp` shaped object reporting the + * "finished" state. The kernel doesn't surface a Thrift-shaped status + * struct, but `IOperation.finished({progress, callback})` is public + * surface and the callback signature expects this exact shape (see + * `lib/contracts/IOperation.ts:5` `OperationStatusCallback`). For M0 + * we report `FINISHED_STATE` with a success status. Richer fields + * (`numModifiedRows`, `progressUpdateResponse`, `displayMessage`) + * defer to M1 per the operation feature plan. + */ +function synthesizeFinishedStatus(): TGetOperationStatusResp { + return { + status: { + statusCode: TStatusCode.SUCCESS_STATUS, + }, + operationState: TOperationState.FINISHED_STATE, + } as TGetOperationStatusResp; +} + +/** + * `IOperation.finished({progress, callback})` M0 implementation. + * + * The Thrift implementation is a 100ms polling loop over + * `getOperationStatus` (`lib/DBSQLOperation.ts:337-391`). For SEA M0, + * the kernel's `Statement::execute().await` already blocks until the + * statement reaches a terminal state — by the time the JS layer has + * a `Statement` handle, the operation has already finished. + * + * Therefore the M0 implementation resolves immediately. If the + * caller supplied a progress callback we still invoke it once (a + * single completion tick) so progress-UI consumers see the same + * "operation is now finished" signal they'd get from the polling + * Thrift path — just without the intermediate `RUNNING_STATE` + * notifications. + * + * If the operation is already cancelled or closed, this is a no-op + * (matches the Thrift `failIfClosed` / cancelled-state semantics + * without throwing; throwing is the responsibility of subsequent + * fetch calls). + */ +export async function seaFinished( + state: SeaOperationLifecycleState, + options?: { + progress?: boolean; + callback?: (progress: TGetOperationStatusResp) => unknown; + }, +): Promise { + if (state.isCancelled || state.isClosed) { + return; + } + + if (options?.callback) { + const response = synthesizeFinishedStatus(); + // Await the callback in case it returns a promise — matches the + // Thrift code path at `lib/DBSQLOperation.ts:348-351`. + await Promise.resolve(options.callback(response)); + } +} + +/** + * Pre-flight check used by fetch* methods on `SeaOperationBackend`. + * If the operation has been cancelled or closed, throw the same + * `OperationStateError` classes the facade uses. Keeping these typed lets + * callers branch on `OperationStateErrorCode` consistently for Thrift and SEA. + * + * Exported so impl-results can call it at the top of every fetch + * call without duplicating the if/throw logic. + */ +export function failIfNotActive(state: SeaOperationLifecycleState): void { + if (state.isCancelled) { + throw mapKernelErrorToJsError({ + code: 'Cancelled', + message: 'The operation was cancelled.', + }); + } + if (state.isClosed) { + throw new OperationStateError(OperationStateErrorCode.Closed); + } +} diff --git a/lib/sea/SeaResultsProvider.ts b/lib/sea/SeaResultsProvider.ts new file mode 100644 index 00000000..6adf2cba --- /dev/null +++ b/lib/sea/SeaResultsProvider.ts @@ -0,0 +1,117 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import IResultsProvider, { ResultsProviderFetchNextOptions } from '../result/IResultsProvider'; +import { ArrowBatch } from '../result/utils'; +import { decodeIpcBatch, patchIpcBytes } from './SeaArrowIpc'; + +/** + * The minimal slice of the napi-binding `Statement` class that we + * consume from JS. Defined locally (not imported from the binding's + * d.ts) so the loader layer's loose `unknown` typing doesn't force + * unsafe casts at every call site, and so unit tests can pass a stub. + */ +export interface SeaStatementHandle { + fetchNextBatch(): Promise<{ ipcBytes: Buffer } | null>; +} + +/** + * `IResultsProvider` that pulls Arrow IPC batches from the + * kernel via the napi `Statement` handle and adapts them onto the + * shape `ArrowResultConverter` already speaks + * (`lib/result/utils.ts:22-25`). + * + * Each kernel `fetchNextBatch()` call returns a complete Arrow IPC + * stream (schema header + 1 record-batch message) per the design + * documented at `sea-workflow/findings/arch/napi-binding/round2-methods-2026-05-15.md:46-60`. + * We pass that buffer through as a single-element `batches: [ipcBytes]` + * array — `RecordBatchReader.from(arrowBatch.batches)` inside the + * converter (`lib/result/ArrowResultConverter.ts:119`) reads the + * schema from the prefix and then the record-batch messages from the + * remainder of the same buffer. + * + * We pre-parse the IPC bytes once here to extract `rowCount` (the + * sum of `RecordBatch.numRows` across messages in the stream) because + * the converter consumes that as an explicit field rather than + * deriving it from the batch contents. See the comment in + * `SeaArrowIpc.ts:decodeIpcBatch` for the cost rationale. + */ +export default class SeaResultsProvider implements IResultsProvider { + private readonly statement: SeaStatementHandle; + + // Prefetched next batch so `hasMore()` can be answered without an + // extra round-trip. Set by `prime()` (lazy) and by `fetchNext`. + private prefetched?: ArrowBatch; + + // Set once the kernel returns `null` from `fetchNextBatch()`. + private exhausted = false; + + constructor(statement: SeaStatementHandle) { + this.statement = statement; + } + + public async hasMore(): Promise { + if (this.exhausted) { + return false; + } + if (this.prefetched !== undefined) { + return true; + } + await this.prime(); + return this.prefetched !== undefined; + } + + public async fetchNext(_options: ResultsProviderFetchNextOptions): Promise { + if (this.prefetched === undefined && !this.exhausted) { + await this.prime(); + } + if (this.prefetched === undefined) { + return { batches: [], rowCount: 0 }; + } + const out = this.prefetched; + this.prefetched = undefined; + return out; + } + + // Pull the next batch from the kernel and stash it in `prefetched`, + // or mark the stream exhausted. Used by both `hasMore` and `fetchNext` + // to keep one batch buffered ahead so `hasMore` is accurate without + // re-asking the kernel. + private async prime(): Promise { + if (this.exhausted || this.prefetched !== undefined) { + return; + } + const next = await this.statement.fetchNextBatch(); + if (next === null) { + this.exhausted = true; + return; + } + // Patch the raw bytes once: rewrite any Arrow `Duration` field to + // `Int64` with a `databricks.arrow.duration_unit` marker, so that + // apache-arrow@13 (which predates Duration support) can decode the + // stream. `decodeIpcBatch` is told these bytes are already patched; + // the downstream `RecordBatchReader.from` inside `ArrowResultConverter` + // sees the same patched buffer. See `SeaArrowIpcDurationFix.ts`. + const ipcBytes = patchIpcBytes(next.ipcBytes); + const { rowCount } = decodeIpcBatch(ipcBytes, { alreadyPatched: true }); + if (rowCount === 0) { + // Skip empty batches — the converter handles them but pre-filtering + // here avoids one round-trip through the converter's prefetch loop. + // Re-prime to either find a non-empty batch or hit exhaustion. + await this.prime(); + return; + } + this.prefetched = { batches: [ipcBytes], rowCount }; + } +} diff --git a/lib/sea/SeaSessionBackend.ts b/lib/sea/SeaSessionBackend.ts new file mode 100644 index 00000000..3df970ee --- /dev/null +++ b/lib/sea/SeaSessionBackend.ts @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { v4 as uuidv4 } from 'uuid'; +import ISessionBackend from '../contracts/ISessionBackend'; +import IOperationBackend from '../contracts/IOperationBackend'; +import IClientContext from '../contracts/IClientContext'; +import { + ExecuteStatementOptions, + TypeInfoRequest, + CatalogsRequest, + SchemasRequest, + TablesRequest, + TableTypesRequest, + ColumnsRequest, + FunctionsRequest, + PrimaryKeysRequest, + CrossReferenceRequest, +} from '../contracts/IDBSQLSession'; +import Status from '../dto/Status'; +import InfoValue from '../dto/InfoValue'; +import HiveDriverError from '../errors/HiveDriverError'; +import { SeaNativeConnection } from './SeaNativeLoader'; +import { decodeNapiKernelError } from './SeaErrorMapping'; +import SeaOperationBackend from './SeaOperationBackend'; + +export interface SeaSessionBackendOptions { + /** The opaque napi `Connection` handle returned by `openSession`. */ + connection: SeaNativeConnection; + context: IClientContext; + /** Optional override for `id`. Defaults to a fresh UUIDv4. */ + id?: string; +} + +/** + * SEA-backed implementation of `ISessionBackend`. + * + * **M0 scope:** `executeStatement` + `close`. Metadata methods + * (`getCatalogs`, `getSchemas`, etc.) defer to M1 — they throw a clear + * `HiveDriverError` so consumers using SEA against metadata APIs get an + * actionable message instead of silently falling back. The Thrift + * backend continues to handle the metadata path by default (callers + * opt into SEA via `ConnectionOptions.useSEA`). + * + * **Session config flow:** catalog, schema, and session configuration are + * applied at native `openSession` time by `SeaBackend`. They are not + * per-statement options on the napi binding. + */ +export default class SeaSessionBackend implements ISessionBackend { + private readonly connection: SeaNativeConnection; + + private readonly context: IClientContext; + + private readonly _id: string; + + private closed = false; + + constructor({ connection, context, id }: SeaSessionBackendOptions) { + this.connection = connection; + this.context = context; + this._id = id ?? uuidv4(); + } + + public get id(): string { + return this._id; + } + + public async getInfo(_infoType: number): Promise { + throw new HiveDriverError('SeaSessionBackend.getInfo: not implemented yet (deferred to M1)'); + } + + /** + * Execute a SQL statement through the napi binding. + * + * M0 rejects options the SEA binding cannot honor yet so callers do not + * accidentally get a query with different semantics than the Thrift path. + */ + public async executeStatement(statement: string, options: ExecuteStatementOptions): Promise { + this.failIfClosed(); + + // M0 surfaces a clear error rather than silently dropping M1-only knobs. + if (options.namedParameters !== undefined || options.ordinalParameters !== undefined) { + throw new HiveDriverError( + 'SEA executeStatement: query parameters are not supported in M0 (deferred to M1)', + ); + } + if (options.queryTimeout !== undefined) { + throw new HiveDriverError( + 'SEA executeStatement: queryTimeout is not supported in M0 (deferred to M1)', + ); + } + if (options.useCloudFetch !== undefined) { + throw new HiveDriverError( + 'SEA executeStatement: useCloudFetch is controlled by the kernel result configuration and is not a per-statement option on SEA', + ); + } + + let nativeStatement; + try { + nativeStatement = await this.connection.executeStatement(statement); + } catch (err) { + throw decodeNapiKernelError(err); + } + return new SeaOperationBackend({ + statement: nativeStatement!, + context: this.context, + id: nativeStatement!.statementId, + }); + } + + public async getTypeInfo(_request: TypeInfoRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getTypeInfo: not implemented yet (deferred to M1)'); + } + + public async getCatalogs(_request: CatalogsRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getCatalogs: not implemented yet (deferred to M1)'); + } + + public async getSchemas(_request: SchemasRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getSchemas: not implemented yet (deferred to M1)'); + } + + public async getTables(_request: TablesRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getTables: not implemented yet (deferred to M1)'); + } + + public async getTableTypes(_request: TableTypesRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getTableTypes: not implemented yet (deferred to M1)'); + } + + public async getColumns(_request: ColumnsRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getColumns: not implemented yet (deferred to M1)'); + } + + public async getFunctions(_request: FunctionsRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getFunctions: not implemented yet (deferred to M1)'); + } + + public async getPrimaryKeys(_request: PrimaryKeysRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getPrimaryKeys: not implemented yet (deferred to M1)'); + } + + public async getCrossReference(_request: CrossReferenceRequest): Promise { + throw new HiveDriverError('SeaSessionBackend.getCrossReference: not implemented yet (deferred to M1)'); + } + + public async close(): Promise { + if (this.closed) { + return Status.success(); + } + try { + await this.connection.close(); + } catch (err) { + throw decodeNapiKernelError(err); + } + this.closed = true; + return Status.success(); + } + + private failIfClosed(): void { + if (this.closed) { + throw new HiveDriverError('SeaSessionBackend: session is closed'); + } + } +} diff --git a/lib/thrift-backend/ThriftBackend.ts b/lib/thrift-backend/ThriftBackend.ts new file mode 100644 index 00000000..5e0e7570 --- /dev/null +++ b/lib/thrift-backend/ThriftBackend.ts @@ -0,0 +1,100 @@ +import Int64 from 'node-int64'; +import IBackend from '../contracts/IBackend'; +import ISessionBackend from '../contracts/ISessionBackend'; +import IClientContext from '../contracts/IClientContext'; +import { OpenSessionRequest } from '../contracts/IDBSQLClient'; +import { TProtocolVersion } from '../../thrift/TCLIService_types'; +import Status from '../dto/Status'; +import { definedOrError, serializeQueryTags } from '../utils'; +import ThriftSessionBackend from './ThriftSessionBackend'; + +function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { + if (!catalogName && !schemaName) { + return {}; + } + + return { + initialNamespace: { + catalogName, + schemaName, + }, + }; +} + +interface ThriftBackendOptions { + context: IClientContext; + onConnectionEvent: (event: 'error' | 'reconnecting' | 'close' | 'timeout', payload?: unknown) => void; +} + +export default class ThriftBackend implements IBackend { + private readonly context: IClientContext; + + private readonly onConnectionEvent: ThriftBackendOptions['onConnectionEvent']; + + constructor({ context, onConnectionEvent }: ThriftBackendOptions) { + this.context = context; + this.onConnectionEvent = onConnectionEvent; + } + + public async connect(): Promise { + // The connection provider is owned by DBSQLClient (it implements IClientContext). + // We only need to wire the EventEmitter listeners through this backend. + const connectionProvider = await this.context.getConnectionProvider(); + const thriftConnection = await connectionProvider.getThriftConnection(); + + thriftConnection.on('error', (error: Error) => { + this.onConnectionEvent('error', error); + }); + + thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => { + this.onConnectionEvent('reconnecting', params); + }); + + thriftConnection.on('close', () => { + this.onConnectionEvent('close'); + }); + + thriftConnection.on('timeout', () => { + this.onConnectionEvent('timeout'); + }); + } + + public async openSession(request: OpenSessionRequest): Promise { + const driver = await this.context.getDriver(); + const config = this.context.getConfig(); + + const configuration = request.configuration ? { ...request.configuration } : {}; + + if (config.enableMetricViewMetadata) { + configuration['spark.sql.thriftserver.metadata.metricview.enabled'] = 'true'; + } + + if (request.queryTags !== undefined) { + const serialized = serializeQueryTags(request.queryTags); + if (serialized) { + configuration.QUERY_TAGS = serialized; + } else { + delete configuration.QUERY_TAGS; + } + } + + const response = await driver.openSession({ + client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8), + ...getInitialNamespaceOptions(request.initialCatalog, request.initialSchema), + configuration, + canUseMultipleCatalogs: true, + }); + + Status.assert(response.status); + return new ThriftSessionBackend({ + handle: definedOrError(response.sessionHandle), + context: this.context, + serverProtocolVersion: response.serverProtocolVersion, + }); + } + + public async close(): Promise { + // DBSQLClient owns the connection lifecycle and clears its own state + // (connectionProvider, authProvider, thrift client) after this returns. + } +} diff --git a/lib/thrift-backend/ThriftOperationBackend.ts b/lib/thrift-backend/ThriftOperationBackend.ts new file mode 100644 index 00000000..436d4928 --- /dev/null +++ b/lib/thrift-backend/ThriftOperationBackend.ts @@ -0,0 +1,382 @@ +import { stringify, NIL } from 'uuid'; +import { + TGetOperationStatusResp, + TOperationHandle, + TSparkDirectResults, + TGetResultSetMetadataResp, + TSparkRowSetType, + TCloseOperationResp, + TOperationState, +} from '../../thrift/TCLIService_types'; +import IOperationBackend from '../contracts/IOperationBackend'; +import IClientContext from '../contracts/IClientContext'; +import { WaitUntilReadyOptions } from '../contracts/IOperation'; +import { OperationStatus, OperationState } from '../contracts/OperationStatus'; +import { ResultMetadata, ResultFormat } from '../contracts/ResultMetadata'; +import Status from '../dto/Status'; +import { LogLevel } from '../contracts/IDBSQLLogger'; +import OperationStateError, { OperationStateErrorCode } from '../errors/OperationStateError'; +import IResultsProvider from '../result/IResultsProvider'; +import RowSetProvider from '../result/RowSetProvider'; +import JsonResultHandler from '../result/JsonResultHandler'; +import ArrowResultHandler from '../result/ArrowResultHandler'; +import CloudFetchResultHandler from '../result/CloudFetchResultHandler'; +import ArrowResultConverter from '../result/ArrowResultConverter'; +import ResultSlicer from '../result/ResultSlicer'; +import { definedOrError } from '../utils'; +import HiveDriverError from '../errors/HiveDriverError'; + +interface ThriftOperationBackendOptions { + handle: TOperationHandle; + directResults?: TSparkDirectResults; + context: IClientContext; +} + +async function delay(ms?: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + +function thriftStateToOperationState(state: TOperationState | undefined | null): OperationState { + switch (state) { + case TOperationState.INITIALIZED_STATE: + case TOperationState.PENDING_STATE: + return OperationState.Pending; + case TOperationState.RUNNING_STATE: + return OperationState.Running; + case TOperationState.FINISHED_STATE: + return OperationState.Succeeded; + case TOperationState.CANCELED_STATE: + return OperationState.Cancelled; + case TOperationState.CLOSED_STATE: + return OperationState.Closed; + case TOperationState.ERROR_STATE: + case TOperationState.TIMEDOUT_STATE: + return OperationState.Failed; + case TOperationState.UKNOWN_STATE: + default: + return OperationState.Unknown; + } +} + +function thriftRowSetTypeToResultFormat(type: TSparkRowSetType): ResultFormat { + switch (type) { + case TSparkRowSetType.COLUMN_BASED_SET: + return ResultFormat.ColumnBased; + case TSparkRowSetType.ARROW_BASED_SET: + return ResultFormat.ArrowBased; + case TSparkRowSetType.URL_BASED_SET: + return ResultFormat.UrlBased; + default: + throw new HiveDriverError(`Unsupported result format: ${TSparkRowSetType[type]}`); + } +} + +export default class ThriftOperationBackend implements IOperationBackend { + private readonly context: IClientContext; + + private readonly operationHandle: TOperationHandle; + + private readonly _data: RowSetProvider; + + private readonly closeOperation?: TCloseOperationResp; + + private metadata?: TGetResultSetMetadataResp; + + private metadataPromise?: Promise; + + private state: TOperationState = TOperationState.INITIALIZED_STATE; + + private operationStatus?: TGetOperationStatusResp; + + private resultHandler?: ResultSlicer; + + constructor({ handle, directResults, context }: ThriftOperationBackendOptions) { + this.operationHandle = handle; + this.context = context; + + const useOnlyPrefetchedResults = Boolean(directResults?.closeOperation); + + if (directResults?.operationStatus) { + this.processOperationStatusResponse(directResults.operationStatus); + } + + this.metadata = directResults?.resultSetMetadata; + this._data = new RowSetProvider( + this.context, + this.operationHandle, + [directResults?.resultSet], + useOnlyPrefetchedResults, + ); + this.closeOperation = directResults?.closeOperation; + } + + public get id(): string { + const operationId = this.operationHandle?.operationId?.guid; + return operationId ? stringify(operationId) : NIL; + } + + public get hasResultSet(): boolean { + return Boolean(this.operationHandle.hasResultSet); + } + + public async fetchChunk({ + limit, + disableBuffering, + }: { + limit: number; + disableBuffering?: boolean; + }): Promise> { + const resultHandler = await this.getResultHandler(); + + // All the library code is Promise-based, however, since Promises are microtasks, + // enqueueing a lot of promises may block macrotasks execution for a while. + // Usually, there are no much microtasks scheduled, however, when fetching query + // results (especially CloudFetch ones) it's quite easy to block event loop for + // long enough to break a lot of things. For example, with CloudFetch, after first + // set of files are downloaded and being processed immediately one by one, event + // loop easily gets blocked for enough time to break connection pool. `http.Agent` + // stops receiving socket events, and marks all sockets invalid on the next attempt + // to use them. See these similar issues that helped to debug this particular case - + // https://github.com/nodejs/node/issues/47130 and https://github.com/node-fetch/node-fetch/issues/1735 + await new Promise((resolve) => { + setTimeout(resolve, 0); + }); + + return resultHandler.fetchNext({ limit, disableBuffering }); + } + + public async hasMore(): Promise { + const resultHandler = await this.getResultHandler(); + return resultHandler.hasMore(); + } + + public async status(progress: boolean): Promise { + const response = await this.thriftStatusResponse(progress); + return this.adaptOperationStatus(response); + } + + /** + * Thrift-specific accessor that returns the raw `TGetOperationStatusResp`. + * + * Used internally to drive the Thrift state machine + attach the wire + * response to `OperationStateError`. Also called by the public + * `DBSQLOperation.status()` facade (zero-loss fast path) so existing user + * code that reads `taskStatus`, `numModifiedRows`, etc. continues to work + * verbatim against the Thrift backend. + * + * Not declared on `IOperationBackend` — non-Thrift backends do not + * implement it. The facade reaches it via `instanceof ThriftOperationBackend`. + */ + public async thriftStatusResponse(progress: boolean): Promise { + if (this.operationStatus) { + return this.operationStatus; + } + + const driver = await this.context.getDriver(); + const response = await driver.getOperationStatus({ + operationHandle: this.operationHandle, + getProgressUpdate: progress, + }); + + return this.processOperationStatusResponse(response); + } + + public async waitUntilReady(options?: WaitUntilReadyOptions): Promise { + if (this.state === TOperationState.FINISHED_STATE) { + return; + } + + let isReady = false; + + while (!isReady) { + // eslint-disable-next-line no-await-in-loop + const response = await this.thriftStatusResponse(Boolean(options?.progress)); + + if (options?.callback) { + // The public `OperationStatusCallback` is Thrift-shaped; pass the + // wire response verbatim. Non-Thrift backends synthesize via + // `synthesizeThriftStatus` in their own `waitUntilReady` impls. + // eslint-disable-next-line no-await-in-loop + await Promise.resolve(options.callback(response)); + } + + switch (response.operationState) { + case TOperationState.INITIALIZED_STATE: + case TOperationState.PENDING_STATE: + case TOperationState.RUNNING_STATE: + break; + + case TOperationState.FINISHED_STATE: + isReady = true; + break; + + case TOperationState.CANCELED_STATE: + throw new OperationStateError(OperationStateErrorCode.Canceled, response); + + case TOperationState.CLOSED_STATE: + throw new OperationStateError(OperationStateErrorCode.Closed, response); + + case TOperationState.ERROR_STATE: + throw new OperationStateError(OperationStateErrorCode.Error, response); + case TOperationState.TIMEDOUT_STATE: + throw new OperationStateError(OperationStateErrorCode.Timeout, response); + case TOperationState.UKNOWN_STATE: + default: + throw new OperationStateError(OperationStateErrorCode.Unknown, response); + } + + if (!isReady) { + // eslint-disable-next-line no-await-in-loop + await delay(100); + } + } + } + + public async getResultMetadata(): Promise { + return this.adaptResultMetadata(await this.thriftResultMetadataResponse()); + } + + /** + * Thrift-specific accessor for the raw `TGetResultSetMetadataResp`. + * + * Used internally by `getResultHandler` (dispatches on Thrift `resultFormat` + * and passes the full Thrift response to the JSON / Arrow / CloudFetch + * result handlers). Also called by the public `DBSQLOperation.getMetadata()` + * facade (zero-loss fast path). + * + * Not declared on `IOperationBackend` — non-Thrift backends do not implement + * it. The facade reaches it via `instanceof ThriftOperationBackend`. + */ + public async thriftResultMetadataResponse(): Promise { + if (this.metadata) { + return this.metadata; + } + + if (this.metadataPromise) { + return this.metadataPromise; + } + + this.metadataPromise = (async () => { + const driver = await this.context.getDriver(); + const metadata = await driver.getResultSetMetadata({ + operationHandle: this.operationHandle, + }); + Status.assert(metadata.status); + this.metadata = metadata; + return metadata; + })(); + + try { + return await this.metadataPromise; + } finally { + this.metadataPromise = undefined; + } + } + + public async cancel(): Promise { + this.context.getLogger().log(LogLevel.debug, `Cancelling operation with id: ${this.id}`); + const driver = await this.context.getDriver(); + const response = await driver.cancelOperation({ + operationHandle: this.operationHandle, + }); + Status.assert(response.status); + return new Status(response.status); + } + + public async close(): Promise { + this.context.getLogger().log(LogLevel.debug, `Closing operation with id: ${this.id}`); + const driver = await this.context.getDriver(); + const response = + this.closeOperation ?? + (await driver.closeOperation({ + operationHandle: this.operationHandle, + })); + Status.assert(response.status); + return new Status(response.status); + } + + private async getResultHandler(): Promise> { + const metadata = await this.thriftResultMetadataResponse(); + const resultFormat = definedOrError(metadata.resultFormat); + + if (!this.resultHandler) { + let resultSource: IResultsProvider> | undefined; + + switch (resultFormat) { + case TSparkRowSetType.COLUMN_BASED_SET: + resultSource = new JsonResultHandler(this.context, this._data, metadata); + break; + case TSparkRowSetType.ARROW_BASED_SET: + resultSource = new ArrowResultConverter( + this.context, + new ArrowResultHandler(this.context, this._data, metadata), + metadata, + ); + break; + case TSparkRowSetType.URL_BASED_SET: + resultSource = new ArrowResultConverter( + this.context, + new CloudFetchResultHandler(this.context, this._data, metadata), + metadata, + ); + break; + // no default + } + + if (resultSource) { + this.resultHandler = new ResultSlicer(this.context, resultSource); + } + } + + if (!this.resultHandler) { + throw new HiveDriverError(`Unsupported result format: ${TSparkRowSetType[resultFormat]}`); + } + + return this.resultHandler; + } + + private processOperationStatusResponse(response: TGetOperationStatusResp) { + Status.assert(response.status); + + this.state = response.operationState ?? this.state; + + if (typeof response.hasResultSet === 'boolean') { + this.operationHandle.hasResultSet = response.hasResultSet; + } + + const isInProgress = [ + TOperationState.INITIALIZED_STATE, + TOperationState.PENDING_STATE, + TOperationState.RUNNING_STATE, + ].includes(this.state); + + if (!isInProgress) { + this.operationStatus = response; + } + + return response; + } + + private adaptOperationStatus(response: TGetOperationStatusResp): OperationStatus { + return { + state: thriftStateToOperationState(response.operationState), + hasResultSet: typeof response.hasResultSet === 'boolean' ? response.hasResultSet : undefined, + errorMessage: response.errorMessage ?? response.displayMessage ?? undefined, + sqlState: response.sqlState ?? undefined, + progressUpdateResponse: response.progressUpdateResponse, + }; + } + + // eslint-disable-next-line class-methods-use-this + private adaptResultMetadata(response: TGetResultSetMetadataResp): ResultMetadata { + return { + schema: response.schema, + resultFormat: thriftRowSetTypeToResultFormat(definedOrError(response.resultFormat)), + lz4Compressed: response.lz4Compressed, + arrowSchema: response.arrowSchema, + isStagingOperation: Boolean(response.isStagingOperation), + }; + } +} diff --git a/lib/thrift-backend/ThriftSessionBackend.ts b/lib/thrift-backend/ThriftSessionBackend.ts new file mode 100644 index 00000000..c103ab4f --- /dev/null +++ b/lib/thrift-backend/ThriftSessionBackend.ts @@ -0,0 +1,333 @@ +import { stringify, NIL } from 'uuid'; +import Int64 from 'node-int64'; +import { + TSessionHandle, + TStatus, + TOperationHandle, + TSparkDirectResults, + TSparkArrowTypes, + TSparkParameter, + TProtocolVersion, + TExecuteStatementReq, +} from '../../thrift/TCLIService_types'; +import ISessionBackend from '../contracts/ISessionBackend'; +import IOperationBackend from '../contracts/IOperationBackend'; +import IClientContext, { ClientConfig } from '../contracts/IClientContext'; +import { + ExecuteStatementOptions, + TypeInfoRequest, + CatalogsRequest, + SchemasRequest, + TablesRequest, + TableTypesRequest, + ColumnsRequest, + FunctionsRequest, + PrimaryKeysRequest, + CrossReferenceRequest, +} from '../contracts/IDBSQLSession'; +import Status from '../dto/Status'; +import InfoValue from '../dto/InfoValue'; +import { definedOrError, LZ4, ProtocolVersion, serializeQueryTags } from '../utils'; +import ParameterError from '../errors/ParameterError'; +import { DBSQLParameter, DBSQLParameterValue } from '../DBSQLParameter'; +import { LogLevel } from '../contracts/IDBSQLLogger'; +import ThriftOperationBackend from './ThriftOperationBackend'; + +interface OperationResponseShape { + status: TStatus; + operationHandle?: TOperationHandle; + directResults?: TSparkDirectResults; +} + +export function numberToInt64(value: number | bigint | Int64): Int64 { + if (value instanceof Int64) { + return value; + } + + if (typeof value === 'bigint') { + const buffer = new ArrayBuffer(BigInt64Array.BYTES_PER_ELEMENT); + const view = new DataView(buffer); + view.setBigInt64(0, value, false); // `false` to use big-endian order + return new Int64(Buffer.from(buffer)); + } + + return new Int64(value); +} + +function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undefined, config: ClientConfig) { + if (maxRows === null) { + return {}; + } + + return { + getDirectResults: { + maxRows: numberToInt64(maxRows ?? config.directResultsDefaultMaxRows), + }, + }; +} + +function getArrowOptions( + config: ClientConfig, + serverProtocolVersion: TProtocolVersion | undefined | null, +): { + canReadArrowResult: boolean; + useArrowNativeTypes?: TSparkArrowTypes; +} { + const { arrowEnabled = true, useArrowNativeTypes = true } = config; + + if (!arrowEnabled || !ProtocolVersion.supportsArrowMetadata(serverProtocolVersion)) { + return { + canReadArrowResult: false, + }; + } + + return { + canReadArrowResult: true, + useArrowNativeTypes: { + timestampAsArrow: useArrowNativeTypes, + decimalAsArrow: useArrowNativeTypes, + complexTypesAsArrow: useArrowNativeTypes, + intervalTypesAsArrow: false, + }, + }; +} + +function getQueryParameters( + namedParameters?: Record, + ordinalParameters?: Array, +): Array { + const namedParametersProvided = namedParameters !== undefined && Object.keys(namedParameters).length > 0; + const ordinalParametersProvided = ordinalParameters !== undefined && ordinalParameters.length > 0; + + if (namedParametersProvided && ordinalParametersProvided) { + throw new ParameterError('Driver does not support both ordinal and named parameters.'); + } + + if (!namedParametersProvided && !ordinalParametersProvided) { + return []; + } + + const result: Array = []; + + if (namedParameters !== undefined) { + for (const name of Object.keys(namedParameters)) { + const value = namedParameters[name]; + const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); + result.push(param.toSparkParameter({ name })); + } + } + + if (ordinalParameters !== undefined) { + for (const value of ordinalParameters) { + const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); + result.push(param.toSparkParameter()); + } + } + + return result; +} + +interface ThriftSessionBackendOptions { + handle: TSessionHandle; + context: IClientContext; + serverProtocolVersion?: TProtocolVersion; +} + +export default class ThriftSessionBackend implements ISessionBackend { + private readonly context: IClientContext; + + private readonly sessionHandle: TSessionHandle; + + private readonly serverProtocolVersion?: TProtocolVersion; + + constructor({ handle, context, serverProtocolVersion }: ThriftSessionBackendOptions) { + this.sessionHandle = handle; + this.context = context; + this.serverProtocolVersion = serverProtocolVersion; + this.context.getLogger().log(LogLevel.debug, `Server protocol version: ${this.serverProtocolVersion}`); + } + + private getRunAsyncForMetadataOperations(): boolean | undefined { + return ProtocolVersion.supportsAsyncMetadataOperations(this.serverProtocolVersion) ? true : undefined; + } + + public get id(): string { + const sessionId = this.sessionHandle?.sessionId?.guid; + return sessionId ? stringify(sessionId) : NIL; + } + + public async getInfo(infoType: number): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getInfo({ + sessionHandle: this.sessionHandle, + infoType, + }); + Status.assert(response.status); + return new InfoValue(response.infoValue); + } + + public async executeStatement(statement: string, options: ExecuteStatementOptions): Promise { + const driver = await this.context.getDriver(); + const clientConfig = this.context.getConfig(); + + const request = new TExecuteStatementReq({ + sessionHandle: this.sessionHandle, + statement, + queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined, + runAsync: true, + ...getDirectResultsOptions(options.maxRows, clientConfig), + ...getArrowOptions(clientConfig, this.serverProtocolVersion), + }); + + if (ProtocolVersion.supportsParameterizedQueries(this.serverProtocolVersion)) { + request.parameters = getQueryParameters(options.namedParameters, options.ordinalParameters); + } + + const serializedQueryTags = serializeQueryTags(options.queryTags); + if (serializedQueryTags !== undefined) { + request.confOverlay = { ...request.confOverlay, query_tags: serializedQueryTags }; + } + + if (ProtocolVersion.supportsCloudFetch(this.serverProtocolVersion)) { + request.canDownloadResult = options.useCloudFetch ?? clientConfig.useCloudFetch; + } + + if (ProtocolVersion.supportsArrowCompression(this.serverProtocolVersion) && request.canDownloadResult !== true) { + request.canDecompressLZ4Result = (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4()); + } + + const response = await driver.executeStatement(request); + return this.createOperationBackend(response); + } + + public async getTypeInfo(request: TypeInfoRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getTypeInfo({ + sessionHandle: this.sessionHandle, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getCatalogs(request: CatalogsRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getCatalogs({ + sessionHandle: this.sessionHandle, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getSchemas(request: SchemasRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getSchemas({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getTables(request: TablesRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getTables({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + tableName: request.tableName, + tableTypes: request.tableTypes, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getTableTypes(request: TableTypesRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getTableTypes({ + sessionHandle: this.sessionHandle, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getColumns(request: ColumnsRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getColumns({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + tableName: request.tableName, + columnName: request.columnName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getFunctions(request: FunctionsRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getFunctions({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + functionName: request.functionName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getPrimaryKeys(request: PrimaryKeysRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getPrimaryKeys({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + tableName: request.tableName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getCrossReference(request: CrossReferenceRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getCrossReference({ + sessionHandle: this.sessionHandle, + parentCatalogName: request.parentCatalogName, + parentSchemaName: request.parentSchemaName, + parentTableName: request.parentTableName, + foreignCatalogName: request.foreignCatalogName, + foreignSchemaName: request.foreignSchemaName, + foreignTableName: request.foreignTableName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async close(): Promise { + const driver = await this.context.getDriver(); + const response = await driver.closeSession({ + sessionHandle: this.sessionHandle, + }); + Status.assert(response.status); + return new Status(response.status); + } + + private createOperationBackend(response: OperationResponseShape): IOperationBackend { + Status.assert(response.status); + const handle = definedOrError(response.operationHandle); + return new ThriftOperationBackend({ + handle, + directResults: response.directResults, + context: this.context, + }); + } +} diff --git a/lib/utils/prependSlash.ts b/lib/utils/prependSlash.ts new file mode 100644 index 00000000..a3ed7d92 --- /dev/null +++ b/lib/utils/prependSlash.ts @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Normalise an HTTP path to a leading-slash form. Empty strings are left + * untouched. Shared by the Thrift connect path (`DBSQLClient`) and the + * SEA auth adapter (`SeaAuth`) so the two can't drift. + */ +export default function prependSlash(str: string): string { + if (str.length > 0 && str.charAt(0) !== '/') { + return `/${str}`; + } + return str; +} diff --git a/lib/utils/thriftWireSynthesis.ts b/lib/utils/thriftWireSynthesis.ts new file mode 100644 index 00000000..b2f69246 --- /dev/null +++ b/lib/utils/thriftWireSynthesis.ts @@ -0,0 +1,87 @@ +import { + TGetOperationStatusResp, + TGetResultSetMetadataResp, + TOperationState, + TSparkRowSetType, + TStatus, + TStatusCode, +} from '../../thrift/TCLIService_types'; +import { OperationState, OperationStatus } from '../contracts/OperationStatus'; +import { ResultFormat, ResultMetadata } from '../contracts/ResultMetadata'; + +function synthesizeOkStatus(): TStatus { + return { statusCode: TStatusCode.SUCCESS_STATUS } as TStatus; +} + +function operationStateToThrift(state: OperationState): TOperationState { + switch (state) { + case OperationState.Pending: + return TOperationState.PENDING_STATE; + case OperationState.Running: + return TOperationState.RUNNING_STATE; + case OperationState.Succeeded: + return TOperationState.FINISHED_STATE; + case OperationState.Cancelled: + return TOperationState.CANCELED_STATE; + case OperationState.Closed: + return TOperationState.CLOSED_STATE; + case OperationState.Failed: + return TOperationState.ERROR_STATE; + case OperationState.Unknown: + default: + return TOperationState.UKNOWN_STATE; + } +} + +function resultFormatToThrift(format: ResultFormat): TSparkRowSetType { + switch (format) { + case ResultFormat.ColumnBased: + return TSparkRowSetType.COLUMN_BASED_SET; + case ResultFormat.ArrowBased: + return TSparkRowSetType.ARROW_BASED_SET; + case ResultFormat.UrlBased: + return TSparkRowSetType.URL_BASED_SET; + default: + return TSparkRowSetType.COLUMN_BASED_SET; + } +} + +/** + * Synthesize a Thrift `TGetOperationStatusResp` from the neutral + * `OperationStatus` DTO. Used by `DBSQLOperation.status()` when running + * against a non-Thrift backend (e.g. SEA) so the public API stays Thrift-shaped. + * + * Lossy by design: Thrift-only fields not carried by `OperationStatus` + * (`taskStatus`, `numModifiedRows`, `operationStarted`, `operationCompleted`, + * `displayMessage`, `diagnosticInfo`) are left undefined. Consumers that + * read those fields will see `undefined` on non-Thrift backends. + */ +export function synthesizeThriftStatus(status: OperationStatus): TGetOperationStatusResp { + return { + status: synthesizeOkStatus(), + operationState: operationStateToThrift(status.state), + sqlState: status.sqlState, + errorMessage: status.errorMessage, + hasResultSet: status.hasResultSet, + progressUpdateResponse: status.progressUpdateResponse as TGetOperationStatusResp['progressUpdateResponse'], + } as TGetOperationStatusResp; +} + +/** + * Synthesize a Thrift `TGetResultSetMetadataResp` from the neutral + * `ResultMetadata` DTO. Used by `DBSQLOperation.getMetadata()` when running + * against a non-Thrift backend. + * + * Lossy: `cacheLookupResult`, `uncompressedBytes`, `compressedBytes` are left + * undefined; `status` is set to a synthetic OK. + */ +export function synthesizeThriftResultSetMetadata(metadata: ResultMetadata): TGetResultSetMetadataResp { + return { + status: synthesizeOkStatus(), + schema: metadata.schema, + resultFormat: resultFormatToThrift(metadata.resultFormat), + lz4Compressed: metadata.lz4Compressed, + arrowSchema: metadata.arrowSchema, + isStagingOperation: metadata.isStagingOperation, + } as TGetResultSetMetadataResp; +} diff --git a/native/sea/README.md b/native/sea/README.md new file mode 100644 index 00000000..2a246059 --- /dev/null +++ b/native/sea/README.md @@ -0,0 +1,87 @@ +# `native/sea/` — consumer-side directory for the Rust napi binding + +**The Rust binding source lives in the kernel repo** at +`databricks-sql-kernel/napi/`. Building it requires a local checkout +of that repo — see "Build for local dev" below. The published npm +package is `@databricks/sql-kernel-`. + +## Workspace topology + +The napi crate is a **standalone Cargo workspace** (`[workspace] +members = ["."]` in `napi/Cargo.toml`), **not** a sibling of `pyo3/` +in the kernel root workspace. + +The reason is Cargo feature unification. pyo3 builds the kernel with +the default `tls-native` feature (system OpenSSL via `native-tls`). +The napi crate has to opt INTO `tls-rustls` instead: napi modules are +loaded into Node.js processes that statically link OpenSSL 3.x, and +dynamically linking the system's OpenSSL 1.1 (which `native-tls` +pulls in on Linux) collides with Node's symbols at module-load time +and segfaults the process before any Rust code runs. `rustls` is +pure Rust + `ring` and avoids the conflict entirely. + +If napi lived in the same workspace as pyo3, `cargo build +--workspace` would unify the kernel's feature set to `tls-native ∪ +tls-rustls`, link both TLS stacks into the resulting napi cdylib, +and reintroduce the same clash. Standalone-workspace is the fix. + +## What lives in this directory + +- `index.d.ts` — TypeScript declarations consumed by `lib/sea/`. + Generated by napi-rs from the Rust source; checked in as the + consumer-facing type contract. +- `index.js` — napi-rs's per-platform router shim. Gitignored; + populated by `npm run build:native` for local dev. In published + tarballs it ships alongside the `.d.ts` and `require()`s the + right `@databricks/sql-kernel-` optional dependency. +- `index.*.node` — the actual native binary, one per platform. + Gitignored. In production these live in the per-triple optional + dependencies (`@databricks/sql-kernel-linux-x64-gnu`, etc.); for + local dev `npm run build:native` copies one into this directory. + +## Build for local dev + +```bash +# From the nodejs repo root: +export DATABRICKS_SQL_KERNEL_REPO=/path/to/your/databricks-sql-kernel +npm run build:native # release build (default) +BUILD_PROFILE= npm run build:native # debug build (empty BUILD_PROFILE drops --release) +``` + +`DATABRICKS_SQL_KERNEL_REPO` points at the kernel repo root (the +directory containing `napi/`) and is required when your kernel +checkout isn't at `../../databricks-sql-kernel` relative to the +nodejs repo. + +## Production load path + +At release time the kernel's CI publishes +`@databricks/sql-kernel-` npm packages — one per supported +platform — each containing a single `.node` binary. The nodejs +driver lists them as `optionalDependencies`; npm installs only the +one matching the consumer's `process.platform` / `process.arch`. +`native/sea/index.js` (the napi-rs router) then `require()`s the +installed package at load time. + +## Supported platforms (M0) + +M0 publishes a **single** triple: **`linux-x64-gnu`** (package +`@databricks/sql-kernel-linux-x64-gnu`). It is the only entry in the +driver's `optionalDependencies`. + +On every other platform (macOS, Windows, linux-arm64, linux-x64-musl +/ Alpine, …) the SEA binding is simply absent: `SeaNativeLoader` +returns `undefined` from `tryGet()` / throws a structured +`MODULE_NOT_FOUND` hint from `get()`, and the driver continues to use +the Thrift backend exclusively. This is expected, not a regression — +additional triples are added to `optionalDependencies` as the kernel +CI starts publishing them in later milestones. + +## Supply-chain note + +The unpublished triple names (`@databricks/sql-kernel-darwin-arm64`, +`…-win32-x64-msvc`, etc.) referenced by the router are **not** +squat-able: `@databricks` is a Databricks-owned npm scope, and npm +only allows org members to publish under a scope it owns. A third +party therefore cannot register `@databricks/sql-kernel-*` and have +the router autoload it. No placeholder packages are required. diff --git a/native/sea/index.d.ts b/native/sea/index.d.ts new file mode 100644 index 00000000..eb16e8ac --- /dev/null +++ b/native/sea/index.d.ts @@ -0,0 +1,297 @@ +/* tslint:disable */ +/* eslint-disable */ + +/* auto-generated by NAPI-RS */ + +/** + * JS-visible options for opening a Databricks SQL session over PAT. + * `token` is required. + * + * Catalog / schema / sessionConf are applied once at session creation + * and remain in effect for every statement run on the resulting + * `Connection`. The SEA wire protocol carries them on + * `CreateSession`, not on `ExecuteStatement` — so there is no + * per-statement override path on this binding. + */ +export interface ConnectionOptions { + /** + * Workspace host, e.g. `adb-…azuredatabricks.net`. The kernel + * normalises this — bare hostnames get `https://` prepended. + */ + hostName: string + /** + * JDBC-style HTTP path, e.g. `/sql/1.0/warehouses/abc123`. The + * kernel parses out the warehouse id. + */ + httpPath: string + /** + * Personal access token. Must be non-empty (the kernel rejects + * empty PATs at session construction). + */ + token: string + /** + * Default catalog for statements executed on this session. + * Routed through the kernel's `DefaultOpts` and onto the SEA + * `CreateSession.catalog` wire field. + */ + catalog?: string + /** + * Default schema for statements executed on this session. + * Routed through the kernel's `DefaultOpts` and onto the SEA + * `CreateSession.schema` wire field. + */ + schema?: string + /** + * Server-bound session conf (Spark conf, `ANSI_MODE`, `TIMEZONE`, + * query-tag presets, …). Forwarded verbatim to SEA + * `session_confs`. Unknown keys are rejected server-side. + */ + sessionConf?: Record + /** + * Maximum number of pooled HTTP connections per host. Routes + * through the kernel's [`HttpConfig::pool_max_idle_per_host`]. + * Tunes the underlying `reqwest` connection pool — higher values + * reduce reconnect overhead when many statements run + * concurrently against the same warehouse. + * + * When the JS caller does NOT provide `maxConnections`, the napi + * binding applies a NodeJS-driver-appropriate default of + * [`NAPI_DEFAULT_POOL_MAX_IDLE_PER_HOST`] (100) — chosen to match + * the JDBC driver's `HttpConnectionPoolSize` default and to close + * the throughput gap vs the NodeJS Thrift driver's + * `maxSockets: Infinity` pool for bursty workloads. The kernel + * core's [`HttpConfig::pool_max_idle_per_host`] default remains + * at the conservative kernel value (10); each binding chooses + * its own user-facing default. Mirrors the Python connector's + * `max_connections` kwarg on the SEA backend, which exposes the + * knob but keeps its own urllib3-aligned default of 10. + * + * Napi-rs serialises `u32` as JS `number`; values up to + * `2^32 - 1` round-trip safely (any reasonable pool size fits). + */ + maxConnections?: number +} +/** + * Open a Databricks SQL session over PAT auth and return an opaque + * `Connection` wrapping the kernel `Session`. + * + * The JS-visible name is `openSession` (napi-rs converts snake_case + * to camelCase for free functions). + */ +export declare function openSession(options: ConnectionOptions): Promise +/** + * A single Arrow IPC stream payload encoding one record batch (plus + * the schema header so the JS-side reader is stateless). + */ +export interface ArrowBatch { + /** + * Arrow IPC stream payload (schema header + 1 record-batch + * message). Decode with `apache-arrow`'s `RecordBatchReader`. + */ + ipcBytes: Buffer +} +/** + * An Arrow IPC stream payload encoding just the result schema (no + * record-batch messages). Returned by `Statement.schema()`. + */ +export interface ArrowSchema { + /** + * Arrow IPC stream payload (schema header only, no record-batch + * messages). Decode with `apache-arrow`'s `RecordBatchReader` — + * the reader will expose the schema and immediately end. + */ + ipcBytes: Buffer +} +/** + * Returns the native binding's crate version (`CARGO_PKG_VERSION`). + * + * Originally the round-1b smoke test; kept as a cheap "is the binding + * loaded?" probe for the JS-side loader's structured diagnostics. + */ +export declare function version(): string +/** + * Opaque connection handle wrapping a kernel `Session`. + * + * `inner` is `Arc>>` so: + * - the Drop impl can clone the `Arc` and `.take()` the session on a + * background tokio task without holding `&mut self` (which Drop is + * forbidden from doing across an `await`), + * - `close()` can `.take()` the session to consume it for the kernel's + * move-by-value `Session::close(self)` signature. + * + * **Current concurrency shape** — `executeStatement` holds + * `inner.lock()` across `stmt.execute().await`, so two concurrent + * `Promise.all([executeStatement(q1), executeStatement(q2)])` calls + * on the same Connection serialise even though the kernel transport + * supports concurrent statements per session, and `close()` blocks + * behind any in-flight execute. The kernel's `Session::statement()` + * is `&self`-callable, so the right shape is `Arc` with + * concurrent execute paths; that lands in the follow-up lock-shape + * refactor — see + * `sea-workflow/jira-candidates/2026-05-24-napi-cancel-during-fetch.md`. + */ +export declare class Connection { + /** + * Server-issued session id. Cached at construction; readable + * even after `close()` so JS-side log lines can correlate + * against kernel / server logs which key on the same id. + */ + get sessionId(): string + /** + * Execute a SQL statement and return a Statement handle that + * streams batches via `fetchNextBatch()`. + * + * No per-statement options: catalog / schema / sessionConf are + * session-level (`openSession`). + */ + executeStatement(sql: string): Promise + /** + * Explicit close. Awaits the server-side `DeleteSession` so the + * JS caller can observe failures (auth revoked mid-session, + * warehouse stopped, network error). Idempotent — a second call + * on an already-closed connection returns `Ok`. + * + * **Errors are terminal from the JS side.** The kernel session + * handle is consumed (`take()`) BEFORE the wire `DeleteSession` + * runs, because `Session::close` takes `self` by value. On `Err`, + * the napi `inner` is already `None`, so a JS-side retry sees a + * closed connection and returns `Ok(())` without re-attempting + * the wire call. The kernel's own `Drop` fire-and-forget retry + * runs once in the background — the JS caller can log the error + * but cannot drive a retry. If you need retry-on-failure + * semantics for `DeleteSession`, layer them above this method. + */ + close(): Promise +} +/** + * Opaque executed-statement handle. + * + * **Current concurrency shape** — every method takes `inner.lock()` + * and holds the guard across the kernel `.await`. tokio `Mutex` is + * FIFO, so cancel/close queue behind any in-flight `fetchNextBatch` + * until it returns naturally. This is a known limitation that exists + * because the napi shape has not yet been split into an + * `Arc` (for cancel/close, which the + * kernel exposes as `&self`-callable) plus a `Mutex>` only + * for the borrowed-mut fetch path. The lock-shape refactor needs a + * small kernel-side accessor and lands in a follow-up PR — see + * `sea-workflow/jira-candidates/2026-05-24-napi-cancel-during-fetch.md`. + * + * `schema` and `statement_id` are cached at construction so they + * survive `close()` — JS callers building error reports against a + * disposed statement can still read them. + */ +export declare class Statement { + /** + * Server-issued statement id. Cached at construction; readable + * even after `close()` so JS-side log lines can correlate against + * kernel / server logs which key on the same id. + */ + get statementId(): string + /** + * Number of rows modified by the statement (UPDATE / INSERT / + * DELETE / MERGE). `null` for SELECT and on warehouses that don't + * surface the counter. Mirrors Thrift's + * `TGetOperationStatusResp.numModifiedRows`. + */ + numModifiedRows(): Promise + /** + * Server-supplied user-facing message. Mirrors Thrift's + * `TGetOperationStatusResp.displayMessage`. **PII / sensitive- + * data note:** may contain SQL fragments or parameter values — + * redact before centralised logging. + * + * Populated on `Succeeded` / `Closed-with-inline-data` paths. + * On terminal-error states (`Failed` / `Cancelled` / + * `Closed-no-data`) the kernel returns an Error instead of a + * `Statement`, and the same field rides on the JS Error envelope + * under the same `displayMessage` key. + */ + displayMessage(): Promise + /** + * Server-supplied diagnostic detail — multi-line operator / + * stack context. Mirrors Thrift's + * `TGetOperationStatusResp.diagnosticInfo`. For support surfaces, + * not user-facing. Same reachability + PII caveats as + * `displayMessage`. + */ + diagnosticInfo(): Promise + /** + * Server-supplied JSON blob with extended error details. Mirrors + * Thrift's `TGetOperationStatusResp.errorDetailsJson`. + * Pass-through string — JS callers parse with `JSON.parse` if + * they need structured access. + * + * **Server-side gating:** populated only when the workspace has + * `spark.databricks.sql.errorDetailsJson.enabled = true` on the + * underlying SQL cluster. The flag is internal-only / default- + * false in the Databricks runtime, so for most JS callers this + * will return `null`. Admin-enabled workspaces return content + * shaped like `{"errorClass": "...", "messageTemplate": "..."}`. + * + * **Unbounded:** when populated, server can return a multi-MB + * blob; size before logging. + */ + errorDetailsJson(): Promise + /** + * Pull the next batch of results. Returns `null` when the stream + * is exhausted. The returned `ArrowBatch.ipcBytes` is a complete + * Arrow IPC stream (schema header + 1 record-batch message) + * suitable for handing to `apache-arrow`'s `RecordBatchReader`. + * + * On `Err`, the stream is in an unspecified state — call + * `close()` and discard the `Statement`. Subsequent + * `fetchNextBatch()` calls after an error are not guaranteed to + * succeed or fail consistently. + */ + fetchNextBatch(): Promise + /** + * Result schema as an Arrow IPC payload (schema header only, no + * record-batch message). Available before any batches have been + * fetched, and remains available after `close()` — the kernel + * materialises the schema eagerly so JS callers can build error + * reports against a disposed statement. + * + * Sync because the body has no `.await` — `encode_ipc_stream` is + * pure CPU work over an `Arc` already cached on the + * wrapper. Mirrors `pyo3/src/statement.rs::arrow_schema` (sync). + * napi-rs converts a panic in a sync `#[napi]` entry point into a + * thrown JS error via its own macro-expanded boundary, so the + * `util::guarded` `catch_unwind` wrapper that the `async fn` + * entry points use is not required for this method. + */ + schema(): ArrowSchema + /** + * Server-side cancel. + * + * Short-circuits to `Ok(())` if `fetchNextBatch` has already + * returned `null` (stream naturally exhausted) — matches the + * JDBC `Statement.cancel()` no-op-after-completion contract, so + * JS callers can fire cancel defensively without distinguishing + * "real cancel" from "raced with natural completion." + * + * Returns `KernelError(InvalidStatementHandle)` if the statement + * has been explicitly `close()`d. + */ + cancel(): Promise + /** + * Explicit close. Awaits the server-side `CloseStatement` so the + * JS caller can observe failures (auth revoked mid-session, + * network error, server-side error). Idempotent — a second call + * on an already-closed statement returns `Ok`. + * + * **Errors are terminal from the JS side.** The kernel executed + * handle is taken out of `inner` BEFORE the wire `CloseStatement` + * runs (so `Drop` knows there's nothing left to clean up). On + * `Err`, the napi `inner` is already `None`, so a JS-side retry + * sees a closed statement and returns `Ok(())` without re- + * attempting the wire call. The kernel-level `ExecutedStatement` + * has been consumed at that point and the value is dropped on + * the way out of the closure — the kernel's `ExecutedStatement:: + * Drop` then fires-and-forgets a single retry on the captured + * runtime. The JS caller can log the error but cannot drive a + * further retry. If you need retry-on-failure semantics for + * `CloseStatement`, layer them above this method. + */ + close(): Promise +} diff --git a/native/sea/index.js b/native/sea/index.js new file mode 100644 index 00000000..6153729d --- /dev/null +++ b/native/sea/index.js @@ -0,0 +1,318 @@ +/* tslint:disable */ +/* eslint-disable */ +/* prettier-ignore */ + +/* auto-generated by NAPI-RS */ + +const { existsSync, readFileSync } = require('fs') +const { join } = require('path') + +const { platform, arch } = process + +let nativeBinding = null +let localFileExisted = false +let loadError = null + +function isMusl() { + // For Node 10 + if (!process.report || typeof process.report.getReport !== 'function') { + try { + const lddPath = require('child_process').execSync('which ldd').toString().trim() + return readFileSync(lddPath, 'utf8').includes('musl') + } catch (e) { + return true + } + } else { + const { glibcVersionRuntime } = process.report.getReport().header + return !glibcVersionRuntime + } +} + +switch (platform) { + case 'android': + switch (arch) { + case 'arm64': + localFileExisted = existsSync(join(__dirname, 'index.android-arm64.node')) + try { + if (localFileExisted) { + nativeBinding = require('./index.android-arm64.node') + } else { + nativeBinding = require('@databricks/sql-kernel-android-arm64') + } + } catch (e) { + loadError = e + } + break + case 'arm': + localFileExisted = existsSync(join(__dirname, 'index.android-arm-eabi.node')) + try { + if (localFileExisted) { + nativeBinding = require('./index.android-arm-eabi.node') + } else { + nativeBinding = require('@databricks/sql-kernel-android-arm-eabi') + } + } catch (e) { + loadError = e + } + break + default: + throw new Error(`Unsupported architecture on Android ${arch}`) + } + break + case 'win32': + switch (arch) { + case 'x64': + localFileExisted = existsSync( + join(__dirname, 'index.win32-x64-msvc.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.win32-x64-msvc.node') + } else { + nativeBinding = require('@databricks/sql-kernel-win32-x64-msvc') + } + } catch (e) { + loadError = e + } + break + case 'ia32': + localFileExisted = existsSync( + join(__dirname, 'index.win32-ia32-msvc.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.win32-ia32-msvc.node') + } else { + nativeBinding = require('@databricks/sql-kernel-win32-ia32-msvc') + } + } catch (e) { + loadError = e + } + break + case 'arm64': + localFileExisted = existsSync( + join(__dirname, 'index.win32-arm64-msvc.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.win32-arm64-msvc.node') + } else { + nativeBinding = require('@databricks/sql-kernel-win32-arm64-msvc') + } + } catch (e) { + loadError = e + } + break + default: + throw new Error(`Unsupported architecture on Windows: ${arch}`) + } + break + case 'darwin': + localFileExisted = existsSync(join(__dirname, 'index.darwin-universal.node')) + try { + if (localFileExisted) { + nativeBinding = require('./index.darwin-universal.node') + } else { + nativeBinding = require('@databricks/sql-kernel-darwin-universal') + } + break + } catch {} + switch (arch) { + case 'x64': + localFileExisted = existsSync(join(__dirname, 'index.darwin-x64.node')) + try { + if (localFileExisted) { + nativeBinding = require('./index.darwin-x64.node') + } else { + nativeBinding = require('@databricks/sql-kernel-darwin-x64') + } + } catch (e) { + loadError = e + } + break + case 'arm64': + localFileExisted = existsSync( + join(__dirname, 'index.darwin-arm64.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.darwin-arm64.node') + } else { + nativeBinding = require('@databricks/sql-kernel-darwin-arm64') + } + } catch (e) { + loadError = e + } + break + default: + throw new Error(`Unsupported architecture on macOS: ${arch}`) + } + break + case 'freebsd': + if (arch !== 'x64') { + throw new Error(`Unsupported architecture on FreeBSD: ${arch}`) + } + localFileExisted = existsSync(join(__dirname, 'index.freebsd-x64.node')) + try { + if (localFileExisted) { + nativeBinding = require('./index.freebsd-x64.node') + } else { + nativeBinding = require('@databricks/sql-kernel-freebsd-x64') + } + } catch (e) { + loadError = e + } + break + case 'linux': + switch (arch) { + case 'x64': + if (isMusl()) { + localFileExisted = existsSync( + join(__dirname, 'index.linux-x64-musl.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-x64-musl.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-x64-musl') + } + } catch (e) { + loadError = e + } + } else { + localFileExisted = existsSync( + join(__dirname, 'index.linux-x64-gnu.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-x64-gnu.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-x64-gnu') + } + } catch (e) { + loadError = e + } + } + break + case 'arm64': + if (isMusl()) { + localFileExisted = existsSync( + join(__dirname, 'index.linux-arm64-musl.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-arm64-musl.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-arm64-musl') + } + } catch (e) { + loadError = e + } + } else { + localFileExisted = existsSync( + join(__dirname, 'index.linux-arm64-gnu.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-arm64-gnu.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-arm64-gnu') + } + } catch (e) { + loadError = e + } + } + break + case 'arm': + if (isMusl()) { + localFileExisted = existsSync( + join(__dirname, 'index.linux-arm-musleabihf.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-arm-musleabihf.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-arm-musleabihf') + } + } catch (e) { + loadError = e + } + } else { + localFileExisted = existsSync( + join(__dirname, 'index.linux-arm-gnueabihf.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-arm-gnueabihf.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-arm-gnueabihf') + } + } catch (e) { + loadError = e + } + } + break + case 'riscv64': + if (isMusl()) { + localFileExisted = existsSync( + join(__dirname, 'index.linux-riscv64-musl.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-riscv64-musl.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-riscv64-musl') + } + } catch (e) { + loadError = e + } + } else { + localFileExisted = existsSync( + join(__dirname, 'index.linux-riscv64-gnu.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-riscv64-gnu.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-riscv64-gnu') + } + } catch (e) { + loadError = e + } + } + break + case 's390x': + localFileExisted = existsSync( + join(__dirname, 'index.linux-s390x-gnu.node') + ) + try { + if (localFileExisted) { + nativeBinding = require('./index.linux-s390x-gnu.node') + } else { + nativeBinding = require('@databricks/sql-kernel-linux-s390x-gnu') + } + } catch (e) { + loadError = e + } + break + default: + throw new Error(`Unsupported architecture on Linux: ${arch}`) + } + break + default: + throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`) +} + +if (!nativeBinding) { + if (loadError) { + throw loadError + } + throw new Error(`Failed to load native binding`) +} + +const { Connection, openSession, Statement, version } = nativeBinding + +module.exports.Connection = Connection +module.exports.openSession = openSession +module.exports.Statement = Statement +module.exports.version = version diff --git a/package-lock.json b/package-lock.json index d4ac2179..b12e6506 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,7 @@ "dependencies": { "apache-arrow": "^13.0.0", "commander": "^9.3.0", + "flatbuffers": "^25.9.23", "node-fetch": "^2.6.12", "node-int64": "^0.4.0", "open": "^8.4.2", @@ -21,6 +22,7 @@ "winston": "^3.8.2" }, "devDependencies": { + "@napi-rs/cli": "2.18.4", "@types/chai": "^4.3.14", "@types/http-proxy": "^1.17.14", "@types/lz4": "^0.6.4", @@ -54,6 +56,7 @@ "node": ">=14.0.0" }, "optionalDependencies": { + "@databricks/sql-kernel-linux-x64-gnu": "0.1.0", "lz4": "^0.6.5" } }, @@ -628,6 +631,9 @@ "kuler": "^2.0.0" } }, + "node_modules/@databricks/sql-kernel-linux-x64-gnu": { + "optional": true + }, "node_modules/@eslint/eslintrc": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-1.3.0.tgz", @@ -833,6 +839,23 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@napi-rs/cli": { + "version": "2.18.4", + "resolved": "https://npm-proxy.dev.databricks.com/@napi-rs/cli/-/cli-2.18.4.tgz", + "integrity": "sha512-SgJeA4df9DE2iAEpr3M2H0OKl/yjtg1BnRI5/JyowS71tUWhrfSu2LT0V3vlHET+g1hBVlrO60PmEXwUEKp8Mg==", + "dev": true, + "license": "MIT", + "bin": { + "napi": "scripts/index.js" + }, + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -1394,6 +1417,12 @@ "resolved": "https://registry.npmjs.org/@types/node/-/node-20.3.0.tgz", "integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==" }, + "node_modules/apache-arrow/node_modules/flatbuffers": { + "version": "23.5.26", + "resolved": "https://npm-proxy.dev.databricks.com/flatbuffers/-/flatbuffers-23.5.26.tgz", + "integrity": "sha512-vE+SI9vrJDwi1oETtTIFldC/o9GsVKRM+s6EL0nQgxXlYV1Vc4Tk30hj4xGICftInKQKj1F3up2n8UbIVobISQ==", + "license": "SEE LICENSE IN LICENSE" + }, "node_modules/apache-arrow/node_modules/tslib": { "version": "2.6.2", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", @@ -2982,9 +3011,10 @@ } }, "node_modules/flatbuffers": { - "version": "23.5.26", - "resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-23.5.26.tgz", - "integrity": "sha512-vE+SI9vrJDwi1oETtTIFldC/o9GsVKRM+s6EL0nQgxXlYV1Vc4Tk30hj4xGICftInKQKj1F3up2n8UbIVobISQ==" + "version": "25.9.23", + "resolved": "https://npm-proxy.dev.databricks.com/flatbuffers/-/flatbuffers-25.9.23.tgz", + "integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==", + "license": "Apache-2.0" }, "node_modules/flatted": { "version": "3.2.6", @@ -6854,6 +6884,9 @@ "kuler": "^2.0.0" } }, + "@databricks/sql-kernel-linux-x64-gnu": { + "optional": true + }, "@eslint/eslintrc": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-1.3.0.tgz", @@ -7015,6 +7048,12 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "@napi-rs/cli": { + "version": "2.18.4", + "resolved": "https://npm-proxy.dev.databricks.com/@napi-rs/cli/-/cli-2.18.4.tgz", + "integrity": "sha512-SgJeA4df9DE2iAEpr3M2H0OKl/yjtg1BnRI5/JyowS71tUWhrfSu2LT0V3vlHET+g1hBVlrO60PmEXwUEKp8Mg==", + "dev": true + }, "@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -7441,6 +7480,11 @@ "resolved": "https://registry.npmjs.org/@types/node/-/node-20.3.0.tgz", "integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==" }, + "flatbuffers": { + "version": "23.5.26", + "resolved": "https://npm-proxy.dev.databricks.com/flatbuffers/-/flatbuffers-23.5.26.tgz", + "integrity": "sha512-vE+SI9vrJDwi1oETtTIFldC/o9GsVKRM+s6EL0nQgxXlYV1Vc4Tk30hj4xGICftInKQKj1F3up2n8UbIVobISQ==" + }, "tslib": { "version": "2.6.2", "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", @@ -8636,9 +8680,9 @@ } }, "flatbuffers": { - "version": "23.5.26", - "resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-23.5.26.tgz", - "integrity": "sha512-vE+SI9vrJDwi1oETtTIFldC/o9GsVKRM+s6EL0nQgxXlYV1Vc4Tk30hj4xGICftInKQKj1F3up2n8UbIVobISQ==" + "version": "25.9.23", + "resolved": "https://npm-proxy.dev.databricks.com/flatbuffers/-/flatbuffers-25.9.23.tgz", + "integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==" }, "flatted": { "version": "3.2.6", diff --git a/package.json b/package.json index e430181f..aa5b9888 100644 --- a/package.json +++ b/package.json @@ -17,6 +17,8 @@ "test": "nyc --report-dir=${NYC_REPORT_DIR:-coverage_unit} mocha --config tests/unit/.mocharc.js", "update-version": "node bin/update-version.js && prettier --write ./lib/version.ts", "build": "npm run update-version && tsc --project tsconfig.build.json", + "build:native": "bash -c 'cd ${DATABRICKS_SQL_KERNEL_REPO:-../../databricks-sql-kernel}/napi && npx --no-install @napi-rs/cli build --platform ${BUILD_PROFILE:---release} && cp index.* $OLDPWD/native/sea/'", + "prepack": "test -f native/sea/index.js || { echo 'ERROR: native/sea/index.js (napi-rs router) is missing — the published tarball would fail to load SEA. It is committed to git; run `npm run build:native` if you removed it.' >&2; exit 1; }", "watch": "tsc --project tsconfig.build.json --watch", "type-check": "tsc --noEmit", "prettier": "prettier . --check", @@ -47,6 +49,7 @@ ], "license": "Apache 2.0", "devDependencies": { + "@napi-rs/cli": "2.18.4", "@types/chai": "^4.3.14", "@types/http-proxy": "^1.17.14", "@types/lz4": "^0.6.4", @@ -79,6 +82,7 @@ "dependencies": { "apache-arrow": "^13.0.0", "commander": "^9.3.0", + "flatbuffers": "^25.9.23", "node-fetch": "^2.6.12", "node-int64": "^0.4.0", "open": "^8.4.2", @@ -89,6 +93,7 @@ "winston": "^3.8.2" }, "optionalDependencies": { + "@databricks/sql-kernel-linux-x64-gnu": "0.1.0", "lz4": "^0.6.5" } } diff --git a/tests/e2e/sea/auth-pat-e2e.test.ts b/tests/e2e/sea/auth-pat-e2e.test.ts new file mode 100644 index 00000000..335b60e5 --- /dev/null +++ b/tests/e2e/sea/auth-pat-e2e.test.ts @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import { DBSQLClient } from '../../../lib'; +import { ConnectionOptions } from '../../../lib/contracts/IDBSQLClient'; +import { InternalConnectionOptions } from '../../../lib/contracts/InternalConnectionOptions'; + +/** + * sea-auth M0 end-to-end: + * 1. Construct a DBSQLClient. + * 2. `connect({ useSEA: true, token })` against pecotesting. + * 3. `openSession()` — round-trips through the napi binding. + * 4. Close the session, then the client. + * + * No query is executed here — execution is the responsibility of the + * sea-execution feature's own e2e. This test exists solely to confirm + * the PAT round-trips end-to-end and the napi binding's `openSession` + * surface is reachable from `DBSQLClient`. + * + * Required env (exported by `~/.zshrc` on the developer machine): + * - DATABRICKS_PECOTESTING_SERVER_HOSTNAME + * - DATABRICKS_PECOTESTING_HTTP_PATH + * - DATABRICKS_PECOTESTING_TOKEN_PERSONAL (preferred — personal PAT) + * - DATABRICKS_PECOTESTING_TOKEN (fallback — shared PAT) + * + * If any of the three required env vars is missing, the suite is skipped + * so CI machines without secrets don't fail-flap. + */ +describe('sea-auth e2e — PAT through DBSQLClient ↔ SeaBackend ↔ napi binding', function suite() { + const host = process.env.DATABRICKS_PECOTESTING_SERVER_HOSTNAME; + const path = process.env.DATABRICKS_PECOTESTING_HTTP_PATH; + const token = + process.env.DATABRICKS_PECOTESTING_TOKEN_PERSONAL || process.env.DATABRICKS_PECOTESTING_TOKEN; + + this.timeout(120_000); + + before(function gate() { + if (!host || !path || !token) { + // eslint-disable-next-line no-invalid-this + this.skip(); + } + }); + + it('connects, opens a session, closes the session, closes the client', async () => { + const client = new DBSQLClient(); + + const connected = await client.connect({ + host: host as string, + path: path as string, + token: token as string, + // `useSEA` is an internal opt-in (InternalConnectionOptions), not a + // public ConnectionOptions field — cast exactly as DBSQLClient.connect + // does internally so the literal passes excess-property checking. + useSEA: true, + } as ConnectionOptions & InternalConnectionOptions); + expect(connected).to.equal(client); + + const session = await client.openSession(); + expect(session).to.exist; + expect(session.id).to.be.a('string'); + expect(session.id.length).to.be.greaterThan(0); + + const status = await session.close(); + expect(status.isSuccess).to.equal(true); + + await client.close(); + }); +}); diff --git a/tests/e2e/sea/e2e-smoke.test.ts b/tests/e2e/sea/e2e-smoke.test.ts new file mode 100644 index 00000000..e96efe34 --- /dev/null +++ b/tests/e2e/sea/e2e-smoke.test.ts @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import { tableFromIPC } from 'apache-arrow'; +import { tryGetSeaNative, SeaConnection, SeaStatement } from '../../../lib/sea/SeaNativeLoader'; +import config from '../utils/config'; + +// End-to-end smoke test against a live warehouse: +// 1. Open a kernel `Session` over PAT. +// 2. Execute `SELECT 1`, decode the IPC payload, assert the value is 1. +// 3. Exercise lifecycle negative paths (drain-past-null, double-close). +// 4. Close the statement, then the connection. +// +// Credentials come from the shared e2e config (tests/e2e/utils/config.ts: +// E2E_HOST / E2E_PATH / E2E_ACCESS_TOKEN) — the single credential source +// used by every other e2e test, so `npm run e2e` has one consistent +// skip/fail contract rather than two. + +describe('SEA native binding — end-to-end smoke', function smoke() { + // Live-warehouse tests can take >2s through warm-up. + this.timeout(60_000); + + const binding = tryGetSeaNative(); + if (binding === undefined) { + // Optional dependency absent on this platform — never reach the live path. + it.skip('SEA native binding not available on this platform'); + return; + } + + const { host: hostName, path: httpPath, token } = config; + + it('opens a session, runs SELECT 1, decodes the IPC payload to 1', async () => { + const connection: SeaConnection = await binding.openSession({ hostName, httpPath, token }); + expect(connection).to.be.an('object'); + + let statement: SeaStatement | null = null; + try { + statement = await connection.executeStatement('SELECT 1'); + expect(statement).to.be.an('object'); + + const batch = await statement.fetchNextBatch(); + expect(batch).to.not.equal(null); + expect(batch!.ipcBytes).to.be.instanceOf(Buffer); + expect(batch!.ipcBytes.length).to.be.greaterThan(0); + + // Decode the IPC payload and verify the value, not just the shape. + const table = tableFromIPC(batch!.ipcBytes); + expect(table.numRows).to.equal(1); + expect(Number(table.getChildAt(0)!.get(0))).to.equal(1); + + // Drain-past-null: subsequent fetch returns null. + const after = await statement.fetchNextBatch(); + expect(after).to.equal(null); + + // Drain-past-drained: another fetch still returns null (idempotent). + const afterAgain = await statement.fetchNextBatch(); + expect(afterAgain).to.equal(null); + } finally { + if (statement !== null) { + await statement.close(); + } + await connection.close(); + } + }); + + it('returns a schema IPC payload before any batch is fetched', async () => { + const connection: SeaConnection = await binding.openSession({ hostName, httpPath, token }); + try { + const statement = await connection.executeStatement('SELECT 1'); + try { + // schema() is synchronous on the binding (cached at construction). + const schema = statement.schema(); + expect(schema.ipcBytes).to.be.instanceOf(Buffer); + expect(schema.ipcBytes.length).to.be.greaterThan(0); + } finally { + await statement.close(); + } + } finally { + await connection.close(); + } + }); +}); diff --git a/tests/e2e/sea/execution-e2e.test.ts b/tests/e2e/sea/execution-e2e.test.ts new file mode 100644 index 00000000..28dd1035 --- /dev/null +++ b/tests/e2e/sea/execution-e2e.test.ts @@ -0,0 +1,124 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import { DBSQLClient } from '../../../lib'; +import { ConnectionOptions } from '../../../lib/contracts/IDBSQLClient'; +import { InternalConnectionOptions } from '../../../lib/contracts/InternalConnectionOptions'; + +/** + * sea-execution end-to-end test. + * + * Walks the full `DBSQLClient` → `SeaBackend` → napi binding → kernel + * pipeline against a live warehouse over PAT: + * + * 1. `connect({ useSEA: true })` selects the SEA backend. + * 2. `openSession({ initialCatalog: 'main' })` opens a kernel session + * and threads `initialCatalog` through to the napi `ExecuteOptions`. + * 3. `executeStatement('SELECT 1')` returns an `IOperation` backed by + * `SeaOperationBackend` (wraps a napi `Statement`). + * 4. `operation.id` is observable (via `IOperation.id` on the public + * surface). + * 5. `operation.cancel()` and `operation.close()` succeed without + * throwing. + * 6. `session.close()` and `client.close()` succeed without throwing. + * + * **Test gating:** requires the same env vars as `tests/native/e2e-smoke`. + * If any is missing, the suite is skipped so dev machines without + * provisioned secrets don't flap. + * + * **Proxy-validation note (per execution plan §17.4):** M0 verifies + * "no thrift fallback" indirectly — by selecting `useSEA: true` and + * exercising the executeStatement path. A proxy that captures + * `executeStatement` + `GetStatement` wire counts lands in the + * sea-integration round; for now we assert that the SEA pipeline + * itself runs cleanly to completion. + */ +describe('SEA execution end-to-end', function e2eSuite() { + const hostName = process.env.DATABRICKS_PECOTESTING_SERVER_HOSTNAME; + const httpPath = process.env.DATABRICKS_PECOTESTING_HTTP_PATH; + const token = process.env.DATABRICKS_PECOTESTING_TOKEN_PERSONAL; + + // Live-warehouse round-trips can take a few seconds through warm-up. + this.timeout(60_000); + + before(function gate() { + if (!hostName || !httpPath || !token) { + // eslint-disable-next-line no-invalid-this + this.skip(); + } + }); + + it('opens a session, executes SELECT 1, and closes cleanly via SEA backend', async () => { + const client = new DBSQLClient(); + + await client.connect({ + host: hostName as string, + path: httpPath as string, + token: token as string, + useSEA: true, + } as ConnectionOptions & InternalConnectionOptions); + + const session = await client.openSession({ + initialCatalog: 'main', + }); + expect(session).to.be.an('object'); + expect(session.id).to.be.a('string').and.have.length.greaterThan(0); + + const operation = await session.executeStatement('SELECT 1', {}); + expect(operation).to.be.an('object'); + // `IOperation.id` is the public-API observable identity for the + // returned operation. SeaOperationBackend generates a UUIDv4 for + // M0 until the napi binding surfaces the server statement id. + expect(operation.id).to.be.a('string').and.have.length.greaterThan(0); + + // M0 does not yet plumb fetchChunk through the SEA pipeline + // (sea-results owns that). We exercise the lifecycle: cancel is a + // no-op against a finished statement, close releases the kernel + // handle. + await operation.close(); + + await session.close(); + await client.close(); + }); + + it('passes sessionConfig (Spark conf) through openSession.configuration', async () => { + const client = new DBSQLClient(); + + await client.connect({ + host: hostName as string, + path: httpPath as string, + token: token as string, + useSEA: true, + } as ConnectionOptions & InternalConnectionOptions); + + // Sanity-check that supplying session-level Spark conf does not + // break openSession. The SEA wire applies these as `parameters` on + // every executeStatement; we don't observe them in the response + // for M0, but the absence of an error proves the napi binding + // accepts and forwards the map. + const session = await client.openSession({ + initialCatalog: 'main', + configuration: { + 'spark.sql.session.timeZone': 'UTC', + }, + }); + + const operation = await session.executeStatement('SELECT 1', {}); + await operation.close(); + + await session.close(); + await client.close(); + }); +}); diff --git a/tests/e2e/sea/operation-lifecycle-e2e.test.ts b/tests/e2e/sea/operation-lifecycle-e2e.test.ts new file mode 100644 index 00000000..b647778d --- /dev/null +++ b/tests/e2e/sea/operation-lifecycle-e2e.test.ts @@ -0,0 +1,286 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * End-to-end tests for the SEA operation lifecycle (cancel / close / + * finished) wired through `SeaOperationBackend`. + * + * The impl-execution feature has not yet wired + * `DBSQLClient.connect({ useSEA: true })` to dispatch into + * `SeaBackend`, so this test drives the lifecycle by: + * 1. Calling the napi `openSession(...)` free function directly to + * get a kernel `Connection`. + * 2. Calling `connection.executeStatement(...)` to get a napi + * `Statement` handle. + * 3. Wrapping that handle in a `SeaOperationBackend` and exercising + * its `cancel()` / `close()` / `waitUntilReady()` methods. + * + * This mirrors how the eventual `SeaSessionBackend.executeStatement` + * call path will assemble the operation — we just inline the kernel + * call here since the session backend is being built in parallel. + * + * Path note: the original task spec referenced + * `tests/integration/sea/operation-lifecycle-e2e.test.ts`. The + * existing project structure uses `tests/e2e/**` (with its own + * `.mocharc.js`), so this file lives under `tests/e2e/sea/` to be + * picked up by `npm run e2e` automatically. + */ + +import { expect } from 'chai'; +import IClientContext from '../../../lib/contracts/IClientContext'; +import IDBSQLLogger, { LogLevel } from '../../../lib/contracts/IDBSQLLogger'; +import { getSeaNative } from '../../../lib/sea/SeaNativeLoader'; +import SeaOperationBackend from '../../../lib/sea/SeaOperationBackend'; +import OperationStateError, { + OperationStateErrorCode, +} from '../../../lib/errors/OperationStateError'; + +// Minimal binding type shapes (mirrors the napi `index.d.ts`). +interface NativeBinding { + openSession(opts: { + hostName: string; + httpPath: string; + token: string; + }): Promise; +} + +interface NativeConnection { + executeStatement( + sql: string, + options: { + initialCatalog?: string; + initialSchema?: string; + sessionConfig?: Record; + }, + ): Promise; + close(): Promise; +} + +interface NativeStatement { + fetchNextBatch(): Promise<{ ipcBytes: Buffer } | null>; + // schema() is synchronous on the merged-kernel binding. + schema(): { ipcBytes: Buffer }; + cancel(): Promise; + close(): Promise; +} + +class NoopLogger implements IDBSQLLogger { + log(_level: LogLevel, _message: string): void { + // no-op for e2e runs + } +} + +function makeContext(): IClientContext { + const logger = new NoopLogger(); + const notUsed = () => { + throw new Error('IClientContext member not expected in lifecycle e2e'); + }; + return { + getConfig: notUsed, + getLogger: () => logger, + getConnectionProvider: notUsed, + getClient: notUsed, + getDriver: notUsed, + } as unknown as IClientContext; +} + +describe('SEA operation lifecycle — end-to-end', function suite() { + // Live-warehouse tests can take >2s through warm-up; bump the + // mocha default (2000ms) generously. The base `tests/e2e/.mocharc.js` + // already sets 300s but we keep this explicit so the file is robust + // when run via `npx mocha …` outside the e2e harness. + this.timeout(120_000); + + const hostName = + process.env.DATABRICKS_PECOTESTING_SERVER_HOSTNAME || process.env.E2E_HOST; + const httpPath = + process.env.DATABRICKS_PECOTESTING_HTTP_PATH || process.env.E2E_PATH; + const token = + process.env.DATABRICKS_PECOTESTING_TOKEN_PERSONAL || process.env.E2E_ACCESS_TOKEN; + + before(function gate() { + if (!hostName || !httpPath || !token) { + // eslint-disable-next-line no-invalid-this + this.skip(); + } + }); + + it('cancel() succeeds against a live SEA statement and is fast', async () => { + const binding = getSeaNative() as unknown as NativeBinding; + + const connection = await binding.openSession({ + hostName: hostName as string, + httpPath: httpPath as string, + token: token as string, + }); + + let statement: NativeStatement | null = null; + try { + // Use a query that is long-enough running that cancel actually + // has work to do. `range(0, 100_000_000)` is large enough that + // even with kernel-side optimizations the server has not yet + // produced the full result by the time we cancel. + statement = await connection.executeStatement( + 'SELECT * FROM range(0, 100000000)', + {}, + ); + expect(statement).to.be.an('object'); + + const op = new SeaOperationBackend({ + statement: statement as unknown as NativeStatement, + context: makeContext(), + }); + + const t0 = Date.now(); + const status = await op.cancel(); + const elapsed = Date.now() - t0; + + // Cancel must complete within 200ms. + expect(elapsed).to.be.lessThan(200, `cancel latency ${elapsed}ms exceeds 200ms budget`); + expect(status.isSuccess).to.equal(true); + } finally { + // Bypass `op.close()` here because we want to verify cancel + // alone — close is exercised in the next test. + if (statement !== null) { + try { + await statement.close(); + } catch (_) { + // Cancelled statements may surface a close error from the + // server; ignore for cleanup. + } + } + await connection.close(); + } + }); + + it('cancel mid-fetch — subsequent fetchChunk throws OperationStateError', async () => { + const binding = getSeaNative() as unknown as NativeBinding; + + const connection = await binding.openSession({ + hostName: hostName as string, + httpPath: httpPath as string, + token: token as string, + }); + + let statement: NativeStatement | null = null; + try { + statement = await connection.executeStatement( + 'SELECT * FROM range(0, 100000000)', + {}, + ); + + const op = new SeaOperationBackend({ + statement: statement as unknown as NativeStatement, + context: makeContext(), + }); + + const t0 = Date.now(); + await op.cancel(); + const elapsed = Date.now() - t0; + expect(elapsed).to.be.lessThan(200, `cancel latency ${elapsed}ms exceeds 200ms budget`); + + // After cancel, fetchChunk must throw the cancellation error + // (regardless of whether the underlying fetch implementation + // is wired — the lifecycle gate runs first). + let thrown: unknown; + try { + await op.fetchChunk({ limit: 100 }); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(OperationStateError); + expect((thrown as OperationStateError).errorCode).to.equal( + OperationStateErrorCode.Canceled, + ); + } finally { + if (statement !== null) { + try { + await statement.close(); + } catch (_) { + // ignore cleanup error after cancel + } + } + await connection.close(); + } + }); + + it('close() succeeds against a SEA statement and is idempotent', async () => { + const binding = getSeaNative() as unknown as NativeBinding; + + const connection = await binding.openSession({ + hostName: hostName as string, + httpPath: httpPath as string, + token: token as string, + }); + + try { + const statement = await connection.executeStatement('SELECT 1', {}); + + const op = new SeaOperationBackend({ + statement: statement as unknown as NativeStatement, + context: makeContext(), + }); + + const status1 = await op.close(); + expect(status1.isSuccess).to.equal(true); + + // Idempotent — a second close is a no-op on the JS side and + // does not hit the binding (which would already have taken the + // inner handle). + const status2 = await op.close(); + expect(status2.isSuccess).to.equal(true); + } finally { + await connection.close(); + } + }); + + it('finished() resolves immediately and fires the progress callback', async () => { + const binding = getSeaNative() as unknown as NativeBinding; + + const connection = await binding.openSession({ + hostName: hostName as string, + httpPath: httpPath as string, + token: token as string, + }); + + let statement: NativeStatement | null = null; + try { + statement = await connection.executeStatement('SELECT 1', {}); + + const op = new SeaOperationBackend({ + statement: statement as unknown as NativeStatement, + context: makeContext(), + }); + + let ticks = 0; + const t0 = Date.now(); + await op.waitUntilReady({ + callback: () => { + ticks += 1; + }, + }); + const elapsed = Date.now() - t0; + + // M0 finished() is a no-op — must resolve in <50ms. + expect(elapsed).to.be.lessThan(50); + // Progress callback fires exactly once. + expect(ticks).to.equal(1); + } finally { + if (statement !== null) { + await statement.close(); + } + await connection.close(); + } + }); +}); diff --git a/tests/e2e/sea/results-e2e.test.ts b/tests/e2e/sea/results-e2e.test.ts new file mode 100644 index 00000000..59741a1a --- /dev/null +++ b/tests/e2e/sea/results-e2e.test.ts @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* eslint-disable no-console */ + +import { expect } from 'chai'; +import { DBSQLClient } from '../../../lib'; +import { ConnectionOptions } from '../../../lib/contracts/IDBSQLClient'; +import { InternalConnectionOptions } from '../../../lib/contracts/InternalConnectionOptions'; + +// Integration suite: connect through both backends, run a probe query, +// and assert byte-identical row output (the M0 parity gate). Requires +// the developer's shell to export the pecotesting secrets: +// - DATABRICKS_PECOTESTING_SERVER_HOSTNAME +// - DATABRICKS_PECOTESTING_HTTP_PATH +// - DATABRICKS_PECOTESTING_TOKEN_PERSONAL +// If any is missing, the suite skips so CI / sandboxes without +// credentials don't flap. + +const PROBE_QUERY = + "SELECT 1 AS x, 'hello' AS s, true AS b, CAST(1.5 AS DECIMAL(10,2)) AS d, DATE '2026-01-01' AS dt"; + +interface PecoSecrets { + host: string; + path: string; + token: string; +} + +function readSecrets(): PecoSecrets | null { + const host = process.env.DATABRICKS_PECOTESTING_SERVER_HOSTNAME; + const path = process.env.DATABRICKS_PECOTESTING_HTTP_PATH; + const token = process.env.DATABRICKS_PECOTESTING_TOKEN_PERSONAL; + if (!host || !path || !token) return null; + return { host, path, token }; +} + +async function fetchProbeRows(useSEA: boolean, secrets: PecoSecrets): Promise>> { + const client = new DBSQLClient(); + await client.connect({ + host: secrets.host, + path: secrets.path, + token: secrets.token, + useSEA, + } as ConnectionOptions & InternalConnectionOptions); + try { + const session = await client.openSession(); + try { + const operation = await session.executeStatement(PROBE_QUERY); + try { + const rows = (await operation.fetchAll()) as Array>; + return rows; + } finally { + await operation.close(); + } + } finally { + await session.close(); + } + } finally { + await client.close(); + } +} + +// JSON-safe normalisation for byte-identical comparison. Buffers, Dates +// and BigInts each have distinct JSON representations; we coerce them +// to stable strings so deep.equal compares value-for-value across +// backends. The thrift converter and the SEA converter both surface +// these as JS Date / Buffer / Number — but we still normalise here so +// a future divergence (e.g. one path returning a string while the +// other returns a Date) trips the assertion explicitly. +function canonical(value: unknown): unknown { + if (value === null || value === undefined) return value; + if (Buffer.isBuffer(value)) return `__buffer__:${value.toString('hex')}`; + if (value instanceof Date) return `__date__:${value.toISOString()}`; + if (typeof value === 'bigint') return `__bigint__:${value.toString()}`; + if (Array.isArray(value)) return value.map(canonical); + if (typeof value === 'object') { + const out: Record = {}; + for (const [k, v] of Object.entries(value as Record)) { + out[k] = canonical(v); + } + return out; + } + return value; +} + +describe('SEA results end-to-end (pecotesting parity gate)', function suite() { + this.timeout(120_000); + + const secrets = readSecrets(); + + before(function gate() { + if (!secrets) { + // eslint-disable-next-line no-invalid-this + this.skip(); + } + }); + + it('SEA backend returns one row with expected columns', async () => { + const rows = await fetchProbeRows(true, secrets as PecoSecrets); + expect(rows.length).to.equal(1); + const row = rows[0]; + expect(row).to.have.property('x'); + expect(row).to.have.property('s'); + expect(row).to.have.property('b'); + expect(row).to.have.property('d'); + expect(row).to.have.property('dt'); + expect(Number(row.x)).to.equal(1); + expect(row.s).to.equal('hello'); + expect(row.b).to.equal(true); + expect(Number(row.d)).to.equal(1.5); + }); + + it('Thrift and SEA produce byte-identical rows for the probe query (parity gate)', async () => { + const seaRows = await fetchProbeRows(true, secrets as PecoSecrets); + const thriftRows = await fetchProbeRows(false, secrets as PecoSecrets); + expect(seaRows.map(canonical)).to.deep.equal(thriftRows.map(canonical)); + }); +}); diff --git a/tests/unit/.stubs/OperationStub.ts b/tests/unit/.stubs/OperationStub.ts index cd827141..1dcac5ca 100644 --- a/tests/unit/.stubs/OperationStub.ts +++ b/tests/unit/.stubs/OperationStub.ts @@ -54,6 +54,10 @@ export default class OperationStub implements IOperation { return Promise.reject(new Error('Not implemented')); } + public async getResultMetadata() { + return Promise.reject(new Error('Not implemented')); + } + public iterateChunks(options?: IteratorOptions): IOperationChunksIterator { return new OperationChunksIterator(this, options); } diff --git a/tests/unit/.stubs/createOperationForTest.ts b/tests/unit/.stubs/createOperationForTest.ts new file mode 100644 index 00000000..563ad016 --- /dev/null +++ b/tests/unit/.stubs/createOperationForTest.ts @@ -0,0 +1,25 @@ +import { TOperationHandle, TSparkDirectResults } from '../../../thrift/TCLIService_types'; +import DBSQLOperation from '../../../lib/DBSQLOperation'; +import ThriftOperationBackend from '../../../lib/thrift-backend/ThriftOperationBackend'; +import IClientContext from '../../../lib/contracts/IClientContext'; + +interface CreateOperationForTestArgs { + handle: TOperationHandle; + directResults?: TSparkDirectResults; + context: IClientContext; +} + +/** + * Test helper that mirrors the pre-PR-378 `new DBSQLOperation({ handle, ... })` + * legacy ctor shape, but routes through the post-PR-378 `{ backend, ... }` + * shape by constructing a `ThriftOperationBackend` explicitly. Keeps the + * facade decoupled from concrete backend imports. + */ +export function createOperationForTest({ + handle, + directResults, + context, +}: CreateOperationForTestArgs): DBSQLOperation { + const backend = new ThriftOperationBackend({ handle, directResults, context }); + return new DBSQLOperation({ backend, context }); +} diff --git a/tests/unit/.stubs/createSessionForTest.ts b/tests/unit/.stubs/createSessionForTest.ts new file mode 100644 index 00000000..145c438e --- /dev/null +++ b/tests/unit/.stubs/createSessionForTest.ts @@ -0,0 +1,21 @@ +import { TSessionHandle, TProtocolVersion } from '../../../thrift/TCLIService_types'; +import DBSQLSession from '../../../lib/DBSQLSession'; +import ThriftSessionBackend from '../../../lib/thrift-backend/ThriftSessionBackend'; +import IClientContext from '../../../lib/contracts/IClientContext'; + +interface CreateSessionForTestArgs { + handle: TSessionHandle; + context: IClientContext; + serverProtocolVersion?: TProtocolVersion; +} + +/** + * Test helper that mirrors the pre-PR-378 `new DBSQLSession({ handle, ... })` + * legacy ctor shape, but routes through the post-PR-378 `{ backend, ... }` + * shape by constructing a `ThriftSessionBackend` explicitly. Keeps the + * facade decoupled from concrete backend imports. + */ +export function createSessionForTest({ handle, context, serverProtocolVersion }: CreateSessionForTestArgs): DBSQLSession { + const backend = new ThriftSessionBackend({ handle, context, serverProtocolVersion }); + return new DBSQLSession({ backend, context }); +} diff --git a/tests/unit/DBSQLClient.test.ts b/tests/unit/DBSQLClient.test.ts index 4c0a3a34..8c3e64ce 100644 --- a/tests/unit/DBSQLClient.test.ts +++ b/tests/unit/DBSQLClient.test.ts @@ -2,6 +2,7 @@ import { expect, AssertionError } from 'chai'; import sinon from 'sinon'; import DBSQLClient, { ThriftLibrary } from '../../lib/DBSQLClient'; import DBSQLSession from '../../lib/DBSQLSession'; +import ThriftBackend from '../../lib/thrift-backend/ThriftBackend'; import PlainHttpAuthentication from '../../lib/connection/auth/PlainHttpAuthentication'; import DatabricksOAuth from '../../lib/connection/auth/DatabricksOAuth'; @@ -25,6 +26,19 @@ const connectOptions = { token: 'dapi********************************', } satisfies ConnectionOptions; +// Test helper: build a DBSQLClient with `getClient` stubbed to return the given +// ThriftClient stub, and pre-seed `client['backend']` with a ThriftBackend. +// Used to avoid 12 copies of the same 4-line setup across the openSession tests. +function makeStubbedClient(thriftClient: ThriftClientStub = new ThriftClientStub()): { + client: DBSQLClient; + thriftClient: ThriftClientStub; +} { + const client = new DBSQLClient(); + sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + client['backend'] = new ThriftBackend({ context: client, onConnectionEvent: () => {} }); + return { client, thriftClient }; +} + describe('DBSQLClient.connect', () => { it('should prepend "/" to path if it is missing', async () => { const client = new DBSQLClient(); @@ -103,18 +117,14 @@ describe('DBSQLClient.connect', () => { describe('DBSQLClient.openSession', () => { it('should successfully open session', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client } = makeStubbedClient(); const session = await client.openSession(); expect(session).instanceOf(DBSQLSession); }); it('should use initial namespace options', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); case1: { const initialCatalog = 'catalog1'; @@ -144,6 +154,7 @@ describe('DBSQLClient.openSession', () => { it('should throw an exception when not connected', async () => { const client = new DBSQLClient(); + client['backend'] = undefined; client['connectionProvider'] = undefined; try { @@ -158,15 +169,13 @@ describe('DBSQLClient.openSession', () => { }); it('should correctly pass server protocol version to session', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); // Test with default protocol version (SPARK_CLI_SERVICE_PROTOCOL_V8) { const session = await client.openSession(); expect(session).instanceOf(DBSQLSession); - expect((session as DBSQLSession)['serverProtocolVersion']).to.equal( + expect(((session as DBSQLSession)['backend'] as any)['serverProtocolVersion']).to.equal( TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, ); } @@ -179,16 +188,14 @@ describe('DBSQLClient.openSession', () => { const session = await client.openSession(); expect(session).instanceOf(DBSQLSession); - expect((session as DBSQLSession)['serverProtocolVersion']).to.equal( + expect(((session as DBSQLSession)['backend'] as any)['serverProtocolVersion']).to.equal( TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, ); } }); it('should pass session configuration to OpenSessionReq', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); const configuration = { QUERY_TAGS: 'team:engineering', ansi_mode: 'true' }; await client.openSession({ configuration }); @@ -196,9 +203,7 @@ describe('DBSQLClient.openSession', () => { }); it('should affect session behavior based on protocol version', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); // With protocol version V6 - should support async metadata operations { @@ -360,6 +365,7 @@ describe('DBSQLClient.close', () => { client['client'] = thriftClient; client['connectionProvider'] = new ConnectionProviderStub(); client['authProvider'] = new AuthProviderStub(); + client['backend'] = new ThriftBackend({ context: client, onConnectionEvent: () => {} }); const session = await client.openSession(); if (!(session instanceof DBSQLSession)) { @@ -583,9 +589,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should inject session parameter when enableMetricViewMetadata is true', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect({ ...connectOptions, enableMetricViewMetadata: true }); await client.openSession(); @@ -597,9 +601,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should not inject session parameter when enableMetricViewMetadata is false', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect({ ...connectOptions, enableMetricViewMetadata: false }); await client.openSession(); @@ -610,9 +612,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should not inject session parameter when enableMetricViewMetadata is not set', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect(connectOptions); await client.openSession(); @@ -623,9 +623,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should preserve user-provided session configuration', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect({ ...connectOptions, enableMetricViewMetadata: true }); const userConfig = { QUERY_TAGS: 'team:engineering', ansi_mode: 'true' }; @@ -638,9 +636,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should serialize queryTags dict and set in session configuration', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.openSession({ queryTags: { team: 'data-eng', project: 'etl' }, @@ -652,9 +648,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should let queryTags take precedence over configuration.QUERY_TAGS', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.openSession({ queryTags: { team: 'new-team' }, @@ -668,9 +662,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should remove QUERY_TAGS from configuration when queryTags is empty', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.openSession({ queryTags: {}, diff --git a/tests/unit/DBSQLOperation.test.ts b/tests/unit/DBSQLOperation.test.ts index b5f142ba..1e670c46 100644 --- a/tests/unit/DBSQLOperation.test.ts +++ b/tests/unit/DBSQLOperation.test.ts @@ -21,6 +21,7 @@ import CloudFetchResultHandler from '../../lib/result/CloudFetchResultHandler'; import ResultSlicer from '../../lib/result/ResultSlicer'; import ClientContextStub from './.stubs/ClientContextStub'; +import { createOperationForTest } from './.stubs/createOperationForTest'; import { Type } from 'apache-arrow'; function operationHandleStub(overrides: Partial): TOperationHandle { @@ -47,15 +48,15 @@ describe('DBSQLOperation', () => { describe('status', () => { it('should pick up state from operation handle', async () => { const context = new ClientContextStub(); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should pick up state from directResults', async () => { const context = new ClientContextStub(); - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context, directResults: { @@ -67,8 +68,8 @@ describe('DBSQLOperation', () => { }, }); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should fetch status and update internal state', async () => { @@ -77,17 +78,17 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: false }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: false }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; const status = await operation.status(); expect(driver.getOperationStatus.called).to.be.true; expect(status.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should request progress', async () => { @@ -95,7 +96,7 @@ describe('DBSQLOperation', () => { const driver = sinon.spy(context.driver); driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: false }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: false }), context }); await operation.status(true); expect(driver.getOperationStatus.called).to.be.true; @@ -108,10 +109,10 @@ describe('DBSQLOperation', () => { const driver = sinon.spy(context.driver); driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: false }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: false }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; // First call - should fetch data and cache driver.getOperationStatusResp = { @@ -122,8 +123,8 @@ describe('DBSQLOperation', () => { expect(driver.getOperationStatus.callCount).to.equal(1); expect(status1.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; // Second call - should return cached data driver.getOperationStatusResp = { @@ -134,8 +135,8 @@ describe('DBSQLOperation', () => { expect(driver.getOperationStatus.callCount).to.equal(1); expect(status2.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should fetch status if directResults status is not finished', async () => { @@ -144,7 +145,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: false }), context, directResults: { @@ -156,15 +157,15 @@ describe('DBSQLOperation', () => { }, }); - expect(operation['state']).to.equal(TOperationState.RUNNING_STATE); // from directResults - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.RUNNING_STATE); // from directResults + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; const status = await operation.status(false); expect(driver.getOperationStatus.called).to.be.true; expect(status.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should not fetch status if directResults status is finished', async () => { @@ -173,7 +174,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.RUNNING_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: false }), context, directResults: { @@ -185,21 +186,21 @@ describe('DBSQLOperation', () => { }, }); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); // from directResults - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); // from directResults + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; const status = await operation.status(false); expect(driver.getOperationStatus.called).to.be.false; expect(status.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; }); it('should throw an error in case of a status error', async () => { const context = new ClientContextStub(); context.driver.getOperationStatusResp.status.statusCode = TStatusCode.ERROR_STATUS; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); try { await operation.status(false); @@ -217,7 +218,7 @@ describe('DBSQLOperation', () => { it('should cancel operation and update state', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -232,7 +233,7 @@ describe('DBSQLOperation', () => { it('should return immediately if already cancelled', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -251,7 +252,7 @@ describe('DBSQLOperation', () => { it('should return immediately if already closed', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -270,7 +271,7 @@ describe('DBSQLOperation', () => { it('should throw an error in case of a status error and keep state', async () => { const context = new ClientContextStub(); context.driver.cancelOperationResp.status.statusCode = TStatusCode.ERROR_STATUS; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -290,7 +291,7 @@ describe('DBSQLOperation', () => { it('should reject all methods once cancelled', async () => { const context = new ClientContextStub(); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); await operation.cancel(); expect(operation['cancelled']).to.be.true; @@ -307,7 +308,7 @@ describe('DBSQLOperation', () => { it('should close operation and update state', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -322,7 +323,7 @@ describe('DBSQLOperation', () => { it('should return immediately if already closed', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -341,7 +342,7 @@ describe('DBSQLOperation', () => { it('should return immediately if already cancelled', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -361,7 +362,7 @@ describe('DBSQLOperation', () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context, directResults: { @@ -385,7 +386,7 @@ describe('DBSQLOperation', () => { it('should throw an error in case of a status error and keep state', async () => { const context = new ClientContextStub(); context.driver.closeOperationResp.status.statusCode = TStatusCode.ERROR_STATUS; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(operation['cancelled']).to.be.false; expect(operation['closed']).to.be.false; @@ -405,7 +406,7 @@ describe('DBSQLOperation', () => { it('should reject all methods once closed', async () => { const context = new ClientContextStub(); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); await operation.close(); expect(operation['closed']).to.be.true; @@ -437,14 +438,14 @@ describe('DBSQLOperation', () => { return getOperationStatusStub.wrappedMethod.apply(context.driver, args); }); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); await operation.finished(); expect(getOperationStatusStub.callCount).to.be.equal(attemptsUntilFinished); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); }); }, ); @@ -463,7 +464,7 @@ describe('DBSQLOperation', () => { return getOperationStatusStub.wrappedMethod.apply(context.driver, args); }); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); await operation.finished({ progress: true }); expect(getOperationStatusStub.called).to.be.true; @@ -487,7 +488,7 @@ describe('DBSQLOperation', () => { return getOperationStatusStub.wrappedMethod.apply(context.driver, args); }); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const callback = sinon.stub(); @@ -503,7 +504,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.status.statusCode = TStatusCode.SUCCESS_STATUS; driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context, directResults: { @@ -526,7 +527,7 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.status.statusCode = TStatusCode.ERROR_STATUS; context.driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); try { await operation.finished(); @@ -551,7 +552,7 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.status.statusCode = TStatusCode.SUCCESS_STATUS; context.driver.getOperationStatusResp.operationState = operationState; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); try { await operation.finished(); @@ -573,7 +574,7 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; context.driver.getOperationStatusResp.hasResultSet = false; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: false }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: false }), context }); const schema = await operation.getSchema(); @@ -597,13 +598,13 @@ describe('DBSQLOperation', () => { context.driver.getResultSetMetadataResp.schema = { columns: [] }; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const schema = await operation.getSchema(); expect(getOperationStatusStub.called).to.be.true; expect(schema).to.deep.equal(context.driver.getResultSetMetadataResp.schema); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); }); it('should request progress', async () => { @@ -620,7 +621,7 @@ describe('DBSQLOperation', () => { return getOperationStatusStub.wrappedMethod.apply(context.driver, args); }); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); await operation.getSchema({ progress: true }); expect(getOperationStatusStub.called).to.be.true; @@ -644,7 +645,7 @@ describe('DBSQLOperation', () => { return getOperationStatusStub.wrappedMethod.apply(context.driver, args); }); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const callback = sinon.stub(); @@ -660,7 +661,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const schema = await operation.getSchema(); expect(schema).to.deep.equal(driver.getResultSetMetadataResp.schema); @@ -673,7 +674,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const schema1 = await operation.getSchema(); expect(schema1).to.deep.equal(context.driver.getResultSetMetadataResp.schema); @@ -710,7 +711,7 @@ describe('DBSQLOperation', () => { }, }, }; - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context, directResults, @@ -728,7 +729,7 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.hasResultSet = true; context.driver.getResultSetMetadataResp.status.statusCode = TStatusCode.ERROR_STATUS; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); try { await operation.getSchema(); @@ -751,8 +752,8 @@ describe('DBSQLOperation', () => { driver.getResultSetMetadataResp.resultFormat = TSparkRowSetType.COLUMN_BASED_SET; driver.getResultSetMetadata.resetHistory(); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - const resultHandler = await operation['getResultHandler'](); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); + const resultHandler = await (operation['backend'] as any)['getResultHandler'](); expect(driver.getResultSetMetadata.called).to.be.true; expect(resultHandler).to.be.instanceOf(ResultSlicer); expect(resultHandler['source']).to.be.instanceOf(JsonResultHandler); @@ -762,8 +763,8 @@ describe('DBSQLOperation', () => { driver.getResultSetMetadataResp.resultFormat = TSparkRowSetType.ARROW_BASED_SET; driver.getResultSetMetadata.resetHistory(); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - const resultHandler = await operation['getResultHandler'](); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); + const resultHandler = await (operation['backend'] as any)['getResultHandler'](); expect(driver.getResultSetMetadata.called).to.be.true; expect(resultHandler).to.be.instanceOf(ResultSlicer); expect(resultHandler['source']).to.be.instanceOf(ArrowResultConverter); @@ -777,8 +778,8 @@ describe('DBSQLOperation', () => { driver.getResultSetMetadataResp.resultFormat = TSparkRowSetType.URL_BASED_SET; driver.getResultSetMetadata.resetHistory(); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - const resultHandler = await operation['getResultHandler'](); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); + const resultHandler = await (operation['backend'] as any)['getResultHandler'](); expect(driver.getResultSetMetadata.called).to.be.true; expect(resultHandler).to.be.instanceOf(ResultSlicer); expect(resultHandler['source']).to.be.instanceOf(ArrowResultConverter); @@ -795,7 +796,7 @@ describe('DBSQLOperation', () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: false }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: false }), context }); const results = await operation.fetchChunk({ disableBuffering: true }); @@ -822,13 +823,13 @@ describe('DBSQLOperation', () => { context.driver.fetchResultsResp.hasMoreRows = false; context.driver.fetchResultsResp.results!.columns = []; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const results = await operation.fetchChunk({ disableBuffering: true }); expect(getOperationStatusStub.called).to.be.true; expect(results).to.deep.equal([]); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); }); it('should request progress', async () => { @@ -849,7 +850,7 @@ describe('DBSQLOperation', () => { context.driver.fetchResultsResp.hasMoreRows = false; context.driver.fetchResultsResp.results!.columns = []; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); await operation.fetchChunk({ progress: true, disableBuffering: true }); expect(getOperationStatusStub.called).to.be.true; @@ -877,7 +878,7 @@ describe('DBSQLOperation', () => { context.driver.fetchResultsResp.hasMoreRows = false; context.driver.fetchResultsResp.results!.columns = []; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const callback = sinon.stub(); @@ -893,7 +894,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const results = await operation.fetchChunk({ disableBuffering: true }); @@ -907,7 +908,7 @@ describe('DBSQLOperation', () => { const driver = sinon.spy(context.driver); driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context, directResults: { @@ -943,7 +944,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context, directResults: { @@ -986,7 +987,7 @@ describe('DBSQLOperation', () => { context.driver.getResultSetMetadataResp.resultFormat = TSparkRowSetType.ROW_BASED_SET; context.driver.getResultSetMetadataResp.schema = { columns: [] }; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); try { await operation.fetchChunk({ disableBuffering: true }); @@ -1003,7 +1004,7 @@ describe('DBSQLOperation', () => { describe('fetchAll', () => { it('should fetch data while available and return it all', async () => { const context = new ClientContextStub(); - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); const originalData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; @@ -1038,13 +1039,13 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.hasResultSet = true; context.driver.fetchResultsResp.hasMoreRows = false; context.driver.fetchResultsResp.results = undefined; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.false; - expect(operation['_data']['hasMoreRowsFlag']).to.be.false; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.false; }); it('should return False if operation was closed', async () => { @@ -1053,7 +1054,7 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; context.driver.getOperationStatusResp.hasResultSet = true; context.driver.fetchResultsResp.hasMoreRows = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; await operation.fetchChunk({ disableBuffering: true }); @@ -1068,7 +1069,7 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; context.driver.getOperationStatusResp.hasResultSet = true; context.driver.fetchResultsResp.hasMoreRows = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; await operation.fetchChunk({ disableBuffering: true }); @@ -1083,13 +1084,13 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; context.driver.getOperationStatusResp.hasResultSet = true; context.driver.fetchResultsResp.hasMoreRows = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.true; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.true; }); it('should return True if hasMoreRows flag is False but there is actual data', async () => { @@ -1098,13 +1099,13 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; context.driver.getOperationStatusResp.hasResultSet = true; context.driver.fetchResultsResp.hasMoreRows = false; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.true; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.true; }); it('should return True if hasMoreRows flag is unset but there is actual data', async () => { @@ -1113,13 +1114,13 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; context.driver.getOperationStatusResp.hasResultSet = true; context.driver.fetchResultsResp.hasMoreRows = undefined; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.true; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.true; }); it('should return False if hasMoreRows flag is False and there is no data', async () => { @@ -1129,13 +1130,13 @@ describe('DBSQLOperation', () => { context.driver.getOperationStatusResp.hasResultSet = true; context.driver.fetchResultsResp.hasMoreRows = false; context.driver.fetchResultsResp.results = undefined; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.false; - expect(operation['_data']['hasMoreRowsFlag']).to.be.false; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.false; }); }); @@ -1147,7 +1148,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.hasResultSet = true; // Create operation without direct results to force metadata fetching - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); // Trigger multiple concurrent metadata fetches const results = await Promise.all([operation.hasMoreRows(), operation.hasMoreRows(), operation.hasMoreRows()]); @@ -1165,7 +1166,7 @@ describe('DBSQLOperation', () => { driver.getOperationStatusResp.operationState = TOperationState.FINISHED_STATE; driver.getOperationStatusResp.hasResultSet = true; - const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); + const operation = createOperationForTest({ handle: operationHandleStub({ hasResultSet: true }), context }); // First call should fetch metadata await operation.hasMoreRows(); diff --git a/tests/unit/DBSQLSession.test.ts b/tests/unit/DBSQLSession.test.ts index 0dc79037..51b27133 100644 --- a/tests/unit/DBSQLSession.test.ts +++ b/tests/unit/DBSQLSession.test.ts @@ -7,6 +7,7 @@ import Status from '../../lib/dto/Status'; import DBSQLOperation from '../../lib/DBSQLOperation'; import { TSessionHandle, TProtocolVersion } from '../../thrift/TCLIService_types'; import ClientContextStub from './.stubs/ClientContextStub'; +import { createSessionForTest } from './.stubs/createSessionForTest'; const sessionHandleStub: TSessionHandle = { sessionId: { guid: Buffer.alloc(16), secret: Buffer.alloc(16) }, @@ -50,7 +51,7 @@ describe('DBSQLSession', () => { describe('getInfo', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getInfo(1); expect(result).instanceOf(InfoValue); }); @@ -58,26 +59,26 @@ describe('DBSQLSession', () => { describe('executeStatement', () => { it('should execute statement', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.executeStatement('SELECT * FROM table'); expect(result).instanceOf(DBSQLOperation); }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.executeStatement('SELECT * FROM table', { maxRows: 10 }); expect(result).instanceOf(DBSQLOperation); }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.executeStatement('SELECT * FROM table', { maxRows: null }); expect(result).instanceOf(DBSQLOperation); }); describe('Arrow support', () => { it('should not use Arrow if disabled in options', async () => { - const session = new DBSQLSession({ + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub({ arrowEnabled: false }), }); @@ -88,7 +89,7 @@ describe('DBSQLSession', () => { it('should apply defaults for Arrow options', async () => { // case 1 { - const session = new DBSQLSession({ + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub({ arrowEnabled: true }), }); @@ -98,7 +99,7 @@ describe('DBSQLSession', () => { // case 2 { - const session = new DBSQLSession({ + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub({ arrowEnabled: true, useArrowNativeTypes: false }), }); @@ -133,7 +134,7 @@ describe('DBSQLSession', () => { useLZ4Compression: true, }; - const session = new DBSQLSession({ + const session = createSessionForTest({ handle: sessionHandleStub, context, serverProtocolVersion: version, @@ -195,7 +196,7 @@ describe('DBSQLSession', () => { const statement = 'SELECT * FROM table'; // Use V6+ which supports arrow compression - const session = new DBSQLSession({ + const session = createSessionForTest({ handle: sessionHandleStub, context, serverProtocolVersion: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6, @@ -218,7 +219,7 @@ describe('DBSQLSession', () => { const statement = 'SELECT * FROM table'; // Use V6+ which supports arrow compression - const session = new DBSQLSession({ + const session = createSessionForTest({ handle: sessionHandleStub, context, serverProtocolVersion: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6, @@ -241,7 +242,7 @@ describe('DBSQLSession', () => { const statement = 'SELECT * FROM table'; // Use V5 which does not support arrow compression - const session = new DBSQLSession({ + const session = createSessionForTest({ handle: sessionHandleStub, context, serverProtocolVersion: TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5, @@ -263,7 +264,7 @@ describe('DBSQLSession', () => { it('should set confOverlay with query_tags when queryTags are provided', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const session = new DBSQLSession({ handle: sessionHandleStub, context }); + const session = createSessionForTest({ handle: sessionHandleStub, context }); await session.executeStatement('SELECT 1', { queryTags: { team: 'eng', app: 'etl' } }); @@ -275,7 +276,7 @@ describe('DBSQLSession', () => { it('should not set confOverlay query_tags when queryTags is not provided', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const session = new DBSQLSession({ handle: sessionHandleStub, context }); + const session = createSessionForTest({ handle: sessionHandleStub, context }); await session.executeStatement('SELECT 1'); @@ -287,7 +288,7 @@ describe('DBSQLSession', () => { it('should not set confOverlay query_tags when queryTags is empty', async () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const session = new DBSQLSession({ handle: sessionHandleStub, context }); + const session = createSessionForTest({ handle: sessionHandleStub, context }); await session.executeStatement('SELECT 1', { queryTags: {} }); @@ -299,19 +300,19 @@ describe('DBSQLSession', () => { describe('getTypeInfo', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTypeInfo(); expect(result).instanceOf(DBSQLOperation); }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTypeInfo({ maxRows: 10 }); expect(result).instanceOf(DBSQLOperation); }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTypeInfo({ maxRows: null }); expect(result).instanceOf(DBSQLOperation); }); @@ -319,19 +320,19 @@ describe('DBSQLSession', () => { describe('getCatalogs', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getCatalogs(); expect(result).instanceOf(DBSQLOperation); }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getCatalogs({ maxRows: 10 }); expect(result).instanceOf(DBSQLOperation); }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getCatalogs({ maxRows: null }); expect(result).instanceOf(DBSQLOperation); }); @@ -339,13 +340,13 @@ describe('DBSQLSession', () => { describe('getSchemas', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getSchemas(); expect(result).instanceOf(DBSQLOperation); }); it('should use filters', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getSchemas({ catalogName: 'catalog', schemaName: 'schema', @@ -354,13 +355,13 @@ describe('DBSQLSession', () => { }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getSchemas({ maxRows: 10 }); expect(result).instanceOf(DBSQLOperation); }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getSchemas({ maxRows: null }); expect(result).instanceOf(DBSQLOperation); }); @@ -368,13 +369,13 @@ describe('DBSQLSession', () => { describe('getTables', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTables(); expect(result).instanceOf(DBSQLOperation); }); it('should use filters', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTables({ catalogName: 'catalog', schemaName: 'default', @@ -385,13 +386,13 @@ describe('DBSQLSession', () => { }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTables({ maxRows: 10 }); expect(result).instanceOf(DBSQLOperation); }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTables({ maxRows: null }); expect(result).instanceOf(DBSQLOperation); }); @@ -399,19 +400,19 @@ describe('DBSQLSession', () => { describe('getTableTypes', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTableTypes(); expect(result).instanceOf(DBSQLOperation); }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTableTypes({ maxRows: 10 }); expect(result).instanceOf(DBSQLOperation); }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getTableTypes({ maxRows: null }); expect(result).instanceOf(DBSQLOperation); }); @@ -419,13 +420,13 @@ describe('DBSQLSession', () => { describe('getColumns', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getColumns(); expect(result).instanceOf(DBSQLOperation); }); it('should use filters', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getColumns({ catalogName: 'catalog', schemaName: 'schema', @@ -436,13 +437,13 @@ describe('DBSQLSession', () => { }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getColumns({ maxRows: 10 }); expect(result).instanceOf(DBSQLOperation); }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getColumns({ maxRows: null }); expect(result).instanceOf(DBSQLOperation); }); @@ -450,7 +451,7 @@ describe('DBSQLSession', () => { describe('getFunctions', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getFunctions({ catalogName: 'catalog', schemaName: 'schema', @@ -460,7 +461,7 @@ describe('DBSQLSession', () => { }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getFunctions({ catalogName: 'catalog', schemaName: 'schema', @@ -471,7 +472,7 @@ describe('DBSQLSession', () => { }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getFunctions({ catalogName: 'catalog', schemaName: 'schema', @@ -484,7 +485,7 @@ describe('DBSQLSession', () => { describe('getPrimaryKeys', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getPrimaryKeys({ catalogName: 'catalog', schemaName: 'schema', @@ -494,7 +495,7 @@ describe('DBSQLSession', () => { }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getPrimaryKeys({ catalogName: 'catalog', schemaName: 'schema', @@ -505,7 +506,7 @@ describe('DBSQLSession', () => { }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getPrimaryKeys({ catalogName: 'catalog', schemaName: 'schema', @@ -518,7 +519,7 @@ describe('DBSQLSession', () => { describe('getCrossReference', () => { it('should run operation', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getCrossReference({ parentCatalogName: 'parentCatalogName', parentSchemaName: 'parentSchemaName', @@ -531,7 +532,7 @@ describe('DBSQLSession', () => { }); it('should use direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getCrossReference({ parentCatalogName: 'parentCatalogName', parentSchemaName: 'parentSchemaName', @@ -545,7 +546,7 @@ describe('DBSQLSession', () => { }); it('should disable direct results', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const result = await session.getCrossReference({ parentCatalogName: 'parentCatalogName', parentSchemaName: 'parentSchemaName', @@ -564,7 +565,7 @@ describe('DBSQLSession', () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const session = new DBSQLSession({ handle: sessionHandleStub, context }); + const session = createSessionForTest({ handle: sessionHandleStub, context }); expect(session['isOpen']).to.be.true; const result = await session.close(); @@ -577,7 +578,7 @@ describe('DBSQLSession', () => { const context = new ClientContextStub(); const driver = sinon.spy(context.driver); - const session = new DBSQLSession({ handle: sessionHandleStub, context }); + const session = createSessionForTest({ handle: sessionHandleStub, context }); expect(session['isOpen']).to.be.true; const result = await session.close(); @@ -592,7 +593,7 @@ describe('DBSQLSession', () => { }); it('should close operations that belong to it', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); const operation = await session.executeStatement('SELECT * FROM table'); if (!(operation instanceof DBSQLOperation)) { expect.fail('Assertion error: operation is not a DBSQLOperation'); @@ -614,7 +615,7 @@ describe('DBSQLSession', () => { }); it('should reject all methods once closed', async () => { - const session = new DBSQLSession({ handle: sessionHandleStub, context: new ClientContextStub() }); + const session = createSessionForTest({ handle: sessionHandleStub, context: new ClientContextStub() }); await session.close(); expect(session['isOpen']).to.be.false; diff --git a/tests/unit/sea/SeaIntervalParity.test.ts b/tests/unit/sea/SeaIntervalParity.test.ts new file mode 100644 index 00000000..3e3274c7 --- /dev/null +++ b/tests/unit/sea/SeaIntervalParity.test.ts @@ -0,0 +1,366 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +/** + * TDD harness for the round-2 INTERVAL parity fix. + * + * Verifies that the SEA path renders the exact thrift wire string for + * INTERVAL YEAR-MONTH and INTERVAL DAY-TIME columns, regardless of + * whether the kernel emits the value as native Arrow `Interval` or + * native Arrow `Duration` (the latter is transparently rewritten to + * `Int64` by `lib/sea/SeaArrowIpcDurationFix.ts` because `apache-arrow@13` + * predates the `Duration` type id). + * + * Reference failure modes (round 5 testing): + * - YEAR-MONTH: + * thrift → `"1-2"` (string) + * SEA pre-fix → `{"0":1,"1":2}` (Int32Array surfaced as struct) + * - DAY-TIME: + * thrift → `"1 02:03:04.000000000"` (string) + * SEA pre-fix → throws `Unrecognized type: "Duration" (18)` on schema decode + * + * Both modes must now produce byte-identical thrift strings. + */ + +import { expect } from 'chai'; +import * as flatbuffers from 'flatbuffers'; +import { + Schema, + Field, + Int32, + Int64, + Interval, + IntervalUnit, + Table, + RecordBatch, + makeData, + Struct, + vectorFromArray, + tableToIPC, +} from 'apache-arrow'; + +// eslint-disable-next-line import/no-internal-modules +import { Message as FbMessage } from 'apache-arrow/fb/message'; +// eslint-disable-next-line import/no-internal-modules +import { MessageHeader } from 'apache-arrow/fb/message-header'; +// eslint-disable-next-line import/no-internal-modules +import { Schema as FbSchema } from 'apache-arrow/fb/schema'; +// eslint-disable-next-line import/no-internal-modules +import { Field as FbField } from 'apache-arrow/fb/field'; +// eslint-disable-next-line import/no-internal-modules +import { Type as FbType } from 'apache-arrow/fb/type'; +// eslint-disable-next-line import/no-internal-modules +import { Duration as FbDuration } from 'apache-arrow/fb/duration'; +// eslint-disable-next-line import/no-internal-modules +import { TimeUnit as FbTimeUnit } from 'apache-arrow/fb/time-unit'; + +import SeaOperationBackend from '../../../lib/sea/SeaOperationBackend'; +import ClientContextStub from '../.stubs/ClientContextStub'; + +// --------------------------------------------------------------------------- +// Test helpers. +// --------------------------------------------------------------------------- + +class StatementStub { + private readonly batches: Buffer[]; + + private readonly schemaIpc: Buffer; + + public cancelled = false; + + public closed = false; + + constructor(schemaIpc: Buffer, batches: Buffer[]) { + this.schemaIpc = schemaIpc; + this.batches = [...batches]; + } + + public async fetchNextBatch(): Promise<{ ipcBytes: Buffer } | null> { + if (this.batches.length === 0) return null; + return { ipcBytes: this.batches.shift() as Buffer }; + } + + // schema() is synchronous on the merged-kernel binding. + public schema(): { ipcBytes: Buffer } { + return { ipcBytes: this.schemaIpc }; + } + + public async cancel(): Promise { + this.cancelled = true; + } + + public async close(): Promise { + this.closed = true; + } +} + +function withTypeName(field: T, typeName: string): T { + const meta = new Map(field.metadata); + meta.set('databricks.type_name', typeName); + return new Field(field.name, field.type, field.nullable, meta) as T; +} + +function ipcFromColumns(schema: Schema, columns: Record): Buffer { + const vectors: any[] = []; + for (const field of schema.fields) { + const col = columns[field.name]; + vectors.push(vectorFromArray(col as any, field.type)); + } + const data = vectors.map((v) => v.data[0]); + const struct = makeData({ + type: new Struct(schema.fields), + children: data, + length: vectors[0]?.length ?? 0, + nullCount: 0, + }); + const batch = new RecordBatch(schema, struct); + const table = new Table([batch]); + return Buffer.from(tableToIPC(table, 'stream')); +} + +function ipcSchemaOnly(schema: Schema): Buffer { + const struct = makeData({ + type: new Struct(schema.fields), + children: schema.fields.map((f) => makeData({ type: f.type as any, length: 0, nullCount: 0 })), + length: 0, + nullCount: 0, + }); + const batch = new RecordBatch(schema, struct); + const table = new Table([batch]); + return Buffer.from(tableToIPC(table, 'stream')); +} + +/** + * Build a schema-only IPC payload whose schema declares a single Arrow + * `Duration` column. `apache-arrow@13` cannot build this directly (no + * Duration class in the public API), so we hand-roll the FlatBuffer + * using the internal `fb/*` accessor classes. The body bytes for this + * column are bit-identical to an Int64 column. + */ +function ipcWithDurationSchema(fieldName: string, durationUnit: FbTimeUnit, typeName = 'INTERVAL'): Buffer { + const builder = new flatbuffers.Builder(256); + + // KeyValue for databricks.type_name + const tnKey = builder.createString('databricks.type_name'); + const tnVal = builder.createString(typeName); + const { KeyValue: FbKeyValueLocal } = require('apache-arrow/fb/key-value'); // eslint-disable-line @typescript-eslint/no-var-requires, global-require, import/no-internal-modules + FbKeyValueLocal.startKeyValue(builder); + FbKeyValueLocal.addKey(builder, tnKey); + FbKeyValueLocal.addValue(builder, tnVal); + const tnKv = FbKeyValueLocal.endKeyValue(builder); + const metadataVec = FbField.createCustomMetadataVector(builder, [tnKv]); + + const nameOff = builder.createString(fieldName); + const durOff = FbDuration.createDuration(builder, durationUnit); + FbField.startField(builder); + FbField.addName(builder, nameOff); + FbField.addNullable(builder, true); + FbField.addTypeType(builder, FbType.Duration); + FbField.addType(builder, durOff); + FbField.addCustomMetadata(builder, metadataVec); + const fieldOff = FbField.endField(builder); + const fieldsVec = FbSchema.createFieldsVector(builder, [fieldOff]); + FbSchema.startSchema(builder); + FbSchema.addFields(builder, fieldsVec); + const schemaOff = FbSchema.endSchema(builder); + FbMessage.startMessage(builder); + FbMessage.addVersion(builder, 4); // V5 + FbMessage.addHeaderType(builder, MessageHeader.Schema); + FbMessage.addHeader(builder, schemaOff); + FbMessage.addBodyLength(builder, BigInt(0)); + const msgOff = FbMessage.endMessage(builder); + builder.finish(msgOff); + const bytes = builder.asUint8Array(); + const rem = bytes.byteLength % 8; + const padded = rem === 0 ? bytes : new Uint8Array(bytes.byteLength + (8 - rem)); + if (rem !== 0) padded.set(bytes, 0); + + // IPC stream framing: continuation marker (0xFFFFFFFF) + length + bytes + const prefix = Buffer.alloc(8); + prefix.writeInt32LE(-1, 0); + prefix.writeInt32LE(padded.byteLength, 4); + + // EOS marker (continuation + zero length) — terminates the stream. + const eos = Buffer.alloc(8); + eos.writeInt32LE(-1, 0); + eos.writeInt32LE(0, 4); + + return Buffer.concat([prefix, Buffer.from(padded), eos]); +} + +/** + * Splice a hand-built Duration schema into an Int64-based IPC stream + * so the record batch body bytes (which are Int64-encoded) become + * "Duration-shaped" without us re-encoding the body. Used to fabricate + * a kernel-shaped Duration IPC payload using only the apache-arrow@13 + * public API. + */ +function buildDurationIpc(fieldName: string, durationUnit: FbTimeUnit, values: bigint[], typeName = 'INTERVAL'): Buffer { + // Build an Int64 stream that carries the values. + const int64Schema = new Schema([new Field(fieldName, new Int64(), true)]); + const int64Ipc = ipcFromColumns(int64Schema, { + [fieldName]: [new BigInt64Array(values)], + }); + + // Build a Duration schema-only message that we splice in to replace + // the Int64 schema. The record-batch bytes from int64Ipc follow + // unchanged. + const durationSchemaIpc = ipcWithDurationSchema(fieldName, durationUnit, typeName); + + // Skip the Int64 schema header + EOS in durationSchemaIpc, then + // append the int64 stream's record batches. + // int64Ipc layout: [continuation+len+schema][continuation+len+recordbatch][continuation+0 EOS] + let cursor = 0; + let len = int64Ipc.readInt32LE(cursor); + cursor += 4; + if (len === -1) { + len = int64Ipc.readInt32LE(cursor); + cursor += 4; + } + // Skip the schema body (always empty for schema messages) + const intRecordsStart = cursor + len; + const intRecords = int64Ipc.subarray(intRecordsStart); + + // durationSchemaIpc layout: [prefix][padded schema bytes][EOS]. + // Drop its EOS so it concatenates cleanly with intRecords (which has + // its own EOS). + const durationNoEos = durationSchemaIpc.subarray(0, durationSchemaIpc.byteLength - 8); + return Buffer.concat([durationNoEos, intRecords]); +} + +// --------------------------------------------------------------------------- +// Tests. +// --------------------------------------------------------------------------- + +describe('SeaOperationBackend — INTERVAL parity with thrift', () => { + it('YEAR-MONTH via native Arrow Interval[YearMonth] → "Y-M"', async () => { + // Arrow `Interval[YearMonth]` carries a single int32 total-months + // value. apache-arrow surfaces it as Int32Array(2) via the + // GetVisitor. The kernel emits this type for INTERVAL YEAR-MONTH. + const fields = [ + withTypeName(new Field('iv', new Interval(IntervalUnit.YEAR_MONTH), true), 'INTERVAL'), + ]; + const schema = new Schema(fields); + const schemaIpc = ipcSchemaOnly(schema); + + // 1 year, 2 months → 14 total months. `vectorFromArray(Int32Array, + // new Interval(...))` packs the int32 total directly into the + // Interval column's underlying values buffer. + const dataIpc = ipcFromColumns(schema, { iv: Int32Array.from([14]) }); + + const stub = new StatementStub(schemaIpc, [dataIpc]); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows).to.have.length(1); + expect((rows[0] as any).iv).to.equal('1-2'); + }); + + it('YEAR-MONTH negative → "-Y-M"', async () => { + const fields = [ + withTypeName(new Field('iv', new Interval(IntervalUnit.YEAR_MONTH), true), 'INTERVAL'), + ]; + const schema = new Schema(fields); + const schemaIpc = ipcSchemaOnly(schema); + + // -14 total months → -1 year -2 months. + const dataIpc = ipcFromColumns(schema, { iv: Int32Array.from([-14]) }); + + const stub = new StatementStub(schemaIpc, [dataIpc]); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows).to.have.length(1); + expect((rows[0] as any).iv).to.equal('-1-2'); + }); + + it('DAY-TIME via Arrow Duration(MICROSECOND) → "1 02:03:04.000000000"', async () => { + // 1 day + 2h + 3min + 4s = 93784 seconds = 93_784_000_000 µs. + const microseconds = BigInt(93_784) * BigInt(1_000_000); + const ipc = buildDurationIpc('iv', FbTimeUnit.MICROSECOND, [microseconds], 'INTERVAL'); + const schemaIpc = ipcWithDurationSchema('iv', FbTimeUnit.MICROSECOND, 'INTERVAL'); + + const stub = new StatementStub(schemaIpc, [ipc]); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows).to.have.length(1); + expect((rows[0] as any).iv).to.equal('1 02:03:04.000000000'); + }); + + it('DAY-TIME via Arrow Duration(NANOSECOND) preserves nanosecond precision', async () => { + // 1 day + 2h + 3min + 4.123456789s + const nanos = + BigInt(86400 + 2 * 3600 + 3 * 60 + 4) * BigInt(1_000_000_000) + BigInt(123_456_789); + const ipc = buildDurationIpc('iv', FbTimeUnit.NANOSECOND, [nanos], 'INTERVAL'); + const schemaIpc = ipcWithDurationSchema('iv', FbTimeUnit.NANOSECOND, 'INTERVAL'); + + const stub = new StatementStub(schemaIpc, [ipc]); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows).to.have.length(1); + expect((rows[0] as any).iv).to.equal('1 02:03:04.123456789'); + }); + + it('DAY-TIME zero → "0 00:00:00.000000000"', async () => { + const ipc = buildDurationIpc('iv', FbTimeUnit.MICROSECOND, [BigInt(0)], 'INTERVAL'); + const schemaIpc = ipcWithDurationSchema('iv', FbTimeUnit.MICROSECOND, 'INTERVAL'); + + const stub = new StatementStub(schemaIpc, [ipc]); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows).to.have.length(1); + expect((rows[0] as any).iv).to.equal('0 00:00:00.000000000'); + }); + + it('DAY-TIME negative → leading "-"', async () => { + // -(1 day + 2h + 3min + 4s) in microseconds. + const microseconds = -(BigInt(93_784) * BigInt(1_000_000)); + const ipc = buildDurationIpc('iv', FbTimeUnit.MICROSECOND, [microseconds], 'INTERVAL'); + const schemaIpc = ipcWithDurationSchema('iv', FbTimeUnit.MICROSECOND, 'INTERVAL'); + + const stub = new StatementStub(schemaIpc, [ipc]); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows).to.have.length(1); + expect((rows[0] as any).iv).to.equal('-1 02:03:04.000000000'); + }); + + it('Duration column round-trips alongside primitive columns (DRY: same converter handles both intervals)', async () => { + // Schema: [iv: Duration(µs), n: Int32]. The pre-processor must + // rewrite the Duration field WITHOUT disturbing the Int32 sibling. + // We hand-build the Duration schema (apache-arrow@13 can't build + // Duration directly) and a body that has [Int64 column, Int32 col]. + // The rewriter must keep the Int32 column intact and substitute + // Int64 for Duration. + // + // Note: we use a single-Duration-column test here because mixing + // hand-built Duration with apache-arrow's batch builder requires + // hand-rolling the entire IPC stream. The "Duration alongside + // other columns" coverage is provided by the E2E parity tests + // (M0-DT-019 in `tests/nodejs/test/parity/M0DatatypeParityTests.test.ts`) + // which use a real warehouse query that mixes INTERVAL with other + // types. + const microseconds = BigInt(86_400) * BigInt(1_000_000); // 1 day + const ipc = buildDurationIpc('iv', FbTimeUnit.MICROSECOND, [microseconds], 'INTERVAL'); + const schemaIpc = ipcWithDurationSchema('iv', FbTimeUnit.MICROSECOND, 'INTERVAL'); + + const stub = new StatementStub(schemaIpc, [ipc]); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + + // Round-trip the metadata to confirm we synthesise the right TTypeId. + const metadata = await backend.getResultMetadata(); + expect(metadata.schema?.columns?.[0]?.typeDesc.types?.[0]?.primitiveEntry?.type).to.equal( + // INTERVAL_DAY_TIME_TYPE = 30 in TCLIService_types + // We assert by importing the enum below to avoid magic numbers. + // eslint-disable-next-line global-require, @typescript-eslint/no-var-requires + require('../../../thrift/TCLIService_types').TTypeId.INTERVAL_DAY_TIME_TYPE, + ); + + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows).to.have.length(1); + expect((rows[0] as any).iv).to.equal('1 00:00:00.000000000'); + }); +}); diff --git a/tests/unit/sea/SeaOperationBackend.test.ts b/tests/unit/sea/SeaOperationBackend.test.ts new file mode 100644 index 00000000..c32ee9f9 --- /dev/null +++ b/tests/unit/sea/SeaOperationBackend.test.ts @@ -0,0 +1,290 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import { + Schema, + Field, + RecordBatch, + Table, + tableToIPC, + Bool, + Int8, + Int16, + Int32, + Int64, + Float32, + Float64, + Utf8, + Binary, + DateDay, + TimestampMicrosecond, + Decimal, + Struct, + makeData, + vectorFromArray, +} from 'apache-arrow'; + +import SeaOperationBackend from '../../../lib/sea/SeaOperationBackend'; +import ClientContextStub from '../.stubs/ClientContextStub'; + +// Minimal stub of the napi `Statement` surface that emits a precomputed +// Arrow IPC payload per `fetchNextBatch()` call. Used to feed +// `SeaOperationBackend` synthetic batches that mirror the kernel's +// per-batch IPC stream contract (`schema header + 1 record-batch +// message`) without loading the native binding. +class StatementStub { + private readonly batches: Buffer[]; + + private readonly schemaIpc: Buffer; + + public cancelled = false; + + public closed = false; + + constructor(schemaIpc: Buffer, batches: Buffer[]) { + this.schemaIpc = schemaIpc; + this.batches = [...batches]; + } + + // Mirrors the kernel `Statement.statementId` getter. + public readonly statementId = '01ef-fake-statement-id'; + + public async fetchNextBatch(): Promise<{ ipcBytes: Buffer } | null> { + if (this.batches.length === 0) return null; + return { ipcBytes: this.batches.shift() as Buffer }; + } + + // schema() is synchronous on the merged-kernel binding. + public schema(): { ipcBytes: Buffer } { + return { ipcBytes: this.schemaIpc }; + } + + public async cancel(): Promise { + this.cancelled = true; + } + + public async close(): Promise { + this.closed = true; + } + + // Status accessors from the kernel's status-fields surface. + public async numModifiedRows(): Promise { + return null; + } + + public async displayMessage(): Promise { + return null; + } + + public async diagnosticInfo(): Promise { + return null; + } + + public async errorDetailsJson(): Promise { + return null; + } +} + +// Helper: attach `databricks.type_name` to a field so the SEA Thrift +// schema synthesiser can resolve the TTypeId (matches kernel behaviour +// at `src/reader/mod.rs:476-504`). +function withTypeName(field: T, typeName: string): T { + const meta = new Map(field.metadata); + meta.set('databricks.type_name', typeName); + return new Field(field.name, field.type, field.nullable, meta) as T; +} + +// Build a single IPC stream (schema header + 1 record-batch message) +// from a Schema and a column->values mapping. Mirrors the kernel's +// per-batch ResultStream output shape. +function ipcFromColumns(schema: Schema, columns: Record): Buffer { + const vectors: any[] = []; + for (const field of schema.fields) { + const col = columns[field.name]; + vectors.push(vectorFromArray(col, field.type)); + } + const data = vectors.map((v) => v.data[0]); + const struct = makeData({ + type: new Struct(schema.fields), + children: data, + length: data[0]?.length ?? 0, + nullCount: 0, + }); + const batch = new RecordBatch(schema, struct); + const table = new Table([batch]); + return Buffer.from(tableToIPC(table, 'stream')); +} + +function ipcSchemaOnly(schema: Schema): Buffer { + // tableToIPC on an empty table produces a schema-only stream. + const struct = makeData({ + type: new Struct(schema.fields), + children: schema.fields.map((f) => makeData({ type: f.type as any, length: 0, nullCount: 0 })), + length: 0, + nullCount: 0, + }); + const batch = new RecordBatch(schema, struct); + const table = new Table([batch]); + return Buffer.from(tableToIPC(table, 'stream')); +} + +describe('SeaOperationBackend — M0 datatype round-trip via napi → ArrowResultConverter', () => { + it('passes M0 primitive datatypes through the same converter the thrift path uses', async () => { + // One row per M0 primitive type with a kernel-style metadata tag on + // each field. Decimal carries a real scale (2) so the converter's + // Phase-1 scale division produces 1.5 from the unscaled bigint. + const fields = [ + withTypeName(new Field('b', new Bool(), true), 'BOOLEAN'), + withTypeName(new Field('i8', new Int8(), true), 'TINYINT'), + withTypeName(new Field('i16', new Int16(), true), 'SMALLINT'), + withTypeName(new Field('i32', new Int32(), true), 'INT'), + withTypeName(new Field('i64', new Int64(), true), 'BIGINT'), + withTypeName(new Field('f32', new Float32(), true), 'FLOAT'), + withTypeName(new Field('f64', new Float64(), true), 'DOUBLE'), + withTypeName(new Field('s', new Utf8(), true), 'STRING'), + withTypeName(new Field('bin', new Binary(), true), 'BINARY'), + withTypeName(new Field('dt', new DateDay(), true), 'DATE'), + withTypeName( + new Field('ts', new TimestampMicrosecond(), true), + 'TIMESTAMP', + ), + // apache-arrow's Decimal signature is `(scale, precision, bitWidth)`. + withTypeName(new Field('dec', new Decimal(2, 10, 128), true), 'DECIMAL'), + // INTERVAL on the kernel side: Utf8 + metadata annotation. + withTypeName(new Field('iv', new Utf8(), true), 'INTERVAL'), + ]; + const schema = new Schema(fields); + const schemaIpc = ipcSchemaOnly(schema); + + // DECIMAL: 128-bit little-endian unscaled integer. 150 little-endian + // → [150, 0, 0, 0, ...0]. Phase-1 reads `valueType.scale` (=2) so the + // converter divides by 100 to yield 1.5. + const decimalBytes = new Uint8Array(16); + decimalBytes[0] = 150; + const dataIpc = ipcFromColumns(schema, { + b: [true], + i8: [Int8Array.from([1])[0]], + i16: [Int16Array.from([200])[0]], + i32: [42], + i64: [BigInt(1234567890123)], + f32: [Math.fround(1.5)], + f64: [3.14], + s: ['hello'], + bin: [new Uint8Array([0xde, 0xad, 0xbe, 0xef])], + dt: [new Date('2026-01-01T00:00:00Z')], + // Builder for TimestampMicrosecond accepts numeric epoch-ms; the + // internal scaling multiplies by 1000 to land on µs. + ts: [new Date('2026-05-15T12:00:00Z').valueOf()], + dec: [decimalBytes], + iv: ['1-0'], + }); + + const stub = new StatementStub(schemaIpc, [dataIpc]); + const backend = new SeaOperationBackend({ + statement: stub, + context: new ClientContextStub(), + }); + + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows.length).to.equal(1); + const row = rows[0] as Record; + + expect(row.b).to.equal(true); + expect(row.i8).to.equal(1); + expect(row.i16).to.equal(200); + expect(row.i32).to.equal(42); + // BIGINT goes through Phase-2 convertBigInt → Number (matches thrift) + expect(row.i64).to.equal(1234567890123); + expect(row.f32).to.equal(Math.fround(1.5)); + expect(row.f64).to.equal(3.14); + expect(row.s).to.equal('hello'); + expect(Buffer.isBuffer(row.bin)).to.equal(true); + expect((row.bin as Buffer).equals(Buffer.from([0xde, 0xad, 0xbe, 0xef]))).to.equal(true); + // DECIMAL: Phase-1 scale-aware coercion via Arrow's Decimal type → 1.5 + expect(row.dec).to.equal(1.5); + // TIMESTAMP: Phase-1 produces JS Date for arrow timestamps + expect(row.ts).to.be.instanceOf(Date); + expect((row.ts as Date).toISOString()).to.equal('2026-05-15T12:00:00.000Z'); + // INTERVAL: kernel emits Utf8 + metadata; converter passes through as string + expect(row.iv).to.equal('1-0'); + + // After consuming the single batch, the backend should report no more rows. + expect(await backend.hasMore()).to.equal(false); + }); + + it('round-trips ARRAY / MAP / STRUCT via the converter Phase-2 JSON fallback', async () => { + // ARRAY / MAP / STRUCT have two possible wire encodings in M0: + // (a) native Arrow `List` / `Map` / `Struct` — Phase 1 produces plain + // JS objects; Phase 2 `convertJSON` sees a non-string and is a + // no-op (`utils.ts:39-49`). + // (b) Utf8 JSON strings — Phase 1 passthrough; Phase 2 `convertJSON` + // runs `JSON.parse` (`utils.ts:75-79`). + // Both produce identical row shapes. We validate (b) here because + // it's the deterministic case we can construct with the current + // apache-arrow JS API; the kernel emits either depending on server + // config (see `findings/rust-kernel/datatype-emission...:140-142`). + const strSchema = new Schema([ + withTypeName(new Field('arr', new Utf8(), true), 'ARRAY'), + withTypeName(new Field('m', new Utf8(), true), 'MAP'), + withTypeName(new Field('s', new Utf8(), true), 'STRUCT'), + ]); + const strSchemaIpc = ipcSchemaOnly(strSchema); + const strDataIpc = ipcFromColumns(strSchema, { + arr: ['[1,2,3]'], + m: ['{"k":1}'], + s: ['{"a":1,"b":"hi"}'], + }); + + const stub = new StatementStub(strSchemaIpc, [strDataIpc]); + const backend = new SeaOperationBackend({ + statement: stub, + context: new ClientContextStub(), + }); + const rows = await backend.fetchChunk({ limit: 100 }); + expect(rows.length).to.equal(1); + const row = rows[0] as Record; + expect(row.arr).to.deep.equal([1, 2, 3]); + expect(row.m).to.deep.equal({ k: 1 }); + expect(row.s).to.deep.equal({ a: 1, b: 'hi' }); + }); + + it('streams multiple batches and reports hasMore correctly', async () => { + const schema = new Schema([withTypeName(new Field('x', new Int32(), true), 'INT')]); + const schemaIpc = ipcSchemaOnly(schema); + const batch1 = ipcFromColumns(schema, { x: [1, 2] }); + const batch2 = ipcFromColumns(schema, { x: [3] }); + + const stub = new StatementStub(schemaIpc, [batch1, batch2]); + const backend = new SeaOperationBackend({ + statement: stub, + context: new ClientContextStub(), + }); + + const all = await backend.fetchChunk({ limit: 10 }); + expect(all).to.deep.equal([{ x: 1 }, { x: 2 }, { x: 3 }]); + expect(await backend.hasMore()).to.equal(false); + }); + + it('cancel / close delegate to the native statement', async () => { + const schema = new Schema([withTypeName(new Field('x', new Int32(), true), 'INT')]); + const schemaIpc = ipcSchemaOnly(schema); + const stub = new StatementStub(schemaIpc, []); + const backend = new SeaOperationBackend({ statement: stub, context: new ClientContextStub() }); + await backend.cancel(); + expect(stub.cancelled).to.equal(true); + await backend.close(); + expect(stub.closed).to.equal(true); + }); +}); diff --git a/tests/unit/sea/auth-pat.test.ts b/tests/unit/sea/auth-pat.test.ts new file mode 100644 index 00000000..21d5d629 --- /dev/null +++ b/tests/unit/sea/auth-pat.test.ts @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import { buildSeaConnectionOptions } from '../../../lib/sea/SeaAuth'; +import { ConnectionOptions } from '../../../lib/contracts/IDBSQLClient'; +import AuthenticationError from '../../../lib/errors/AuthenticationError'; +import HiveDriverError from '../../../lib/errors/HiveDriverError'; + +describe('SeaAuth — PAT auth options builder', () => { + describe('buildSeaConnectionOptions', () => { + it('accepts a bare access-token PAT (undefined authType)', () => { + const opts: ConnectionOptions = { + host: 'example.cloud.databricks.com', + path: '/sql/1.0/warehouses/abc', + token: 'dapi-fake-pat', + }; + + const native = buildSeaConnectionOptions(opts); + expect(native).to.deep.equal({ + hostName: 'example.cloud.databricks.com', + httpPath: '/sql/1.0/warehouses/abc', + token: 'dapi-fake-pat', + }); + }); + + it('accepts an explicit access-token PAT', () => { + const opts: ConnectionOptions = { + host: 'example.cloud.databricks.com', + path: '/sql/1.0/warehouses/abc', + authType: 'access-token', + token: 'dapi-fake-pat', + }; + + const native = buildSeaConnectionOptions(opts); + expect(native.token).to.equal('dapi-fake-pat'); + }); + + it('prepends `/` to a path missing the leading slash', () => { + const opts: ConnectionOptions = { + host: 'example.cloud.databricks.com', + path: 'sql/1.0/warehouses/abc', + token: 'dapi-fake-pat', + }; + + const native = buildSeaConnectionOptions(opts); + expect(native.httpPath).to.equal('/sql/1.0/warehouses/abc'); + }); + + it('throws AuthenticationError when token is missing', () => { + const opts = { + host: 'example.cloud.databricks.com', + path: '/sql/1.0/warehouses/abc', + authType: 'access-token', + // no token + } as unknown as ConnectionOptions; + + expect(() => buildSeaConnectionOptions(opts)).to.throw(AuthenticationError, /non-empty PAT/); + }); + + it('throws AuthenticationError when token is an empty string', () => { + const opts: ConnectionOptions = { + host: 'example.cloud.databricks.com', + path: '/sql/1.0/warehouses/abc', + token: '', + }; + + expect(() => buildSeaConnectionOptions(opts)).to.throw(AuthenticationError, /non-empty PAT/); + }); + + it('rejects OAuth with a clear M0-scope error', () => { + const opts: ConnectionOptions = { + host: 'example.cloud.databricks.com', + path: '/sql/1.0/warehouses/abc', + authType: 'databricks-oauth', + }; + + expect(() => buildSeaConnectionOptions(opts)).to.throw( + HiveDriverError, + /M0\) supports only PAT.*databricks-oauth.*M1/, + ); + }); + + it('rejects token-provider with a clear M0-scope error', () => { + const opts: ConnectionOptions = { + host: 'example.cloud.databricks.com', + path: '/sql/1.0/warehouses/abc', + authType: 'token-provider', + tokenProvider: { getToken: async () => 'tok' } as unknown as ConnectionOptions extends infer T + ? // eslint-disable-next-line @typescript-eslint/no-explicit-any + any + : never, + }; + + expect(() => buildSeaConnectionOptions(opts)).to.throw(HiveDriverError, /token-provider.*M1/); + }); + + it('rejects external-token, static-token, and custom auth modes', () => { + const authTypes = ['external-token', 'static-token', 'custom'] as const; + for (const authType of authTypes) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const opts = { + host: 'h', + path: '/p', + authType, + } as any; + expect(() => buildSeaConnectionOptions(opts)).to.throw(HiveDriverError, /M0\) supports only PAT/); + } + }); + }); + + // Note: SeaBackend.connect/openSession round-trip + error-path coverage + // moved to tests/unit/sea/execution.test.ts during the sea-integration + // merge (the execution branch's SeaBackend constructor signature + // {context, nativeBinding} supersedes the auth-only (binding) shape). +}); diff --git a/tests/unit/sea/error-mapping.test.ts b/tests/unit/sea/error-mapping.test.ts new file mode 100644 index 00000000..8331bc57 --- /dev/null +++ b/tests/unit/sea/error-mapping.test.ts @@ -0,0 +1,227 @@ +import { expect } from 'chai'; +import { + mapKernelErrorToJsError, + KernelErrorCode, + KernelErrorShape, +} from '../../../lib/sea/SeaErrorMapping'; +import HiveDriverError from '../../../lib/errors/HiveDriverError'; +import AuthenticationError from '../../../lib/errors/AuthenticationError'; +import OperationStateError, { + OperationStateErrorCode, +} from '../../../lib/errors/OperationStateError'; +import ParameterError from '../../../lib/errors/ParameterError'; + +describe('SeaErrorMapping.mapKernelErrorToJsError', () => { + // The 13 kernel ErrorCode variants — kept in sync with src/kernel_error.rs:66-134. + // Tabular driver: each row is (kernel code, expected class, optional extra assertion). + type Case = { + code: KernelErrorCode; + expectedClass: Function; + extra?: (err: Error) => void; + }; + + const cases: Array = [ + { + code: 'InvalidArgument', + expectedClass: ParameterError, + }, + { + code: 'Unauthenticated', + expectedClass: AuthenticationError, + }, + { + code: 'PermissionDenied', + expectedClass: AuthenticationError, + }, + { + code: 'NotFound', + expectedClass: HiveDriverError, + }, + { + code: 'ResourceExhausted', + expectedClass: HiveDriverError, + }, + { + code: 'Unavailable', + expectedClass: HiveDriverError, + }, + { + code: 'Timeout', + expectedClass: OperationStateError, + extra: (err) => { + expect((err as OperationStateError).errorCode).to.equal(OperationStateErrorCode.Timeout); + }, + }, + { + code: 'Cancelled', + expectedClass: OperationStateError, + extra: (err) => { + expect((err as OperationStateError).errorCode).to.equal(OperationStateErrorCode.Canceled); + }, + }, + { + code: 'DataLoss', + expectedClass: HiveDriverError, + }, + { + code: 'Internal', + expectedClass: HiveDriverError, + }, + { + code: 'InvalidStatementHandle', + expectedClass: HiveDriverError, + }, + { + code: 'NetworkError', + expectedClass: HiveDriverError, + }, + { + code: 'SqlError', + expectedClass: HiveDriverError, + }, + ]; + + it('covers all 13 kernel ErrorCode variants', () => { + // Guardrail: if the kernel adds a variant, KernelErrorCode in TS will gain + // a literal — this test then fails because the new variant has no case row. + // (Drift is caught at the test level since the union itself is an inline literal.) + expect(cases).to.have.lengthOf(13); + }); + + cases.forEach(({ code, expectedClass, extra }) => { + it(`maps ${code} to ${expectedClass.name}`, () => { + const kErr: KernelErrorShape = { + code, + message: `kernel ${code} message`, + }; + + const err = mapKernelErrorToJsError(kErr); + + expect(err).to.be.instanceOf(expectedClass); + expect(err.message).to.equal(`kernel ${code} message`); + if (extra) { + extra(err); + } + }); + }); + + describe('SQLSTATE preservation', () => { + it('attaches sqlState when present on the kernel error', () => { + const err = mapKernelErrorToJsError({ + code: 'SqlError', + message: 'syntax error', + sqlstate: '42000', + }); + + expect(err).to.be.instanceOf(HiveDriverError); + expect(err.sqlState).to.equal('42000'); + }); + + it('does not set sqlState when absent', () => { + const err = mapKernelErrorToJsError({ + code: 'Internal', + message: 'boom', + }); + + expect(err.sqlState).to.be.undefined; + }); + + it('preserves sqlState on AuthenticationError', () => { + const err = mapKernelErrorToJsError({ + code: 'Unauthenticated', + message: 'invalid token', + sqlstate: '28000', + }); + + expect(err).to.be.instanceOf(AuthenticationError); + expect(err.sqlState).to.equal('28000'); + }); + + it('preserves sqlState on OperationStateError', () => { + const err = mapKernelErrorToJsError({ + code: 'Timeout', + message: 'deadline exceeded', + sqlstate: 'HYT01', + }); + + expect(err).to.be.instanceOf(OperationStateError); + expect((err as OperationStateError).errorCode).to.equal(OperationStateErrorCode.Timeout); + expect(err.sqlState).to.equal('HYT01'); + }); + + it('preserves sqlState on ParameterError', () => { + const err = mapKernelErrorToJsError({ + code: 'InvalidArgument', + message: 'bad param', + sqlstate: 'HY009', + }); + + expect(err).to.be.instanceOf(ParameterError); + expect(err.sqlState).to.equal('HY009'); + }); + + it('attaches sqlState as a non-enumerable property', () => { + const err = mapKernelErrorToJsError({ + code: 'SqlError', + message: 'oops', + sqlstate: '42000', + }); + + const descriptor = Object.getOwnPropertyDescriptor(err, 'sqlState'); + expect(descriptor).to.exist; + expect(descriptor!.enumerable).to.equal(false); + expect(descriptor!.writable).to.equal(true); + expect(descriptor!.configurable).to.equal(true); + }); + }); + + describe('unknown / future kernel codes', () => { + it('falls back to HiveDriverError for an unrecognised code', () => { + const err = mapKernelErrorToJsError({ + code: 'SomeFutureVariantThatDoesNotExist', + message: 'forward-compat message', + }); + + // Never silently drop — must surface as the base driver class. + expect(err).to.be.instanceOf(HiveDriverError); + expect(err.message).to.equal('forward-compat message'); + }); + + it('still preserves sqlState on a fallback HiveDriverError', () => { + const err = mapKernelErrorToJsError({ + code: 'BrandNewVariant', + message: 'with sqlstate', + sqlstate: '01004', + }); + + expect(err).to.be.instanceOf(HiveDriverError); + expect(err.sqlState).to.equal('01004'); + }); + }); + + describe('returned errors compose with try/catch', () => { + it('thrown errors are catchable as Error', () => { + function thrower() { + throw mapKernelErrorToJsError({ code: 'Internal', message: 'kaboom' }); + } + + expect(thrower).to.throw(Error, 'kaboom'); + expect(thrower).to.throw(HiveDriverError, 'kaboom'); + }); + + it('AuthenticationError thrown is also instanceOf HiveDriverError', () => { + // AuthenticationError extends HiveDriverError — preserve that hierarchy. + const err = mapKernelErrorToJsError({ code: 'Unauthenticated', message: 'nope' }); + expect(err).to.be.instanceOf(AuthenticationError); + expect(err).to.be.instanceOf(HiveDriverError); + expect(err).to.be.instanceOf(Error); + }); + + it('ParameterError does NOT extend HiveDriverError (matches existing class hierarchy)', () => { + const err = mapKernelErrorToJsError({ code: 'InvalidArgument', message: 'bad' }); + expect(err).to.be.instanceOf(ParameterError); + expect(err).to.not.be.instanceOf(HiveDriverError); + expect(err).to.be.instanceOf(Error); + }); + }); +}); diff --git a/tests/unit/sea/execution.test.ts b/tests/unit/sea/execution.test.ts new file mode 100644 index 00000000..3d416716 --- /dev/null +++ b/tests/unit/sea/execution.test.ts @@ -0,0 +1,489 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import sinon from 'sinon'; +import SeaBackend from '../../../lib/sea/SeaBackend'; +import SeaSessionBackend from '../../../lib/sea/SeaSessionBackend'; +import SeaOperationBackend from '../../../lib/sea/SeaOperationBackend'; +import { + SeaNativeBinding, + SeaNativeConnection, + SeaNativeStatement, + SeaExecuteOptions, +} from '../../../lib/sea/SeaNativeLoader'; +import IClientContext, { ClientConfig } from '../../../lib/contracts/IClientContext'; +import IDBSQLLogger, { LogLevel } from '../../../lib/contracts/IDBSQLLogger'; +import HiveDriverError from '../../../lib/errors/HiveDriverError'; +import { ConnectionOptions } from '../../../lib/contracts/IDBSQLClient'; + +// ----------------------------------------------------------------------------- +// Fakes — minimal stand-ins for the napi-rs generated surface and the +// IClientContext side of the abstraction. Keeping them inline avoids +// pulling in test-only fixtures from outside the sea/ namespace. +// ----------------------------------------------------------------------------- + +class FakeNativeStatement implements SeaNativeStatement { + public closed = false; + + public cancelled = false; + + // Mirrors the kernel `Statement.statementId` getter. + public readonly statementId = '01ef-fake-statement-id'; + + public async fetchNextBatch() { + return null; + } + + // schema() is synchronous on the merged-kernel binding. + public schema() { + return { ipcBytes: Buffer.alloc(0) }; + } + + public async cancel() { + this.cancelled = true; + } + + public async close() { + this.closed = true; + } + + // Status accessors added by the kernel's status-fields surface. + public async numModifiedRows(): Promise { + return null; + } + + public async displayMessage(): Promise { + return null; + } + + public async diagnosticInfo(): Promise { + return null; + } + + public async errorDetailsJson(): Promise { + return null; + } +} + +class FakeNativeConnection implements SeaNativeConnection { + public closed = false; + + public lastSql?: string; + + public lastOptions?: SeaExecuteOptions; + + public throwOnExecute: Error | null = null; + + public statementToReturn: FakeNativeStatement = new FakeNativeStatement(); + + // Mirrors the kernel `Connection.sessionId` getter. + public readonly sessionId = '01ef-fake-session-id'; + + // `options` is optional so this stays structurally assignable to the + // merged binding's `executeStatement(sql)` while still recording any + // per-statement options the caller forwards (the kernel now applies + // those at session level — see the session-level options migration). + public async executeStatement(sql: string, options?: SeaExecuteOptions): Promise { + if (this.throwOnExecute) { + throw this.throwOnExecute; + } + this.lastSql = sql; + this.lastOptions = options; + return this.statementToReturn; + } + + public async close(): Promise { + this.closed = true; + } +} + +function makeBinding(connection: SeaNativeConnection): SeaNativeBinding & { + openSessionStub: sinon.SinonStub; +} { + const openSessionStub = sinon.stub().resolves(connection); + const binding: SeaNativeBinding = { + version: () => 'test', + openSession: openSessionStub, + // Index the binding type for the class constructor types; `typeof + // Connection` is illegal since they're exported as type aliases. + Connection: function Connection() {} as unknown as SeaNativeBinding['Connection'], + Statement: function Statement() {} as unknown as SeaNativeBinding['Statement'], + }; + return Object.assign(binding, { openSessionStub }); +} + +function makeContext(): IClientContext { + const logger: IDBSQLLogger = { + log(_level: LogLevel, _message: string): void { + // no-op + }, + }; + const config = {} as ClientConfig; + return { + getConfig: () => config, + getLogger: () => logger, + getConnectionProvider: async () => { + throw new Error('not used by SEA backend'); + }, + getClient: async () => { + throw new Error('not used by SEA backend'); + }, + getDriver: async () => { + throw new Error('not used by SEA backend'); + }, + }; +} + +// ----------------------------------------------------------------------------- +// Tests +// ----------------------------------------------------------------------------- + +describe('SeaBackend', () => { + it('connect() captures the connection options and validates PAT auth', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + + await backend.connect({ + host: 'example.databricks.com', + path: '/sql/1.0/warehouses/abc', + token: 'dapi-token', + } as ConnectionOptions); + + // openSession should not have been called by connect() + expect(binding.openSessionStub.called).to.equal(false); + }); + + it('connect() rejects non-PAT auth (M0 PAT-only)', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + + let thrown: unknown; + try { + await backend.connect({ + host: 'example.databricks.com', + path: '/sql/1.0/warehouses/abc', + authType: 'databricks-oauth', + } as ConnectionOptions); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + expect((thrown as Error).message).to.match(/access-token/); + }); + + it('connect() rejects missing token', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + + let thrown: unknown; + try { + await backend.connect({ + host: 'example.databricks.com', + path: '/sql/1.0/warehouses/abc', + token: '', + } as ConnectionOptions); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + // After sea-integration merge, missing-token validation goes through + // SeaAuth.buildSeaConnectionOptions which throws AuthenticationError + // (extends HiveDriverError) with the "non-empty PAT" message. + expect((thrown as Error).message).to.match(/non-empty PAT/); + }); + + it('openSession() throws if connect() was not called', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + + let thrown: unknown; + try { + await backend.openSession({}); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + expect((thrown as Error).message).to.match(/not connected/); + }); + + it('openSession() forwards hostName / httpPath / token to napi binding', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + + await backend.connect({ + host: 'workspace.example', + path: '/sql/1.0/warehouses/xyz', + token: 'dapi-token', + } as ConnectionOptions); + + await backend.openSession({}); + + expect(binding.openSessionStub.calledOnce).to.equal(true); + const args = binding.openSessionStub.firstCall.args[0]; + expect(args).to.deep.equal({ + hostName: 'workspace.example', + httpPath: '/sql/1.0/warehouses/xyz', + token: 'dapi-token', + }); + }); + + it('openSession() returns a SeaSessionBackend wrapping the napi Connection', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + + await backend.connect({ + host: 'h', + path: '/p', + token: 't', + } as ConnectionOptions); + + const sessionBackend = await backend.openSession({}); + expect(sessionBackend).to.be.instanceOf(SeaSessionBackend); + expect(sessionBackend.id).to.be.a('string').and.have.length.greaterThan(0); + }); + + it('openSession() propagates initialCatalog / initialSchema / sessionConfig through to executeStatement', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + + await backend.connect({ + host: 'h', + path: '/p', + token: 't', + } as ConnectionOptions); + + const session = await backend.openSession({ + initialCatalog: 'main', + initialSchema: 'default', + configuration: { 'spark.sql.execution.arrow.enabled': 'true' }, + }); + + await session.executeStatement('SELECT 1', {}); + + expect(connection.lastSql).to.equal('SELECT 1'); + expect(connection.lastOptions).to.deep.equal({ + initialCatalog: 'main', + initialSchema: 'default', + sessionConfig: { 'spark.sql.execution.arrow.enabled': 'true' }, + }); + }); + + it('close() clears connection state without throwing', async () => { + const connection = new FakeNativeConnection(); + const binding = makeBinding(connection); + const backend = new SeaBackend({ context: makeContext(), nativeBinding: binding }); + await backend.connect({ host: 'h', path: '/p', token: 't' } as ConnectionOptions); + await backend.close(); + + let thrown: unknown; + try { + await backend.openSession({}); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + }); +}); + +describe('SeaSessionBackend', () => { + function makeSession(connection: SeaNativeConnection, defaults = {}) { + return new SeaSessionBackend({ connection, context: makeContext(), defaults }); + } + + it('executeStatement passes sql through verbatim', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + await session.executeStatement('SELECT * FROM foo', {}); + expect(connection.lastSql).to.equal('SELECT * FROM foo'); + }); + + it('executeStatement returns a SeaOperationBackend with an id', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + const op = await session.executeStatement('SELECT 1', {}); + expect(op).to.be.instanceOf(SeaOperationBackend); + expect(op.id).to.be.a('string').and.have.length.greaterThan(0); + }); + + it('executeStatement merges session defaults into ExecuteOptions', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection, { + initialCatalog: 'main', + initialSchema: 'default', + sessionConfig: { foo: 'bar' }, + }); + await session.executeStatement('SELECT 1', {}); + expect(connection.lastOptions).to.deep.equal({ + initialCatalog: 'main', + initialSchema: 'default', + sessionConfig: { foo: 'bar' }, + }); + }); + + it('executeStatement rejects namedParameters (M1)', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + let thrown: unknown; + try { + await session.executeStatement('SELECT :x', { namedParameters: { x: 1 } }); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + expect((thrown as Error).message).to.match(/parameters/); + }); + + it('executeStatement rejects ordinalParameters (M1)', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + let thrown: unknown; + try { + await session.executeStatement('SELECT ?', { ordinalParameters: [1] }); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + }); + + it('executeStatement rejects queryTimeout (M1)', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + let thrown: unknown; + try { + await session.executeStatement('SELECT 1', { queryTimeout: 30 }); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + expect((thrown as Error).message).to.match(/queryTimeout/); + }); + + it('metadata methods throw deferred-M1 errors', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + for (const method of [ + 'getInfo', + 'getTypeInfo', + 'getCatalogs', + 'getSchemas', + 'getTables', + 'getTableTypes', + 'getColumns', + 'getFunctions', + 'getPrimaryKeys', + 'getCrossReference', + ] as const) { + let thrown: unknown; + try { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await (session as any)[method]({}); + } catch (err) { + thrown = err; + } + expect(thrown, `expected ${method} to throw`).to.be.instanceOf(HiveDriverError); + expect((thrown as Error).message).to.match(/M1|not implemented/); + } + }); + + it('close() forwards to the native connection', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + const status = await session.close(); + expect(connection.closed).to.equal(true); + expect(status.isSuccess).to.equal(true); + }); + + it('close() is idempotent', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + await session.close(); + // Second call should not re-invoke connection.close + connection.closed = false; + const status = await session.close(); + expect(connection.closed).to.equal(false); + expect(status.isSuccess).to.equal(true); + }); + + it('executeStatement fails after close()', async () => { + const connection = new FakeNativeConnection(); + const session = makeSession(connection); + await session.close(); + let thrown: unknown; + try { + await session.executeStatement('SELECT 1', {}); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + }); +}); + +describe('SeaOperationBackend', () => { + function makeOperation(statement: SeaNativeStatement = new FakeNativeStatement()) { + return new SeaOperationBackend({ statement, context: makeContext() }); + } + + it('id is a stable string', () => { + const op = makeOperation(); + expect(op.id).to.equal(op.id); + expect(op.id).to.be.a('string').and.have.length.greaterThan(0); + }); + + it('hasResultSet is true for M0', () => { + const op = makeOperation(); + expect(op.hasResultSet).to.equal(true); + }); + + it('cancel() forwards to napi Statement', async () => { + const stmt = new FakeNativeStatement(); + const op = makeOperation(stmt); + await op.cancel(); + expect(stmt.cancelled).to.equal(true); + }); + + it('cancel() is idempotent', async () => { + const stmt = new FakeNativeStatement(); + const op = makeOperation(stmt); + await op.cancel(); + stmt.cancelled = false; + await op.cancel(); + expect(stmt.cancelled).to.equal(false); + }); + + it('close() forwards to napi Statement', async () => { + const stmt = new FakeNativeStatement(); + const op = makeOperation(stmt); + await op.close(); + expect(stmt.closed).to.equal(true); + }); + + it('waitUntilReady() is a no-op (kernel internalises polling)', async () => { + const op = makeOperation(); + await op.waitUntilReady(); + }); + + // Note: after sea-integration merge, fetchChunk is no longer a stub — + // the sea-results SeaResultsProvider + ArrowResultConverter pipeline + // implements the real fetch path. Full coverage lives in + // tests/unit/sea/SeaOperationBackend.test.ts and the parity-gate e2e + // at tests/integration/sea/results-e2e.test.ts. +}); diff --git a/tests/unit/sea/loader.test.ts b/tests/unit/sea/loader.test.ts new file mode 100644 index 00000000..39bf610f --- /dev/null +++ b/tests/unit/sea/loader.test.ts @@ -0,0 +1,149 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import { SeaNativeLoader, SeaNativeBinding } from '../../../lib/sea/SeaNativeLoader'; + +// Pure-logic tests for SeaNativeLoader. These exercise the load-failure +// hint branches, the Node-version gate, the shape check, and caching via +// the injectable `load` seam — so they run everywhere regardless of +// whether a real `.node` is installed on the test machine. + +function stubBinding(overrides: Partial> = {}): SeaNativeBinding { + return { + version: () => '1.2.3', + openSession: async () => ({}), + Connection: function Connection() {}, + Statement: function Statement() {}, + ...overrides, + } as unknown as SeaNativeBinding; +} + +function errWithCode(code: string, message: string): NodeJS.ErrnoException { + const err = new Error(message) as NodeJS.ErrnoException; + err.code = code; + return err; +} + +// Capture the message of the error thrown by `fn` (fails the test if +// nothing is thrown). Lets a single failure be asserted against several +// substrings without chai's `.and.to.throw` re-targeting quirk. +function thrownMessage(fn: () => unknown): string { + try { + fn(); + } catch (err) { + return err instanceof Error ? err.message : String(err); + } + return expect.fail('expected the call to throw, but it did not') as never; +} + +describe('SeaNativeLoader', () => { + describe('successful load', () => { + it('get() returns the binding from the injected loader', () => { + const binding = stubBinding(); + const loader = new SeaNativeLoader(() => binding); + expect(loader.get()).to.equal(binding); + expect(loader.tryGet()).to.equal(binding); + }); + + it('caches the result — the load function runs at most once', () => { + let calls = 0; + const binding = stubBinding(); + const loader = new SeaNativeLoader(() => { + calls += 1; + return binding; + }); + loader.get(); + loader.tryGet(); + loader.get(); + expect(calls).to.equal(1); + }); + }); + + describe('load-failure hints', () => { + it('MODULE_NOT_FOUND → "not installed" hint pointing at the README', () => { + const loader = new SeaNativeLoader(() => { + throw errWithCode('MODULE_NOT_FOUND', "Cannot find module '../../native/sea'"); + }); + expect(loader.tryGet()).to.equal(undefined); + const msg = thrownMessage(() => loader.get()); + expect(msg).to.match(/not installed/); + expect(msg).to.match(/README/); + }); + + it('ERR_DLOPEN_FAILED → includes the underlying dlerror string and remediation', () => { + const loader = new SeaNativeLoader(() => { + throw errWithCode('ERR_DLOPEN_FAILED', 'GLIBC_2.32 not found'); + }); + const msg = thrownMessage(() => loader.get()); + expect(msg).to.match(/GLIBC_2\.32 not found/); + expect(msg).to.match(/musl/); + expect(msg).to.match(/rm -rf node_modules/); + }); + + it('a generic Error (no code) preserves its message', () => { + const loader = new SeaNativeLoader(() => { + throw new Error('totally unexpected'); + }); + expect(() => loader.get()).to.throw(/totally unexpected/); + }); + + it('a non-Error throw is wrapped', () => { + const loader = new SeaNativeLoader(() => { + // eslint-disable-next-line no-throw-literal + throw 'a string'; + }); + expect(() => loader.get()).to.throw(/non-standard error/); + }); + }); + + describe('shape check', () => { + it('rejects a binding missing an expected export', () => { + const loader = new SeaNativeLoader(() => stubBinding({ openSession: undefined })); + expect(loader.tryGet()).to.equal(undefined); + const msg = thrownMessage(() => loader.get()); + expect(msg).to.match(/missing expected export/); + expect(msg).to.match(/openSession/); + }); + }); + + describe('Node-version gate', () => { + it('fails closed on a Node version below the floor', () => { + const original = process.version; + try { + Object.defineProperty(process, 'version', { value: 'v16.20.0', configurable: true }); + let loadCalled = false; + const loader = new SeaNativeLoader(() => { + loadCalled = true; + return stubBinding(); + }); + expect(() => loader.get()).to.throw(/requires Node >=18/); + expect(loadCalled, 'load() must not be attempted on an unsupported Node').to.equal(false); + } finally { + Object.defineProperty(process, 'version', { value: original, configurable: true }); + } + }); + + it('fails closed when the Node version is unparseable (NaN)', () => { + const original = process.version; + try { + Object.defineProperty(process, 'version', { value: 'vNOT-A-VERSION', configurable: true }); + const loader = new SeaNativeLoader(() => stubBinding()); + expect(() => loader.get()).to.throw(/requires Node >=18/); + } finally { + Object.defineProperty(process, 'version', { value: original, configurable: true }); + } + }); + }); +}); diff --git a/tests/unit/sea/operation-lifecycle.test.ts b/tests/unit/sea/operation-lifecycle.test.ts new file mode 100644 index 00000000..7f93184c --- /dev/null +++ b/tests/unit/sea/operation-lifecycle.test.ts @@ -0,0 +1,449 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Unit tests for the SEA operation lifecycle (`cancel`, `close`, + * `finished`) — both via the `SeaOperationLifecycle` helpers and + * via `SeaOperationBackend` which composes them. + * + * We mock the napi binding's `Statement` handle so the test process + * doesn't touch any native code; the helpers and the backend are + * structurally typed against `SeaStatementHandle` exactly so this + * works. + */ + +import { expect } from 'chai'; +import sinon from 'sinon'; +import { + TOperationState, + TStatusCode, + TGetOperationStatusResp, +} from '../../../thrift/TCLIService_types'; +import IClientContext from '../../../lib/contracts/IClientContext'; +import IDBSQLLogger, { LogLevel } from '../../../lib/contracts/IDBSQLLogger'; +import { + SeaStatementHandle, + createLifecycleState, + seaCancel, + seaClose, + seaFinished, + failIfNotActive, +} from '../../../lib/sea/SeaOperationLifecycle'; +import SeaOperationBackend from '../../../lib/sea/SeaOperationBackend'; +import OperationStateError, { + OperationStateErrorCode, +} from '../../../lib/errors/OperationStateError'; +import { OperationState } from '../../../lib/contracts/OperationStatus'; +import HiveDriverError from '../../../lib/errors/HiveDriverError'; + +class TestLogger implements IDBSQLLogger { + public readonly entries: Array<{ level: LogLevel; message: string }> = []; + + log(level: LogLevel, message: string): void { + this.entries.push({ level, message }); + } +} + +function makeContext(): IClientContext { + const logger = new TestLogger(); + // Only `getLogger` is exercised by the lifecycle helpers; the rest + // of `IClientContext` is stubbed to throw so accidental coupling + // to it shows up loudly in tests. + const notUsed = () => { + throw new Error('IClientContext member not expected to be used by lifecycle'); + }; + return { + getConfig: notUsed, + getLogger: () => logger, + getConnectionProvider: notUsed, + getClient: notUsed, + getDriver: notUsed, + } as unknown as IClientContext; +} + +function makeStatement(overrides: Partial = {}): { + handle: SeaStatementHandle; + cancel: sinon.SinonStub; + close: sinon.SinonStub; +} { + const cancel = sinon.stub().resolves(); + const close = sinon.stub().resolves(); + return { + handle: { cancel, close, ...overrides }, + cancel, + close, + }; +} + +describe('SeaOperationLifecycle (helpers)', () => { + describe('seaCancel', () => { + it('calls statement.cancel() and resolves with a success Status', async () => { + const ctx = makeContext(); + const { handle, cancel } = makeStatement(); + const state = createLifecycleState(); + + const status = await seaCancel(state, handle, ctx, 'op-id-1'); + + expect(cancel.calledOnce).to.equal(true); + expect(status.isSuccess).to.equal(true); + expect(state.isCancelled).to.equal(true); + }); + + it('is idempotent — second call does not hit the binding', async () => { + const ctx = makeContext(); + const { handle, cancel } = makeStatement(); + const state = createLifecycleState(); + + await seaCancel(state, handle, ctx, 'op-id-2'); + await seaCancel(state, handle, ctx, 'op-id-2'); + + expect(cancel.calledOnce).to.equal(true); + }); + + it('short-circuits when the operation is already closed', async () => { + const ctx = makeContext(); + const { handle, cancel } = makeStatement(); + const state = createLifecycleState(); + state.isClosed = true; + + const status = await seaCancel(state, handle, ctx, 'op-id-3'); + + expect(cancel.called).to.equal(false); + expect(status.isSuccess).to.equal(true); + }); + + it('sets isCancelled BEFORE awaiting the binding (so concurrent fetch sees it)', async () => { + const ctx = makeContext(); + const state = createLifecycleState(); + + // Cancel returns a promise that resolves only when we say so. + let release: (() => void) | undefined; + const cancelPromise = new Promise((resolve) => { + release = resolve; + }); + const handle: SeaStatementHandle = { + cancel: () => cancelPromise, + close: async () => undefined, + }; + + const inflight = seaCancel(state, handle, ctx, 'op-id-4'); + + // Yield once so the synchronous prelude of seaCancel runs. + await Promise.resolve(); + expect(state.isCancelled).to.equal(true); + // Before the await resolves, failIfNotActive must already throw. + expect(() => failIfNotActive(state)).to.throw(); + + release!(); + const status = await inflight; + expect(status.isSuccess).to.equal(true); + }); + + it('propagates binding errors via the kernel error mapping', async () => { + const ctx = makeContext(); + const state = createLifecycleState(); + const handle: SeaStatementHandle = { + cancel: async () => { + // Simulate the binding's JSON-envelope error format. + const payload = JSON.stringify({ + code: 'InvalidStatementHandle', + message: 'statement already closed', + }); + throw new Error(`__databricks_error__:${payload}`); + }, + close: async () => undefined, + }; + + let thrown: unknown; + try { + await seaCancel(state, handle, ctx, 'op-err-1'); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + expect((thrown as Error).message).to.contain('statement already closed'); + }); + + it('logs a debug message tagged with the operation id', async () => { + const ctx = makeContext(); + const logger = ctx.getLogger() as TestLogger; + const { handle } = makeStatement(); + const state = createLifecycleState(); + + await seaCancel(state, handle, ctx, 'op-id-log'); + + expect( + logger.entries.some( + (e) => e.level === LogLevel.debug && e.message.includes('op-id-log'), + ), + ).to.equal(true); + }); + }); + + describe('seaClose', () => { + it('calls statement.close() and resolves with a success Status', async () => { + const ctx = makeContext(); + const { handle, close } = makeStatement(); + const state = createLifecycleState(); + + const status = await seaClose(state, handle, ctx, 'op-close-1'); + + expect(close.calledOnce).to.equal(true); + expect(status.isSuccess).to.equal(true); + expect(state.isClosed).to.equal(true); + }); + + it('is idempotent — second call does not hit the binding', async () => { + const ctx = makeContext(); + const { handle, close } = makeStatement(); + const state = createLifecycleState(); + + await seaClose(state, handle, ctx, 'op-close-2'); + await seaClose(state, handle, ctx, 'op-close-2'); + + expect(close.calledOnce).to.equal(true); + }); + + it('propagates binding errors via the kernel error mapping', async () => { + const ctx = makeContext(); + const state = createLifecycleState(); + const handle: SeaStatementHandle = { + cancel: async () => undefined, + close: async () => { + const payload = JSON.stringify({ + code: 'NetworkError', + message: 'connection reset by peer', + }); + throw new Error(`__databricks_error__:${payload}`); + }, + }; + + let thrown: unknown; + try { + await seaClose(state, handle, ctx, 'op-err-close'); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(HiveDriverError); + expect((thrown as Error).message).to.contain('connection reset'); + }); + }); + + describe('seaFinished', () => { + it('resolves immediately when no callback is provided (M0 no-op)', async () => { + const state = createLifecycleState(); + const start = Date.now(); + await seaFinished(state); + // Should be near-instantaneous — no 100ms poll. + expect(Date.now() - start).to.be.lessThan(50); + }); + + it('invokes the progress callback exactly once with a FINISHED status', async () => { + const state = createLifecycleState(); + const callback = sinon.stub(); + + await seaFinished(state, { callback }); + + expect(callback.calledOnce).to.equal(true); + const arg = callback.firstCall.args[0] as TGetOperationStatusResp; + expect(arg.operationState).to.equal(TOperationState.FINISHED_STATE); + expect(arg.status?.statusCode).to.equal(TStatusCode.SUCCESS_STATUS); + }); + + it('awaits an async progress callback', async () => { + const state = createLifecycleState(); + let resolvedInsideCallback = false; + const callback = async () => { + await new Promise((r) => setTimeout(r, 10)); + resolvedInsideCallback = true; + }; + + await seaFinished(state, { callback }); + + expect(resolvedInsideCallback).to.equal(true); + }); + + it('is a no-op when the operation is already cancelled', async () => { + const state = createLifecycleState(); + state.isCancelled = true; + const callback = sinon.stub(); + + await seaFinished(state, { callback }); + + expect(callback.called).to.equal(false); + }); + }); + + describe('failIfNotActive', () => { + it('throws OperationStateError(Canceled) when cancelled', () => { + const state = createLifecycleState(); + state.isCancelled = true; + // The kernel-error mapping routes Cancelled → OperationStateError. + try { + failIfNotActive(state); + expect.fail('expected throw'); + } catch (err) { + expect(err).to.be.instanceOf(OperationStateError); + expect((err as OperationStateError).errorCode).to.equal( + OperationStateErrorCode.Canceled, + ); + } + }); + + it('throws OperationStateError(Closed) when closed', () => { + const state = createLifecycleState(); + state.isClosed = true; + try { + failIfNotActive(state); + expect.fail('expected throw'); + } catch (err) { + expect(err).to.be.instanceOf(OperationStateError); + expect((err as OperationStateError).errorCode).to.equal( + OperationStateErrorCode.Closed, + ); + } + }); + + it('does nothing when active', () => { + const state = createLifecycleState(); + // Should not throw. + failIfNotActive(state); + }); + }); +}); + +describe('SeaOperationBackend (lifecycle integration)', () => { + it('cancel() forwards to statement.cancel()', async () => { + const ctx = makeContext(); + const { handle, cancel } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + const status = await op.cancel(); + + expect(cancel.calledOnce).to.equal(true); + expect(status.isSuccess).to.equal(true); + }); + + it('close() forwards to statement.close()', async () => { + const ctx = makeContext(); + const { handle, close } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + const status = await op.close(); + + expect(close.calledOnce).to.equal(true); + expect(status.isSuccess).to.equal(true); + }); + + it('finished() resolves immediately and fires the callback once', async () => { + const ctx = makeContext(); + const { handle } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + const responses: TGetOperationStatusResp[] = []; + const start = Date.now(); + await op.waitUntilReady({ callback: (r) => responses.push(r) }); + + expect(Date.now() - start).to.be.lessThan(50); + expect(responses).to.have.length(1); + expect(responses[0].operationState).to.equal(TOperationState.FINISHED_STATE); + }); + + it('fetchChunk after cancel throws the cancellation error', async () => { + const ctx = makeContext(); + const { handle } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + await op.cancel(); + + let thrown: unknown; + try { + await op.fetchChunk({ limit: 10 }); + } catch (err) { + thrown = err; + } + expect(thrown).to.be.instanceOf(OperationStateError); + expect((thrown as OperationStateError).errorCode).to.equal( + OperationStateErrorCode.Canceled, + ); + }); + + it('cancel() is idempotent across the backend surface', async () => { + const ctx = makeContext(); + const { handle, cancel } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + await op.cancel(); + await op.cancel(); + await op.cancel(); + + expect(cancel.calledOnce).to.equal(true); + }); + + it('close() is idempotent across the backend surface', async () => { + const ctx = makeContext(); + const { handle, close } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + await op.close(); + await op.close(); + + expect(close.calledOnce).to.equal(true); + }); + + it('status() reports Succeeded when active', async () => { + const ctx = makeContext(); + const { handle } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + const status = await op.status(false); + expect(status.state).to.equal(OperationState.Succeeded); + }); + + it('status() reports Cancelled after cancel', async () => { + const ctx = makeContext(); + const { handle } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + await op.cancel(); + const status = await op.status(false); + expect(status.state).to.equal(OperationState.Cancelled); + }); + + it('id getter is stable', () => { + const ctx = makeContext(); + const { handle } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx, id: 'fixed-id' }); + + expect(op.id).to.equal('fixed-id'); + expect(op.id).to.equal('fixed-id'); + }); + + it('id getter defaults to a uuid when none is supplied', () => { + const ctx = makeContext(); + const { handle } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + // RFC4122 v4 — 36 chars with hyphens at positions 8/13/18/23. + expect(op.id).to.match(/^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[0-9a-f]{4}-[0-9a-f]{12}$/); + }); + + it('hasResultSet is true by default (kernel always streams)', () => { + const ctx = makeContext(); + const { handle } = makeStatement(); + const op = new SeaOperationBackend({ statement: handle, context: ctx }); + + expect(op.hasResultSet).to.equal(true); + }); +}); diff --git a/tests/unit/sea/version.test.ts b/tests/unit/sea/version.test.ts new file mode 100644 index 00000000..a6c8c1fc --- /dev/null +++ b/tests/unit/sea/version.test.ts @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Databricks, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { expect } from 'chai'; +import { tryGetSeaNative } from '../../../lib/sea/SeaNativeLoader'; + +// On a CI runner whose triple is supposed to have a published binding +// (M0 = linux-x64-gnu) a missing binding is a hard failure — a silent +// skip there would mask a broken build / packaging regression. On every +// other platform (and on dev machines) the binding is optional, so we +// skip. +function bindingIsExpected(): boolean { + return process.env.CI === 'true' && process.platform === 'linux' && process.arch === 'x64'; +} + +describe('SEA native binding — smoke test', function smoke() { + const binding = tryGetSeaNative(); + + if (binding === undefined) { + if (bindingIsExpected()) { + it('fails loudly: the binding must load on the linux-x64 CI runner', () => { + expect.fail( + 'SEA native binding failed to load on a linux-x64 CI runner where ' + + '@databricks/sql-kernel-linux-x64-gnu is expected. Run `npm run build:native` or check packaging.', + ); + }); + return; + } + // Optional dependency absent on this platform — skip rather than fail. + // eslint-disable-next-line no-invalid-this + this.pending = true; + it.skip('SEA native binding not available on this platform'); + return; + } + + it('returns a semver version()', () => { + expect(binding.version()).to.match(/^\d+\.\d+\.\d+$/); + }); + + it('exposes the full binding surface the driver depends on', () => { + // Guards against kernel-side renames: if the kernel drops/renames a + // free function or class, this fails instead of staying green. + expect(binding.version, 'version()').to.be.a('function'); + expect(binding.openSession, 'openSession()').to.be.a('function'); + expect(binding.Connection, 'Connection class').to.be.a('function'); + expect(binding.Statement, 'Statement class').to.be.a('function'); + }); +}); diff --git a/tsconfig.json b/tsconfig.json index 9da406df..767f4166 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -6,7 +6,8 @@ "sourceMap": true, "strict": true, "esModuleInterop": true, - "forceConsistentCasingInFileNames": true + "forceConsistentCasingInFileNames": true, + "baseUrl": "./" }, "exclude": ["./dist/**/*"] }