Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
7 changes: 6 additions & 1 deletion 7 src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

logger = logging.getLogger(__name__)

DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600
Comment thread
mattdeekay marked this conversation as resolved.
DEFAULT_ARRAY_SIZE = 100000


Expand Down Expand Up @@ -153,6 +153,8 @@ def read(self) -> Optional[OAuthToken]:
# _use_arrow_native_timestamps
# Databricks runtime will return native Arrow types for timestamps instead of Arrow strings
# (True by default)
# use_cloud_fetch
# Enable use of cloud fetch to extract large query results in parallel via cloud storage

if access_token:
access_token_kv = {"access_token": access_token}
Expand Down Expand Up @@ -189,6 +191,7 @@ def read(self) -> Optional[OAuthToken]:
self._session_handle = self.thrift_backend.open_session(
session_configuration, catalog, schema
)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", False)
self.open = True
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
self._cursors = [] # type: List[Cursor]
Expand Down Expand Up @@ -497,6 +500,7 @@ def execute(
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
)
self.active_result_set = ResultSet(
self.connection,
Expand Down Expand Up @@ -822,6 +826,7 @@ def __iter__(self):
break

def _fill_results_buffer(self):
# At initialization or if the server does not have cloud fetch result links available
results, has_more_rows = self.thrift_backend.fetch_results(
op_handle=self.command_id,
max_rows=self.arraysize,
Expand Down
4 changes: 2 additions & 2 deletions 4 src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
return True

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool to cancel pending futures
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
self.thread_pool.shutdown(wait=False, cancel_futures=True)
self.thread_pool.shutdown(wait=False)
151 changes: 39 additions & 112 deletions 151 src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import uuid
import threading
import lz4.frame
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

Expand All @@ -26,11 +25,14 @@
)

from databricks.sql.utils import (
ArrowQueue,
ExecuteResponse,
_bound,
RequestErrorInfo,
NoRetryReason,
ResultSetQueueFactory,
convert_arrow_based_set_to_arrow_table,
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +69,6 @@
class ThriftBackend:
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]

def __init__(
self,
Expand Down Expand Up @@ -115,6 +116,8 @@ def __init__(
# _socket_timeout
# The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer.
# (defaults to 900)
# max_download_threads
# Number of threads for handling cloud fetch downloads. Defaults to 10

port = port or 443
if kwargs.get("_connection_uri"):
Expand All @@ -136,6 +139,9 @@ def __init__(
"_use_arrow_native_timestamps", True
)

# Cloud fetch
self.max_download_threads = kwargs.get("max_download_threads", 10)
Comment thread
mattdeekay marked this conversation as resolved.

# Configure tls context
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
if kwargs.get("_tls_no_verify") is True:
Expand Down Expand Up @@ -558,108 +564,14 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
(
arrow_table,
num_rows,
) = ThriftBackend._convert_column_based_set_to_arrow_table(
t_row_set.columns, description
)
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
elif t_row_set.arrowBatches is not None:
(
arrow_table,
num_rows,
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows

@staticmethod
def _convert_decimals_in_arrow_table(table, description):
for (i, col) in enumerate(table.itercolumns()):
if description[i][1] == "decimal":
decimal_col = col.to_pandas().apply(
lambda v: v if v is None else Decimal(v)
)
precision, scale = description[i][4], description[i][5]
assert scale is not None
assert precision is not None
# Spark limits decimal to a maximum scale of 38,
# so 128 is guaranteed to be big enough
dtype = pyarrow.decimal128(precision, scale)
col_data = pyarrow.array(decimal_col, type=dtype)
field = table.field(i).with_type(dtype)
table = table.set_column(i, field, col_data)
return table

@staticmethod
def _convert_arrow_based_set_to_arrow_table(
arrow_batches, lz4_compressed, schema_bytes
):
ba = bytearray()
ba += schema_bytes
n_rows = 0
if lz4_compressed:
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += lz4.frame.decompress(arrow_batch.batch)
else:
for arrow_batch in arrow_batches:
n_rows += arrow_batch.rowCount
ba += arrow_batch.batch
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
return arrow_table, n_rows

@staticmethod
def _convert_column_based_set_to_arrow_table(columns, description):
arrow_table = pyarrow.Table.from_arrays(
[ThriftBackend._convert_column_to_arrow_array(c) for c in columns],
# Only use the column names from the schema, the types are determined by the
# physical types used in column based set, as they can differ from the
# mapping used in _hive_schema_to_arrow_schema.
names=[c[0] for c in description],
)
return arrow_table, arrow_table.num_rows

@staticmethod
def _convert_column_to_arrow_array(t_col):
"""
Return a pyarrow array from the values in a TColumn instance.
Note that ColumnBasedSet has no native support for complex types, so they will be converted
to strings server-side.
"""
field_name_to_arrow_type = {
"boolVal": pyarrow.bool_(),
"byteVal": pyarrow.int8(),
"i16Val": pyarrow.int16(),
"i32Val": pyarrow.int32(),
"i64Val": pyarrow.int64(),
"doubleVal": pyarrow.float64(),
"stringVal": pyarrow.string(),
"binaryVal": pyarrow.binary(),
}
for field in field_name_to_arrow_type.keys():
wrapper = getattr(t_col, field)
if wrapper:
return ThriftBackend._create_arrow_array(
wrapper, field_name_to_arrow_type[field]
)

raise OperationalError("Empty TColumn instance {}".format(t_col))

@staticmethod
def _create_arrow_array(t_col_value_wrapper, arrow_type):
result = t_col_value_wrapper.values
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
assert isinstance(nulls, bytes)

# The number of bits in nulls can be both larger or smaller than the number of
# elements in result, so take the minimum of both to iterate over.
length = min(len(result), len(nulls) * 8)

for i in range(length):
if nulls[i >> 3] & ThriftBackend.BIT_MASKS[i & 0x7]:
result[i] = None

return pyarrow.array(result, type=arrow_type)
return convert_decimals_in_arrow_table(arrow_table, description), num_rows

def _get_metadata_resp(self, op_handle):
req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle)
Expand Down Expand Up @@ -752,6 +664,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
if t_result_set_metadata_resp.resultFormat not in [
ttypes.TSparkRowSetType.ARROW_BASED_SET,
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
ttypes.TSparkRowSetType.URL_BASED_SET,
]:
raise OperationalError(
"Expected results to be in Arrow or column based format, "
Expand Down Expand Up @@ -783,13 +696,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
assert direct_results.resultSet.results.startRowOffset == 0
assert direct_results.resultSetMetadata

arrow_results, n_rows = self._create_arrow_table(
direct_results.resultSet.results,
lz4_compressed,
schema_bytes,
description,
arrow_queue_opt = ResultSetQueueFactory.build_queue(
row_set_type=t_result_set_metadata_resp.resultFormat,
t_row_set=direct_results.resultSet.results,
arrow_schema_bytes=schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
)
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
else:
arrow_queue_opt = None
return ExecuteResponse(
Expand Down Expand Up @@ -843,7 +757,14 @@ def _check_direct_results_for_error(t_spark_direct_results):
)

def execute_command(
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
self,
operation,
session_handle,
max_rows,
max_bytes,
lz4_compression,
cursor,
use_cloud_fetch=False,
):
assert session_handle is not None

Expand All @@ -864,7 +785,7 @@ def execute_command(
),
canReadArrowResult=True,
canDecompressLZ4Result=lz4_compression,
canDownloadResult=False,
canDownloadResult=use_cloud_fetch,
confOverlay={
# We want to receive proper Timestamp arrow types.
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
Expand Down Expand Up @@ -993,6 +914,7 @@ def fetch_results(
maxRows=max_rows,
maxBytes=max_bytes,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
includeResultSetMetadata=True,
)

resp = self.make_request(self._client.FetchResults, req)
Expand All @@ -1002,12 +924,17 @@ def fetch_results(
expected_row_start_offset, resp.results.startRowOffset
)
)
arrow_results, n_rows = self._create_arrow_table(
resp.results, lz4_compressed, arrow_schema_bytes, description

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
t_row_set=resp.results,
arrow_schema_bytes=arrow_schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
)
arrow_queue = ArrowQueue(arrow_results, n_rows)

return arrow_queue, resp.hasMoreRows
return queue, resp.hasMoreRows

def close_command(self, op_handle):
req = ttypes.TCloseOperationReq(operationHandle=op_handle)
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.