Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3b1b213

Browse filesBrowse files
authored
Add message queue for SSE messages POST endpoint (#459)
1 parent 58c5e72 commit 3b1b213
Copy full SHA for 3b1b213

File tree

Expand file treeCollapse file tree

26 files changed

+1247
-50
lines changed
Filter options
Expand file treeCollapse file tree

26 files changed

+1247
-50
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+24Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,30 @@ app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app()))
412412

413413
For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes).
414414

415+
#### Message Dispatch Options
416+
417+
By default, the SSE server uses an in-memory message dispatch system for incoming POST messages. For production deployments or distributed scenarios, you can use Redis or implement your own message dispatch system that conforms to the `MessageDispatch` protocol:
418+
419+
```python
420+
# Using the built-in Redis message dispatch
421+
from mcp.server.fastmcp import FastMCP
422+
from mcp.server.message_queue import RedisMessageDispatch
423+
424+
# Create a Redis message dispatch
425+
redis_dispatch = RedisMessageDispatch(
426+
redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:"
427+
)
428+
429+
# Pass the message dispatch instance to the server
430+
mcp = FastMCP("My App", message_queue=redis_dispatch)
431+
```
432+
433+
To use Redis, add the Redis dependency:
434+
435+
```bash
436+
uv add "mcp[redis]"
437+
```
438+
415439
## Examples
416440

417441
### Echo Server

‎examples/servers/simple-prompt/mcp_simple_prompt/server.py

Copy file name to clipboardExpand all lines: examples/servers/simple-prompt/mcp_simple_prompt/server.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,15 @@ async def get_prompt(
8888
)
8989

9090
if transport == "sse":
91+
from mcp.server.message_queue.redis import RedisMessageDispatch
9192
from mcp.server.sse import SseServerTransport
9293
from starlette.applications import Starlette
9394
from starlette.responses import Response
9495
from starlette.routing import Mount, Route
9596

96-
sse = SseServerTransport("/messages/")
97+
message_dispatch = RedisMessageDispatch("redis://localhost:6379/0")
98+
99+
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
97100

98101
async def handle_sse(request):
99102
async with sse.connect_sse(

‎pyproject.toml

Copy file name to clipboardExpand all lines: pyproject.toml
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
rich = ["rich>=13.9.4"]
3838
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
3939
ws = ["websockets>=15.0.1"]
40+
redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"]
4041

4142
[project.scripts]
4243
mcp = "mcp.cli:app [cli]"
@@ -55,6 +56,7 @@ dev = [
5556
"pytest-xdist>=3.6.1",
5657
"pytest-examples>=0.0.14",
5758
"pytest-pretty>=1.2.0",
59+
"fakeredis==2.28.1",
5860
]
5961
docs = [
6062
"mkdocs>=1.6.1",

‎src/mcp/client/sse.py

Copy file name to clipboardExpand all lines: src/mcp/client/sse.py
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ async def sse_reader(
9898
await read_stream_writer.send(exc)
9999
continue
100100

101-
session_message = SessionMessage(message)
101+
session_message = SessionMessage(
102+
message=message
103+
)
102104
await read_stream_writer.send(session_message)
103105
case _:
104106
logger.warning(
@@ -148,3 +150,5 @@ async def post_writer(endpoint_url: str):
148150
finally:
149151
await read_stream_writer.aclose()
150152
await write_stream.aclose()
153+
await read_stream.aclose()
154+
await write_stream_reader.aclose()

‎src/mcp/client/stdio/__init__.py

Copy file name to clipboardExpand all lines: src/mcp/client/stdio/__init__.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ async def stdout_reader():
144144
await read_stream_writer.send(exc)
145145
continue
146146

147-
session_message = SessionMessage(message)
147+
session_message = SessionMessage(message=message)
148148
await read_stream_writer.send(session_message)
149149
except anyio.ClosedResourceError:
150150
await anyio.lowlevel.checkpoint()

‎src/mcp/client/streamable_http.py

Copy file name to clipboardExpand all lines: src/mcp/client/streamable_http.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def _handle_sse_event(
153153
):
154154
message.root.id = original_request_id
155155

156-
session_message = SessionMessage(message)
156+
session_message = SessionMessage(message=message)
157157
await read_stream_writer.send(session_message)
158158

159159
# Call resumption token callback if we have an ID
@@ -286,7 +286,7 @@ async def _handle_json_response(
286286
try:
287287
content = await response.aread()
288288
message = JSONRPCMessage.model_validate_json(content)
289-
session_message = SessionMessage(message)
289+
session_message = SessionMessage(message=message)
290290
await read_stream_writer.send(session_message)
291291
except Exception as exc:
292292
logger.error(f"Error parsing JSON response: {exc}")
@@ -333,7 +333,7 @@ async def _send_session_terminated_error(
333333
id=request_id,
334334
error=ErrorData(code=32600, message="Session terminated"),
335335
)
336-
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
336+
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
337337
await read_stream_writer.send(session_message)
338338

339339
async def post_writer(

‎src/mcp/client/websocket.py

Copy file name to clipboardExpand all lines: src/mcp/client/websocket.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def ws_reader():
6060
async for raw_text in ws:
6161
try:
6262
message = types.JSONRPCMessage.model_validate_json(raw_text)
63-
session_message = SessionMessage(message)
63+
session_message = SessionMessage(message=message)
6464
await read_stream_writer.send(session_message)
6565
except ValidationError as exc:
6666
# If JSON parse or model validation fails, send the exception

‎src/mcp/server/fastmcp/server.py

Copy file name to clipboardExpand all lines: src/mcp/server/fastmcp/server.py
+28-3Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from mcp.server.lowlevel.server import LifespanResultT
4545
from mcp.server.lowlevel.server import Server as MCPServer
4646
from mcp.server.lowlevel.server import lifespan as default_lifespan
47+
from mcp.server.message_queue import MessageDispatch
4748
from mcp.server.session import ServerSession, ServerSessionT
4849
from mcp.server.sse import SseServerTransport
4950
from mcp.server.stdio import stdio_server
@@ -90,6 +91,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
9091
sse_path: str = "/sse"
9192
message_path: str = "/messages/"
9293

94+
# SSE message queue settings
95+
message_dispatch: MessageDispatch | None = Field(
96+
None, description="Custom message dispatch instance"
97+
)
98+
9399
# resource settings
94100
warn_on_duplicate_resources: bool = True
95101

@@ -569,12 +575,21 @@ async def run_sse_async(self) -> None:
569575

570576
def sse_app(self) -> Starlette:
571577
"""Return an instance of the SSE server app."""
578+
message_dispatch = self.settings.message_dispatch
579+
if message_dispatch is None:
580+
from mcp.server.message_queue import InMemoryMessageDispatch
581+
582+
message_dispatch = InMemoryMessageDispatch()
583+
logger.info("Using default in-memory message dispatch")
584+
572585
from starlette.middleware import Middleware
573586
from starlette.routing import Mount, Route
574587

575588
# Set up auth context and dependencies
576589

577-
sse = SseServerTransport(self.settings.message_path)
590+
sse = SseServerTransport(
591+
self.settings.message_path, message_dispatch=message_dispatch
592+
)
578593

579594
async def handle_sse(scope: Scope, receive: Receive, send: Send):
580595
# Add client ID from auth context into request context if available
@@ -589,7 +604,14 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
589604
streams[1],
590605
self._mcp_server.create_initialization_options(),
591606
)
592-
return Response()
607+
return Response()
608+
609+
@asynccontextmanager
610+
async def lifespan(app: Starlette):
611+
try:
612+
yield
613+
finally:
614+
await message_dispatch.close()
593615

594616
# Create routes
595617
routes: list[Route | Mount] = []
@@ -666,7 +688,10 @@ async def sse_endpoint(request: Request) -> None:
666688

667689
# Create Starlette app with routes and middleware
668690
return Starlette(
669-
debug=self.settings.debug, routes=routes, middleware=middleware
691+
debug=self.settings.debug,
692+
routes=routes,
693+
middleware=middleware,
694+
lifespan=lifespan,
670695
)
671696

672697
async def list_prompts(self) -> list[MCPPrompt]:
+16Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
Message Dispatch Module for MCP Server
3+
4+
This module implements dispatch interfaces for handling
5+
messages between clients and servers.
6+
"""
7+
8+
from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch
9+
10+
# Try to import Redis implementation if available
11+
try:
12+
from mcp.server.message_queue.redis import RedisMessageDispatch
13+
except ImportError:
14+
RedisMessageDispatch = None
15+
16+
__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"]

‎src/mcp/server/message_queue/base.py

Copy file name to clipboard
+116Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import logging
2+
from collections.abc import Awaitable, Callable
3+
from contextlib import asynccontextmanager
4+
from typing import Protocol, runtime_checkable
5+
from uuid import UUID
6+
7+
from pydantic import ValidationError
8+
9+
from mcp.shared.message import SessionMessage
10+
11+
logger = logging.getLogger(__name__)
12+
13+
MessageCallback = Callable[[SessionMessage | Exception], Awaitable[None]]
14+
15+
16+
@runtime_checkable
17+
class MessageDispatch(Protocol):
18+
"""Abstract interface for SSE message dispatching.
19+
20+
This interface allows messages to be published to sessions and callbacks to be
21+
registered for message handling, enabling multiple servers to handle requests.
22+
"""
23+
24+
async def publish_message(
25+
self, session_id: UUID, message: SessionMessage | str
26+
) -> bool:
27+
"""Publish a message for the specified session.
28+
29+
Args:
30+
session_id: The UUID of the session this message is for
31+
message: The message to publish (SessionMessage or str for invalid JSON)
32+
33+
Returns:
34+
bool: True if message was published, False if session not found
35+
"""
36+
...
37+
38+
@asynccontextmanager
39+
async def subscribe(self, session_id: UUID, callback: MessageCallback):
40+
"""Request-scoped context manager that subscribes to messages for a session.
41+
42+
Args:
43+
session_id: The UUID of the session to subscribe to
44+
callback: Async callback function to handle messages for this session
45+
"""
46+
yield
47+
48+
async def session_exists(self, session_id: UUID) -> bool:
49+
"""Check if a session exists.
50+
51+
Args:
52+
session_id: The UUID of the session to check
53+
54+
Returns:
55+
bool: True if the session is active, False otherwise
56+
"""
57+
...
58+
59+
async def close(self) -> None:
60+
"""Close the message dispatch."""
61+
...
62+
63+
64+
class InMemoryMessageDispatch:
65+
"""Default in-memory implementation of the MessageDispatch interface.
66+
67+
This implementation immediately dispatches messages to registered callbacks when
68+
messages are received without any queuing behavior.
69+
"""
70+
71+
def __init__(self) -> None:
72+
self._callbacks: dict[UUID, MessageCallback] = {}
73+
74+
async def publish_message(
75+
self, session_id: UUID, message: SessionMessage | str
76+
) -> bool:
77+
"""Publish a message for the specified session."""
78+
if session_id not in self._callbacks:
79+
logger.warning(f"Message dropped: unknown session {session_id}")
80+
return False
81+
82+
# Parse string messages or recreate original ValidationError
83+
if isinstance(message, str):
84+
try:
85+
callback_argument = SessionMessage.model_validate_json(message)
86+
except ValidationError as exc:
87+
callback_argument = exc
88+
else:
89+
callback_argument = message
90+
91+
# Call the callback with either valid message or recreated ValidationError
92+
await self._callbacks[session_id](callback_argument)
93+
94+
logger.debug(f"Message dispatched to session {session_id}")
95+
return True
96+
97+
@asynccontextmanager
98+
async def subscribe(self, session_id: UUID, callback: MessageCallback):
99+
"""Request-scoped context manager that subscribes to messages for a session."""
100+
self._callbacks[session_id] = callback
101+
logger.debug(f"Subscribing to messages for session {session_id}")
102+
103+
try:
104+
yield
105+
finally:
106+
if session_id in self._callbacks:
107+
del self._callbacks[session_id]
108+
logger.debug(f"Unsubscribed from session {session_id}")
109+
110+
async def session_exists(self, session_id: UUID) -> bool:
111+
"""Check if a session exists."""
112+
return session_id in self._callbacks
113+
114+
async def close(self) -> None:
115+
"""Close the message dispatch."""
116+
pass

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.