Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
17 changes: 0 additions & 17 deletions 17 src/databricks/sql/auth/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@
logger = logging.getLogger(__name__)


def parse_hostname(hostname: str) -> str:
"""
Normalize the hostname to include scheme and trailing slash.

Args:
hostname: The hostname to normalize

Returns:
Normalized hostname with scheme and trailing slash
"""
if not hostname.startswith("http://") and not hostname.startswith("https://"):
hostname = f"https://{hostname}"
if not hostname.endswith("/"):
hostname = f"{hostname}/"
return hostname


def decode_token(access_token: str) -> Optional[Dict]:
"""
Decode a JWT token without verification to extract claims.
Expand Down
6 changes: 3 additions & 3 deletions 6 src/databricks/sql/auth/token_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from databricks.sql.auth.authenticators import AuthProvider
from databricks.sql.auth.auth_utils import (
parse_hostname,
decode_token,
is_same_host,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol
from databricks.sql.common.http import HttpMethod

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
if not http_client:
raise ValueError("http_client is required for TokenFederationProvider")

self.hostname = parse_hostname(hostname)
self.hostname = normalize_host_with_protocol(hostname)
self.external_provider = external_provider
self.http_client = http_client
self.identity_federation_client_id = identity_federation_client_id
Expand Down Expand Up @@ -164,7 +164,7 @@ def _should_exchange_token(self, access_token: str) -> bool:

def _exchange_token(self, access_token: str) -> Token:
"""Exchange the external token for a Databricks token."""
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"
token_url = f"{self.hostname}{self.TOKEN_EXCHANGE_ENDPOINT}"

data = {
"grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE,
Expand Down
6 changes: 4 additions & 2 deletions 6 src/databricks/sql/backend/sea/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from databricks.sql.common.http_utils import (
detect_and_parse_proxy,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,8 +67,9 @@ def __init__(
self.auth_provider = auth_provider
self.ssl_options = ssl_options

# Build base URL
self.base_url = f"https://{server_hostname}:{self.port}"
# Build base URL using url_utils for consistent normalization
normalized_host = normalize_host_with_protocol(server_hostname)
self.base_url = f"{normalized_host}:{self.port}"

# Parse URL for proxy handling
parsed_url = urllib.parse.urlparse(self.base_url)
Expand Down
4 changes: 3 additions & 1 deletion 4 src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Optional, List, Any, TYPE_CHECKING

from databricks.sql.common.http import HttpMethod
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
normalize_host_with_protocol(self._connection.session.host)
+ endpoint_suffix
)

# Use the provided HTTP client
Expand Down
45 changes: 45 additions & 0 deletions 45 src/databricks/sql/common/url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
URL utility functions for the Databricks SQL connector.
"""


def normalize_host_with_protocol(host: str) -> str:
Comment thread
nikhilsuri-db marked this conversation as resolved.
"""
Normalize a connection hostname by ensuring it has a protocol.

This is useful for handling cases where users may provide hostnames with or without protocols
(common with dbt-databricks users copying URLs from their browser).

Args:
host: Connection hostname which may or may not include a protocol prefix (https:// or http://)
Comment thread
nikhilsuri-db marked this conversation as resolved.
and may or may not have a trailing slash

Returns:
Normalized hostname with protocol prefix and no trailing slashes

Examples:
normalize_host_with_protocol("myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("HTTPS://myserver.com/") -> "https://myserver.com"
normalize_host_with_protocol("http://localhost:8080/") -> "http://localhost:8080"

Raises:
ValueError: If host is None or empty string
"""
# Handle None or empty host
if not host or not host.strip():
raise ValueError("Host cannot be None or empty")

# Remove trailing slashes
host = host.rstrip("/")
Comment thread
nikhilsuri-db marked this conversation as resolved.

# Add protocol if not present (case-insensitive check)
host_lower = host.lower()
if not host_lower.startswith("https://") and not host_lower.startswith("http://"):
host = f"https://{host}"
elif host_lower.startswith("https://") or host_lower.startswith("http://"):
# Normalize protocol to lowercase
protocol_end = host.index("://") + 3
host = host[:protocol_end].lower() + host[protocol_end:]

return host
3 changes: 2 additions & 1 deletion 3 src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TelemetryPushClient,
CircuitBreakerTelemetryPushClient,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -278,7 +279,7 @@ def _send_telemetry(self, events):
if self._auth_provider
else self.TELEMETRY_UNAUTHENTICATED_PATH
)
url = f"https://{self._host_url}{path}"
url = normalize_host_with_protocol(self._host_url) + path

headers = {"Accept": "application/json", "Content-Type": "application/json"}

Expand Down
79 changes: 58 additions & 21 deletions 79 tests/e2e/test_circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@
from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager


def wait_for_circuit_state(circuit_breaker, expected_states, timeout=5):
"""
Wait for circuit breaker to reach one of the expected states with polling.

Args:
circuit_breaker: The circuit breaker instance to monitor
expected_states: List of acceptable states
(STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN)
timeout: Maximum time to wait in seconds

Returns:
True if state reached, False if timeout

Examples:
# Single state - pass list with one element
wait_for_circuit_state(cb, [STATE_OPEN])

# Multiple states
wait_for_circuit_state(cb, [STATE_CLOSED, STATE_HALF_OPEN])
"""
start = time.time()
while time.time() - start < timeout:
if circuit_breaker.current_state in expected_states:
return True
time.sleep(0.1) # Poll every 100ms
return False


@pytest.fixture(autouse=True)
def aggressive_circuit_breaker_config():
"""
Expand Down Expand Up @@ -65,12 +93,17 @@ def create_mock_response(self, status_code):
}.get(status_code, b"Response")
return response

@pytest.mark.parametrize("status_code,should_trigger", [
(429, True),
(503, True),
(500, False),
])
def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger):
@pytest.mark.parametrize(
"status_code,should_trigger",
[
(429, True),
(503, True),
(500, False),
],
)
def test_circuit_breaker_triggers_for_rate_limit_codes(
self, status_code, should_trigger
):
"""
Verify circuit breaker opens for rate-limit codes (429/503) but not others (500).
"""
Expand Down Expand Up @@ -107,9 +140,14 @@ def mock_request(*args, **kwargs):
time.sleep(0.5)

if should_trigger:
# Circuit should be OPEN after 2 rate-limit failures
# Wait for circuit to open (async telemetry may take time)
assert wait_for_circuit_state(
circuit_breaker, [STATE_OPEN], timeout=5
), f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}"

# Circuit should be OPEN after rate-limit failures
assert circuit_breaker.current_state == STATE_OPEN
assert circuit_breaker.fail_counter == 2
assert circuit_breaker.fail_counter >= 2 # At least 2 failures

# Track requests before another query
requests_before = request_count["count"]
Expand Down Expand Up @@ -197,7 +235,10 @@ def mock_conditional_request(*args, **kwargs):
cursor.fetchone()
time.sleep(2)

assert circuit_breaker.current_state == STATE_OPEN
# Wait for circuit to open
assert wait_for_circuit_state(
circuit_breaker, [STATE_OPEN], timeout=5
), f"Circuit didn't open, state: {circuit_breaker.current_state}"

# Wait for reset timeout (5 seconds in test)
time.sleep(6)
Expand All @@ -208,24 +249,20 @@ def mock_conditional_request(*args, **kwargs):
# Execute query to trigger HALF_OPEN state
cursor.execute("SELECT 3")
cursor.fetchone()
time.sleep(1)

# Circuit should be recovering
assert circuit_breaker.current_state in [
STATE_HALF_OPEN,
STATE_CLOSED,
], f"Circuit should be recovering, but is {circuit_breaker.current_state}"
# Wait for circuit to start recovering
assert wait_for_circuit_state(
circuit_breaker, [STATE_HALF_OPEN, STATE_CLOSED], timeout=5
), f"Circuit didn't recover, state: {circuit_breaker.current_state}"

# Execute more queries to fully recover
cursor.execute("SELECT 4")
cursor.fetchone()
time.sleep(1)

current_state = circuit_breaker.current_state
assert current_state in [
STATE_CLOSED,
STATE_HALF_OPEN,
], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}"
# Wait for full recovery
assert wait_for_circuit_state(
circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5
), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}"


if __name__ == "__main__":
Expand Down
36 changes: 32 additions & 4 deletions 36 tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase):
"access_token": "tok",
}

def _setup_mock_session_with_http_client(self, mock_session):
"""
Helper to configure a mock session with HTTP client mocks.
This prevents feature flag network requests during Connection initialization.
"""
mock_session.host = "foo"

# Mock HTTP client to prevent feature flag network requests
mock_http_client = Mock()
mock_session.http_client = mock_http_client

# Mock feature flag response to prevent blocking HTTP calls
mock_ff_response = Mock()
mock_ff_response.status = 200
mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}'
mock_http_client.request.return_value = mock_ff_response

def _create_mock_connection(self, mock_session_class):
"""Helper to create a mocked connection for transaction tests."""
# Mock session
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"
mock_session.get_autocommit.return_value = True

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=False to test actual transaction functionality
Expand Down Expand Up @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class):
conn = self._create_mock_connection(mock_session_class)

mock_cursor = Mock()
original_error = DatabaseError(
"Original error", host_url="test-host"
)
original_error = DatabaseError("Original error", host_url="test-host")
mock_cursor.execute.side_effect = original_error

with patch.object(conn, "cursor", return_value=mock_cursor):
Expand Down Expand Up @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down Expand Up @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand All @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down
Loading
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.