Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
6 changes: 5 additions & 1 deletion 6 durabletask/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
):
if interceptors is not None:
interceptors = list(interceptors)
Expand All @@ -46,7 +47,10 @@ def __init__(
interceptors = None

channel = get_grpc_aio_channel(
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
host_address=host_address,
secure_channel=secure_channel,
interceptors=interceptors,
options=channel_options,
)
self._channel = channel
self._stub = stubs.TaskHubSidecarServiceStub(channel)
Expand Down
16 changes: 14 additions & 2 deletions 16 durabletask/aio/internal/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import grpc
from grpc import aio as grpc_aio
from grpc.aio import ChannelArgumentType

from durabletask.internal.shared import (
INSECURE_PROTOCOLS,
Expand All @@ -24,7 +25,16 @@ def get_grpc_aio_channel(
host_address: Optional[str],
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
options: Optional[ChannelArgumentType] = None,
) -> grpc_aio.Channel:
"""create a grpc asyncio channel

Args:
host_address: The host address of the gRPC server. If None, uses the default address.
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
interceptors: Optional sequence of client interceptors to apply to the channel.
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
"""
if host_address is None:
host_address = get_default_host_address()

Expand All @@ -42,9 +52,11 @@ def get_grpc_aio_channel(

if secure_channel:
channel = grpc_aio.secure_channel(
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors, options=options
)
else:
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)
channel = grpc_aio.insecure_channel(
host_address, interceptors=interceptors, options=options
)

return channel
6 changes: 5 additions & 1 deletion 6 durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
):
# If the caller provided metadata, we need to create a new interceptor for it and
# add it to the list of interceptors.
Expand All @@ -121,7 +122,10 @@ def __init__(
interceptors = None

channel = shared.get_grpc_channel(
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
host_address=host_address,
secure_channel=secure_channel,
interceptors=interceptors,
options=channel_options,
)
self._stub = stubs.TaskHubSidecarServiceStub(channel)
self._logger = shared.get_logger("client", log_handler, log_formatter)
Expand Down
18 changes: 13 additions & 5 deletions 18 durabletask/internal/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def get_default_host_address() -> str:
Honors environment variables if present; otherwise defaults to localhost:4001.

Supported environment variables (checked in order):
- DURABLETASK_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443")
- DURABLETASK_GRPC_HOST and DURABLETASK_GRPC_PORT
- DAPR_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443")
- DAPR_GRPC_HOST/DAPR_RUNTIME_HOST and DAPR_GRPC_PORT
"""

# Full endpoint overrides
Expand All @@ -54,7 +54,16 @@ def get_grpc_channel(
host_address: Optional[str],
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
options: Optional[Sequence[tuple[str, Any]]] = None,
) -> grpc.Channel:
"""create a grpc channel

Args:
host_address: The host address of the gRPC server. If None, uses the default address (as defined in get_default_host_address above).
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
interceptors: Optional sequence of client interceptors to apply to the channel.
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
"""
if host_address is None:
host_address = get_default_host_address()

Expand All @@ -72,11 +81,10 @@ def get_grpc_channel(
host_address = host_address[len(protocol) :]
break

# Create the base channel
if secure_channel:
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options)
else:
channel = grpc.insecure_channel(host_address)
channel = grpc.insecure_channel(host_address, options=options)

# Apply interceptors ONLY if they exist
if interceptors:
Expand Down
7 changes: 6 additions & 1 deletion 7 durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,15 @@ def __init__(
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
concurrency_options: Optional[ConcurrencyOptions] = None,
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
):
self._registry = _Registry()
self._host_address = host_address if host_address else shared.get_default_host_address()
self._logger = shared.get_logger("worker", log_handler, log_formatter)
self._shutdown = Event()
self._is_running = False
self._secure_channel = secure_channel
self._channel_options = channel_options

# Use provided concurrency options or create default ones
self._concurrency_options = (
Expand Down Expand Up @@ -306,7 +308,10 @@ def create_fresh_connection():
current_stub = None
try:
current_channel = shared.get_grpc_channel(
self._host_address, self._secure_channel, self._interceptors
self._host_address,
self._secure_channel,
self._interceptors,
options=self._channel_options,
)
current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
current_stub.Hello(empty_pb2.Empty())
Expand Down
79 changes: 64 additions & 15 deletions 79 tests/durabletask/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import ANY, patch
from unittest.mock import patch

from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
from durabletask.internal.shared import get_default_host_address, get_grpc_channel
Expand All @@ -11,7 +11,9 @@
def test_get_grpc_channel_insecure():
with patch("grpc.insecure_channel") as mock_channel:
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS)
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert "options" in kwargs and kwargs["options"] is None


def test_get_grpc_channel_secure():
Expand All @@ -20,13 +22,18 @@ def test_get_grpc_channel_secure():
patch("grpc.ssl_channel_credentials") as mock_credentials,
):
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert args[1] == mock_credentials.return_value
assert "options" in kwargs and kwargs["options"] is None


def test_get_grpc_channel_default_host_address():
with patch("grpc.insecure_channel") as mock_channel:
get_grpc_channel(None, False, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(get_default_host_address())
args, kwargs = mock_channel.call_args
assert args[0] == get_default_host_address()
assert "options" in kwargs and kwargs["options"] is None


def test_get_grpc_channel_with_metadata():
Expand All @@ -35,7 +42,9 @@ def test_get_grpc_channel_with_metadata():
patch("grpc.intercept_channel") as mock_intercept_channel,
):
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS)
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert "options" in kwargs and kwargs["options"] is None
mock_intercept_channel.assert_called_once()

# Capture and check the arguments passed to intercept_channel()
Expand All @@ -54,40 +63,80 @@ def test_grpc_channel_with_host_name_protocol_stripping():

prefix = "grpc://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "http://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "HTTP://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "GRPC://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = ""
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "grpcs://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "https://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "HTTPS://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "GRPCS://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = ""
get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None


def test_sync_channel_passes_base_options_and_max_lengths():
base_options = [
("grpc.max_send_message_length", 1234),
("grpc.max_receive_message_length", 5678),
("grpc.primary_user_agent", "durabletask-tests"),
]
with patch("grpc.insecure_channel") as mock_channel:
get_grpc_channel(HOST_ADDRESS, False, options=base_options)
# Ensure called with options kwarg
assert mock_channel.call_count == 1
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert "options" in kwargs
opts = kwargs["options"]
# Check our base options made it through
assert ("grpc.max_send_message_length", 1234) in opts
assert ("grpc.max_receive_message_length", 5678) in opts
assert ("grpc.primary_user_agent", "durabletask-tests") in opts
Loading
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.