Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c6fb822

Browse filesBrowse files
authored
Fix streamable http sampling (#693)
1 parent ed25167 commit c6fb822
Copy full SHA for c6fb822

File tree

7 files changed

+152
-23
lines changed
Filter options

7 files changed

+152
-23
lines changed

‎src/mcp/cli/claude.py

Copy file name to clipboardExpand all lines: src/mcp/cli/claude.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_claude_config_path() -> Path | None:
3131
return path
3232
return None
3333

34+
3435
def get_uv_path() -> str:
3536
"""Get the full path to the uv executable."""
3637
uv_path = shutil.which("uv")
@@ -42,6 +43,7 @@ def get_uv_path() -> str:
4243
return "uv" # Fall back to just "uv" if not found
4344
return uv_path
4445

46+
4547
def update_claude_config(
4648
file_spec: str,
4749
server_name: str,

‎src/mcp/client/streamable_http.py

Copy file name to clipboardExpand all lines: src/mcp/client/streamable_http.py
+19-5Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import anyio
1717
import httpx
18+
from anyio.abc import TaskGroup
1819
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1920
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2021

@@ -239,7 +240,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
239240
break
240241

241242
async def _handle_post_request(self, ctx: RequestContext) -> None:
242-
"""Handle a POST request with response processing."""
243+
"""Handle a POST request with response processing."""
243244
headers = self._update_headers_with_session(ctx.headers)
244245
message = ctx.session_message.message
245246
is_initialization = self._is_initialization_request(message)
@@ -300,7 +301,7 @@ async def _handle_sse_response(
300301
try:
301302
event_source = EventSource(response)
302303
async for sse in event_source.aiter_sse():
303-
await self._handle_sse_event(
304+
is_complete = await self._handle_sse_event(
304305
sse,
305306
ctx.read_stream_writer,
306307
resumption_callback=(
@@ -309,6 +310,10 @@ async def _handle_sse_response(
309310
else None
310311
),
311312
)
313+
# If the SSE event indicates completion, like returning respose/error
314+
# break the loop
315+
if is_complete:
316+
break
312317
except Exception as e:
313318
logger.exception("Error reading SSE stream:")
314319
await ctx.read_stream_writer.send(e)
@@ -344,6 +349,7 @@ async def post_writer(
344349
read_stream_writer: StreamWriter,
345350
write_stream: MemoryObjectSendStream[SessionMessage],
346351
start_get_stream: Callable[[], None],
352+
tg: TaskGroup,
347353
) -> None:
348354
"""Handle writing requests to the server."""
349355
try:
@@ -375,10 +381,17 @@ async def post_writer(
375381
sse_read_timeout=self.sse_read_timeout,
376382
)
377383

378-
if is_resumption:
379-
await self._handle_resumption_request(ctx)
384+
async def handle_request_async():
385+
if is_resumption:
386+
await self._handle_resumption_request(ctx)
387+
else:
388+
await self._handle_post_request(ctx)
389+
390+
# If this is a request, start a new task to handle it
391+
if isinstance(message.root, JSONRPCRequest):
392+
tg.start_soon(handle_request_async)
380393
else:
381-
await self._handle_post_request(ctx)
394+
await handle_request_async()
382395

383396
except Exception as exc:
384397
logger.error(f"Error in post_writer: {exc}")
@@ -466,6 +479,7 @@ def start_get_stream() -> None:
466479
read_stream_writer,
467480
write_stream,
468481
start_get_stream,
482+
tg,
469483
)
470484

471485
try:

‎src/mcp/server/session.py

Copy file name to clipboardExpand all lines: src/mcp/server/session.py
+7-3Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
50-
from mcp.shared.message import SessionMessage
50+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5151
from mcp.shared.session import (
5252
BaseSession,
5353
RequestResponder,
@@ -230,10 +230,11 @@ async def create_message(
230230
stop_sequences: list[str] | None = None,
231231
metadata: dict[str, Any] | None = None,
232232
model_preferences: types.ModelPreferences | None = None,
233+
related_request_id: types.RequestId | None = None,
233234
) -> types.CreateMessageResult:
234235
"""Send a sampling/create_message request."""
235236
return await self.send_request(
236-
types.ServerRequest(
237+
request=types.ServerRequest(
237238
types.CreateMessageRequest(
238239
method="sampling/createMessage",
239240
params=types.CreateMessageRequestParams(
@@ -248,7 +249,10 @@ async def create_message(
248249
),
249250
)
250251
),
251-
types.CreateMessageResult,
252+
result_type=types.CreateMessageResult,
253+
metadata=ServerMessageMetadata(
254+
related_request_id=related_request_id,
255+
),
252256
)
253257

254258
async def list_roots(self) -> types.ListRootsResult:

‎src/mcp/server/streamable_http.py

Copy file name to clipboardExpand all lines: src/mcp/server/streamable_http.py
+14-7Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
ErrorData,
3434
JSONRPCError,
3535
JSONRPCMessage,
36-
JSONRPCNotification,
3736
JSONRPCRequest,
3837
JSONRPCResponse,
3938
RequestId,
@@ -849,9 +848,15 @@ async def message_router():
849848
# Determine which request stream(s) should receive this message
850849
message = session_message.message
851850
target_request_id = None
852-
if isinstance(
853-
message.root, JSONRPCNotification | JSONRPCRequest
854-
):
851+
# Check if this is a response
852+
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
853+
response_id = str(message.root.id)
854+
# If this response is for an existing request stream,
855+
# send it there
856+
if response_id in self._request_streams:
857+
target_request_id = response_id
858+
859+
else:
855860
# Extract related_request_id from meta if it exists
856861
if (
857862
session_message.metadata is not None
@@ -865,10 +870,12 @@ async def message_router():
865870
target_request_id = str(
866871
session_message.metadata.related_request_id
867872
)
868-
else:
869-
target_request_id = str(message.root.id)
870873

871-
request_stream_id = target_request_id or GET_STREAM_KEY
874+
request_stream_id = (
875+
target_request_id
876+
if target_request_id is not None
877+
else GET_STREAM_KEY
878+
)
872879

873880
# Store the event if we have an event store,
874881
# regardless of whether a client is connected

‎src/mcp/shared/session.py

Copy file name to clipboardExpand all lines: src/mcp/shared/session.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ async def send_request(
223223
Do not use this method to emit notifications! Use send_notification()
224224
instead.
225225
"""
226-
227226
request_id = self._request_id
228227
self._request_id = request_id + 1
229228

‎tests/client/test_config.py

Copy file name to clipboardExpand all lines: tests/client/test_config.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_absolute_uv_path(mock_config_path: Path):
5454
"""Test that the absolute path to uv is used when available."""
5555
# Mock the shutil.which function to return a fake path
5656
mock_uv_path = "/usr/local/bin/uv"
57-
57+
5858
with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path):
5959
# Setup
6060
server_name = "test_server"
@@ -71,5 +71,5 @@ def test_absolute_uv_path(mock_config_path: Path):
7171
# Verify the command is the absolute path
7272
server_config = config["mcpServers"][server_name]
7373
command = server_config["command"]
74-
75-
assert command == mock_uv_path
74+
75+
assert command == mock_uv_path

‎tests/shared/test_streamable_http.py

Copy file name to clipboardExpand all lines: tests/shared/test_streamable_http.py
+107-4Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import socket
99
import time
1010
from collections.abc import Generator
11+
from typing import Any
1112

1213
import anyio
1314
import httpx
@@ -33,6 +34,7 @@
3334
StreamId,
3435
)
3536
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
37+
from mcp.shared.context import RequestContext
3638
from mcp.shared.exceptions import McpError
3739
from mcp.shared.message import (
3840
ClientMessageMetadata,
@@ -139,6 +141,11 @@ async def handle_list_tools() -> list[Tool]:
139141
description="A long-running tool that sends periodic notifications",
140142
inputSchema={"type": "object", "properties": {}},
141143
),
144+
Tool(
145+
name="test_sampling_tool",
146+
description="A tool that triggers server-side sampling",
147+
inputSchema={"type": "object", "properties": {}},
148+
),
142149
]
143150

144151
@self.call_tool()
@@ -174,6 +181,34 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
174181

175182
return [TextContent(type="text", text="Completed!")]
176183

184+
elif name == "test_sampling_tool":
185+
# Test sampling by requesting the client to sample a message
186+
sampling_result = await ctx.session.create_message(
187+
messages=[
188+
types.SamplingMessage(
189+
role="user",
190+
content=types.TextContent(
191+
type="text", text="Server needs client sampling"
192+
),
193+
)
194+
],
195+
max_tokens=100,
196+
related_request_id=ctx.request_id,
197+
)
198+
199+
# Return the sampling result in the tool response
200+
response = (
201+
sampling_result.content.text
202+
if sampling_result.content.type == "text"
203+
else None
204+
)
205+
return [
206+
TextContent(
207+
type="text",
208+
text=f"Response from sampling: {response}",
209+
)
210+
]
211+
177212
return [TextContent(type="text", text=f"Called {name}")]
178213

179214

@@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
754789
"""Test client tool invocation."""
755790
# First list tools
756791
tools = await initialized_client_session.list_tools()
757-
assert len(tools.tools) == 3
792+
assert len(tools.tools) == 4
758793
assert tools.tools[0].name == "test_tool"
759794

760795
# Call the tool
@@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence(
795830

796831
# Make multiple requests to verify session persistence
797832
tools = await session.list_tools()
798-
assert len(tools.tools) == 3
833+
assert len(tools.tools) == 4
799834

800835
# Read a resource
801836
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response(
826861

827862
# Check tool listing
828863
tools = await session.list_tools()
829-
assert len(tools.tools) == 3
864+
assert len(tools.tools) == 4
830865

831866
# Call a tool and verify JSON response handling
832867
result = await session.call_tool("test_tool", {})
@@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination(
905940

906941
# Make a request to confirm session is working
907942
tools = await session.list_tools()
908-
assert len(tools.tools) == 3
943+
assert len(tools.tools) == 4
909944

910945
headers = {}
911946
if captured_session_id:
@@ -1054,3 +1089,71 @@ async def run_tool():
10541089
assert not any(
10551090
n in captured_notifications_pre for n in captured_notifications
10561091
)
1092+
1093+
1094+
@pytest.mark.anyio
1095+
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
1096+
"""Test server-initiated sampling request through streamable HTTP transport."""
1097+
print("Testing server sampling...")
1098+
# Variable to track if sampling callback was invoked
1099+
sampling_callback_invoked = False
1100+
captured_message_params = None
1101+
1102+
# Define sampling callback that returns a mock response
1103+
async def sampling_callback(
1104+
context: RequestContext[ClientSession, Any],
1105+
params: types.CreateMessageRequestParams,
1106+
) -> types.CreateMessageResult:
1107+
nonlocal sampling_callback_invoked, captured_message_params
1108+
sampling_callback_invoked = True
1109+
captured_message_params = params
1110+
message_received = (
1111+
params.messages[0].content.text
1112+
if params.messages[0].content.type == "text"
1113+
else None
1114+
)
1115+
1116+
return types.CreateMessageResult(
1117+
role="assistant",
1118+
content=types.TextContent(
1119+
type="text",
1120+
text=f"Received message from server: {message_received}",
1121+
),
1122+
model="test-model",
1123+
stopReason="endTurn",
1124+
)
1125+
1126+
# Create client with sampling callback
1127+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1128+
read_stream,
1129+
write_stream,
1130+
_,
1131+
):
1132+
async with ClientSession(
1133+
read_stream,
1134+
write_stream,
1135+
sampling_callback=sampling_callback,
1136+
) as session:
1137+
# Initialize the session
1138+
result = await session.initialize()
1139+
assert isinstance(result, InitializeResult)
1140+
1141+
# Call the tool that triggers server-side sampling
1142+
tool_result = await session.call_tool("test_sampling_tool", {})
1143+
1144+
# Verify the tool result contains the expected content
1145+
assert len(tool_result.content) == 1
1146+
assert tool_result.content[0].type == "text"
1147+
assert (
1148+
"Response from sampling: Received message from server"
1149+
in tool_result.content[0].text
1150+
)
1151+
1152+
# Verify sampling callback was invoked
1153+
assert sampling_callback_invoked
1154+
assert captured_message_params is not None
1155+
assert len(captured_message_params.messages) == 1
1156+
assert (
1157+
captured_message_params.messages[0].content.text
1158+
== "Server needs client sampling"
1159+
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.