Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5d8eaf7

Browse filesBrowse files
authored
Streamable Http - clean up server memory streams (#604)
1 parent 74f5fcf commit 5d8eaf7
Copy full SHA for 5d8eaf7

File tree

4 files changed

+108
-69
lines changed
Filter options

4 files changed

+108
-69
lines changed

‎examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

Copy file name to clipboardExpand all lines: examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py
+11-9Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,20 +185,22 @@ async def handle_streamable_http(scope, receive, send):
185185
)
186186
server_instances[http_transport.mcp_session_id] = http_transport
187187
logger.info(f"Created new transport with session ID: {new_session_id}")
188-
async with http_transport.connect() as streams:
189-
read_stream, write_stream = streams
190188

191-
async def run_server():
192-
await app.run(
193-
read_stream,
194-
write_stream,
195-
app.create_initialization_options(),
196-
)
189+
async def run_server(task_status=None):
190+
async with http_transport.connect() as streams:
191+
read_stream, write_stream = streams
192+
if task_status:
193+
task_status.started()
194+
await app.run(
195+
read_stream,
196+
write_stream,
197+
app.create_initialization_options(),
198+
)
197199

198200
if not task_group:
199201
raise RuntimeError("Task group is not initialized")
200202

201-
task_group.start_soon(run_server)
203+
await task_group.start(run_server)
202204

203205
# Handle the HTTP request and return the response
204206
await http_transport.handle_request(scope, receive, send)

‎src/mcp/server/lowlevel/server.py

Copy file name to clipboardExpand all lines: src/mcp/server/lowlevel/server.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ async def run(
480480
# but also make tracing exceptions much easier during testing and when using
481481
# in-process servers.
482482
raise_exceptions: bool = False,
483-
# When True, the server as stateless deployments where
483+
# When True, the server is stateless and
484484
# clients can perform initialization with any node. The client must still follow
485485
# the initialization lifecycle, but can do so with any available node
486486
# rather than requiring initialization for each connection.

‎src/mcp/server/streamable_http.py

Copy file name to clipboardExpand all lines: src/mcp/server/streamable_http.py
+80-44Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
129129
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
130130
None
131131
)
132+
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
133+
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
132134
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
133135

134136
def __init__(
@@ -163,7 +165,11 @@ def __init__(
163165
self.is_json_response_enabled = is_json_response_enabled
164166
self._event_store = event_store
165167
self._request_streams: dict[
166-
RequestId, MemoryObjectSendStream[EventMessage]
168+
RequestId,
169+
tuple[
170+
MemoryObjectSendStream[EventMessage],
171+
MemoryObjectReceiveStream[EventMessage],
172+
],
167173
] = {}
168174
self._terminated = False
169175

@@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
239245

240246
return event_data
241247

248+
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
249+
"""Clean up memory streams for a given request ID."""
250+
if request_id in self._request_streams:
251+
try:
252+
# Close the request stream
253+
await self._request_streams[request_id][0].aclose()
254+
await self._request_streams[request_id][1].aclose()
255+
except Exception as e:
256+
logger.debug(f"Error closing memory streams: {e}")
257+
finally:
258+
# Remove the request stream from the mapping
259+
self._request_streams.pop(request_id, None)
260+
242261
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
243262
"""Application entry point that handles all HTTP requests"""
244263
request = Request(scope, receive)
@@ -386,13 +405,11 @@ async def _handle_post_request(
386405

387406
# Extract the request ID outside the try block for proper scope
388407
request_id = str(message.root.id)
389-
# Create promise stream for getting response
390-
request_stream_writer, request_stream_reader = (
391-
anyio.create_memory_object_stream[EventMessage](0)
392-
)
393-
394408
# Register this stream for the request ID
395-
self._request_streams[request_id] = request_stream_writer
409+
self._request_streams[request_id] = anyio.create_memory_object_stream[
410+
EventMessage
411+
](0)
412+
request_stream_reader = self._request_streams[request_id][1]
396413

397414
if self.is_json_response_enabled:
398415
# Process the message
@@ -441,11 +458,7 @@ async def _handle_post_request(
441458
)
442459
await response(scope, receive, send)
443460
finally:
444-
# Clean up the request stream
445-
if request_id in self._request_streams:
446-
self._request_streams.pop(request_id, None)
447-
await request_stream_reader.aclose()
448-
await request_stream_writer.aclose()
461+
await self._clean_up_memory_streams(request_id)
449462
else:
450463
# Create SSE stream
451464
sse_stream_writer, sse_stream_reader = (
@@ -467,16 +480,12 @@ async def sse_writer():
467480
event_message.message.root,
468481
JSONRPCResponse | JSONRPCError,
469482
):
470-
if request_id:
471-
self._request_streams.pop(request_id, None)
472483
break
473484
except Exception as e:
474485
logger.exception(f"Error in SSE writer: {e}")
475486
finally:
476487
logger.debug("Closing SSE writer")
477-
# Clean up the request-specific streams
478-
if request_id and request_id in self._request_streams:
479-
self._request_streams.pop(request_id, None)
488+
await self._clean_up_memory_streams(request_id)
480489

481490
# Create and start EventSourceResponse
482491
# SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ async def sse_writer():
507516
await writer.send(session_message)
508517
except Exception:
509518
logger.exception("SSE response error")
510-
# Clean up the request stream if something goes wrong
511-
if request_id and request_id in self._request_streams:
512-
self._request_streams.pop(request_id, None)
519+
await sse_stream_writer.aclose()
520+
await sse_stream_reader.aclose()
521+
await self._clean_up_memory_streams(request_id)
513522

514523
except Exception as err:
515524
logger.exception("Error handling POST request")
@@ -581,12 +590,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
581590
async def standalone_sse_writer():
582591
try:
583592
# Create a standalone message stream for server-initiated messages
584-
standalone_stream_writer, standalone_stream_reader = (
593+
594+
self._request_streams[GET_STREAM_KEY] = (
585595
anyio.create_memory_object_stream[EventMessage](0)
586596
)
587-
588-
# Register this stream using the special key
589-
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
597+
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
590598

591599
async with sse_stream_writer, standalone_stream_reader:
592600
# Process messages from the standalone stream
@@ -603,8 +611,7 @@ async def standalone_sse_writer():
603611
logger.exception(f"Error in standalone SSE writer: {e}")
604612
finally:
605613
logger.debug("Closing standalone SSE writer")
606-
# Remove the stream from request_streams
607-
self._request_streams.pop(GET_STREAM_KEY, None)
614+
await self._clean_up_memory_streams(GET_STREAM_KEY)
608615

609616
# Create and start EventSourceResponse
610617
response = EventSourceResponse(
@@ -618,8 +625,9 @@ async def standalone_sse_writer():
618625
await response(request.scope, request.receive, send)
619626
except Exception as e:
620627
logger.exception(f"Error in standalone SSE response: {e}")
621-
# Clean up the request stream
622-
self._request_streams.pop(GET_STREAM_KEY, None)
628+
await sse_stream_writer.aclose()
629+
await sse_stream_reader.aclose()
630+
await self._clean_up_memory_streams(GET_STREAM_KEY)
623631

624632
async def _handle_delete_request(self, request: Request, send: Send) -> None:
625633
"""Handle DELETE requests for explicit session termination."""
@@ -636,15 +644,15 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
636644
if not await self._validate_session(request, send):
637645
return
638646

639-
self._terminate_session()
647+
await self._terminate_session()
640648

641649
response = self._create_json_response(
642650
None,
643651
HTTPStatus.OK,
644652
)
645653
await response(request.scope, request.receive, send)
646654

647-
def _terminate_session(self) -> None:
655+
async def _terminate_session(self) -> None:
648656
"""Terminate the current session, closing all streams.
649657
650658
Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -656,19 +664,26 @@ def _terminate_session(self) -> None:
656664
# We need a copy of the keys to avoid modification during iteration
657665
request_stream_keys = list(self._request_streams.keys())
658666

659-
# Close all request streams (synchronously)
667+
# Close all request streams asynchronously
660668
for key in request_stream_keys:
661669
try:
662-
# Get the stream
663-
stream = self._request_streams.get(key)
664-
if stream:
665-
# We must use close() here, not aclose() since this is a sync method
666-
stream.close()
670+
await self._clean_up_memory_streams(key)
667671
except Exception as e:
668672
logger.debug(f"Error closing stream {key} during termination: {e}")
669673

670674
# Clear the request streams dictionary immediately
671675
self._request_streams.clear()
676+
try:
677+
if self._read_stream_writer is not None:
678+
await self._read_stream_writer.aclose()
679+
if self._read_stream is not None:
680+
await self._read_stream.aclose()
681+
if self._write_stream_reader is not None:
682+
await self._write_stream_reader.aclose()
683+
if self._write_stream is not None:
684+
await self._write_stream.aclose()
685+
except Exception as e:
686+
logger.debug(f"Error closing streams: {e}")
672687

673688
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
674689
"""Handle unsupported HTTP methods."""
@@ -756,10 +771,10 @@ async def send_event(event_message: EventMessage) -> None:
756771

757772
# If stream ID not in mapping, create it
758773
if stream_id and stream_id not in self._request_streams:
759-
msg_writer, msg_reader = anyio.create_memory_object_stream[
760-
EventMessage
761-
](0)
762-
self._request_streams[stream_id] = msg_writer
774+
self._request_streams[stream_id] = (
775+
anyio.create_memory_object_stream[EventMessage](0)
776+
)
777+
msg_reader = self._request_streams[stream_id][1]
763778

764779
# Forward messages to SSE
765780
async with msg_reader:
@@ -781,6 +796,9 @@ async def send_event(event_message: EventMessage) -> None:
781796
await response(request.scope, request.receive, send)
782797
except Exception as e:
783798
logger.exception(f"Error in replay response: {e}")
799+
finally:
800+
await sse_stream_writer.aclose()
801+
await sse_stream_reader.aclose()
784802

785803
except Exception as e:
786804
logger.exception(f"Error replaying events: {e}")
@@ -818,7 +836,9 @@ async def connect(
818836

819837
# Store the streams
820838
self._read_stream_writer = read_stream_writer
839+
self._read_stream = read_stream
821840
self._write_stream_reader = write_stream_reader
841+
self._write_stream = write_stream
822842

823843
# Start a task group for message routing
824844
async with anyio.create_task_group() as tg:
@@ -863,7 +883,7 @@ async def message_router():
863883
if request_stream_id in self._request_streams:
864884
try:
865885
# Send both the message and the event ID
866-
await self._request_streams[request_stream_id].send(
886+
await self._request_streams[request_stream_id][0].send(
867887
EventMessage(message, event_id)
868888
)
869889
except (
@@ -872,6 +892,12 @@ async def message_router():
872892
):
873893
# Stream might be closed, remove from registry
874894
self._request_streams.pop(request_stream_id, None)
895+
else:
896+
logging.debug(
897+
f"""Request stream {request_stream_id} not found
898+
for message. Still processing message as the client
899+
might reconnect and replay."""
900+
)
875901
except Exception as e:
876902
logger.exception(f"Error in message router: {e}")
877903

@@ -882,9 +908,19 @@ async def message_router():
882908
# Yield the streams for the caller to use
883909
yield read_stream, write_stream
884910
finally:
885-
for stream in list(self._request_streams.values()):
911+
for stream_id in list(self._request_streams.keys()):
886912
try:
887-
await stream.aclose()
888-
except Exception:
913+
await self._clean_up_memory_streams(stream_id)
914+
except Exception as e:
915+
logger.debug(f"Error closing request stream: {e}")
889916
pass
890917
self._request_streams.clear()
918+
919+
# Clean up the read and write streams
920+
try:
921+
await read_stream_writer.aclose()
922+
await read_stream.aclose()
923+
await write_stream_reader.aclose()
924+
await write_stream.aclose()
925+
except Exception as e:
926+
logger.debug(f"Error closing streams: {e}")

‎tests/shared/test_streamable_http.py

Copy file name to clipboardExpand all lines: tests/shared/test_streamable_http.py
+16-15Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -234,29 +234,30 @@ async def handle_streamable_http(scope, receive, send):
234234
event_store=event_store,
235235
)
236236

237-
async with http_transport.connect() as streams:
238-
read_stream, write_stream = streams
239-
240-
async def run_server():
237+
async def run_server(task_status=None):
238+
async with http_transport.connect() as streams:
239+
read_stream, write_stream = streams
240+
if task_status:
241+
task_status.started()
241242
await server.run(
242243
read_stream,
243244
write_stream,
244245
server.create_initialization_options(),
245246
)
246247

247-
if task_group is None:
248-
response = Response(
249-
"Internal Server Error: Task group is not initialized",
250-
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
251-
)
252-
await response(scope, receive, send)
253-
return
248+
if task_group is None:
249+
response = Response(
250+
"Internal Server Error: Task group is not initialized",
251+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
252+
)
253+
await response(scope, receive, send)
254+
return
254255

255-
# Store the instance before starting the task to prevent races
256-
server_instances[http_transport.mcp_session_id] = http_transport
257-
task_group.start_soon(run_server)
256+
# Store the instance before starting the task to prevent races
257+
server_instances[http_transport.mcp_session_id] = http_transport
258+
await task_group.start(run_server)
258259

259-
await http_transport.handle_request(scope, receive, send)
260+
await http_transport.handle_request(scope, receive, send)
260261
else:
261262
response = Response(
262263
"Bad Request: No valid session ID provided",

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.