Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
113 changes: 59 additions & 54 deletions 113 src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

if TYPE_CHECKING:
from databricks.sql.client import Cursor
from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.types import (
Expand Down Expand Up @@ -332,7 +333,7 @@ def _extract_description_from_manifest(
return columns

def _results_message_to_execute_response(
self, response: GetStatementResponse
self, response: Union[ExecuteStatementResponse, GetStatementResponse]
) -> ExecuteResponse:
"""
Convert a SEA response to an ExecuteResponse and extract result data.
Expand Down Expand Up @@ -366,6 +367,27 @@ def _results_message_to_execute_response(

return execute_response

def _response_to_result_set(
self,
response: Union[ExecuteStatementResponse, GetStatementResponse],
cursor: Cursor,
) -> SeaResultSet:
"""
Convert a SEA response to a SeaResultSet.
"""

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)

def _check_command_not_in_failed_or_closed_state(
self, state: CommandState, command_id: CommandId
) -> None:
Expand All @@ -386,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state(

def _wait_until_command_done(
self, response: ExecuteStatementResponse
) -> CommandState:
) -> Union[ExecuteStatementResponse, GetStatementResponse]:
"""
Wait until a command is done.
"""

state = response.status.state
command_id = CommandId.from_sea_statement_id(response.statement_id)
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response

state = final_response.status.state
command_id = CommandId.from_sea_statement_id(final_response.statement_id)

while state in [CommandState.PENDING, CommandState.RUNNING]:
time.sleep(self.POLL_INTERVAL_SECONDS)
state = self.get_query_state(command_id)
final_response = self._poll_query(command_id)
state = final_response.status.state

self._check_command_not_in_failed_or_closed_state(state, command_id)

return state
return final_response

def execute_command(
self,
Expand Down Expand Up @@ -506,8 +531,11 @@ def execute_command(
if async_op:
return None

self._wait_until_command_done(response)
return self.get_execution_result(command_id, cursor)
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
if response.status.state != CommandState.SUCCEEDED:
final_response = self._wait_until_command_done(response)

return self._response_to_result_set(final_response, cursor)

def cancel_command(self, command_id: CommandId) -> None:
"""
Expand Down Expand Up @@ -559,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None:
data=request.to_dict(),
)

def get_query_state(self, command_id: CommandId) -> CommandState:
def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ValueError: If the command ID is invalid
Poll for the current command info.
"""

if command_id.backend_type != BackendType.SEA:
Expand All @@ -586,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)

# Parse the response
response = GetStatementResponse.from_dict(response_data)

return response

def get_query_state(self, command_id: CommandId) -> CommandState:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
"""

response = self._poll_query(command_id)
return response.status.state

def get_execution_result(
Expand All @@ -610,38 +645,8 @@ def get_execution_result(
ValueError: If the command ID is invalid
"""

if command_id.backend_type != BackendType.SEA:
raise ValueError("Not a valid SEA command ID")

sea_statement_id = command_id.to_sea_statement_id()
if sea_statement_id is None:
raise ValueError("Not a valid SEA command ID")

# Create the request model
request = GetStatementRequest(statement_id=sea_statement_id)

# Get the statement result
response_data = self._http_client._make_request(
method="GET",
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = GetStatementResponse.from_dict(response_data)

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)
response = self._poll_query(command_id)
return self._response_to_result_set(response, cursor)

def get_chunk_links(
self, statement_id: str, chunk_index: int
Expand Down
2 changes: 1 addition & 1 deletion 2 src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter

Expand All @@ -15,6 +14,7 @@

if TYPE_CHECKING:
from databricks.sql.client import Connection
from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.types import Row
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
from databricks.sql.backend.types import ExecuteResponse
Expand Down
9 changes: 3 additions & 6 deletions 9 tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_command_execution_sync(
mock_http_client._make_request.return_value = execute_response

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
result = sea_client.execute_command(
operation="SELECT 1",
Expand All @@ -242,9 +242,6 @@ def test_command_execution_sync(
enforce_embedded_schema_correctness=False,
)
assert result == "mock_result_set"
cmd_id_arg = mock_get_result.call_args[0][0]
assert isinstance(cmd_id_arg, CommandId)
assert cmd_id_arg.guid == "test-statement-123"

# Test with invalid session ID
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -332,7 +329,7 @@ def test_command_execution_advanced(
mock_http_client._make_request.side_effect = [initial_response, poll_response]

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
with patch("time.sleep"):
result = sea_client.execute_command(
Expand Down Expand Up @@ -360,7 +357,7 @@ def test_command_execution_advanced(
dbsql_param = IntegerParameter(name="param1", value=1)
param = dbsql_param.as_tspark_param(named=True)

with patch.object(sea_client, "get_execution_result"):
with patch.object(sea_client, "_response_to_result_set"):
sea_client.execute_command(
operation="SELECT * FROM table WHERE col = :param1",
session_id=sea_session_id,
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.