Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 08f4e01

Browse filesBrowse files
authored
add callback for logging message notification (#314)
1 parent a9aca20 commit 08f4e01
Copy full SHA for 08f4e01

File tree

Expand file treeCollapse file tree

3 files changed

+113
-1
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+113
-1
lines changed

‎src/mcp/client/session.py

Copy file name to clipboardExpand all lines: src/mcp/client/session.py
+25Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ async def __call__(
2424
) -> types.ListRootsResult | types.ErrorData: ...
2525

2626

27+
class LoggingFnT(Protocol):
28+
async def __call__(
29+
self,
30+
params: types.LoggingMessageNotificationParams,
31+
) -> None: ...
32+
33+
2734
async def _default_sampling_callback(
2835
context: RequestContext["ClientSession", Any],
2936
params: types.CreateMessageRequestParams,
@@ -43,6 +50,12 @@ async def _default_list_roots_callback(
4350
)
4451

4552

53+
async def _default_logging_callback(
54+
params: types.LoggingMessageNotificationParams,
55+
) -> None:
56+
pass
57+
58+
4659
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
4760
types.ClientResult | types.ErrorData
4861
)
@@ -64,6 +77,7 @@ def __init__(
6477
read_timeout_seconds: timedelta | None = None,
6578
sampling_callback: SamplingFnT | None = None,
6679
list_roots_callback: ListRootsFnT | None = None,
80+
logging_callback: LoggingFnT | None = None,
6781
) -> None:
6882
super().__init__(
6983
read_stream,
@@ -74,6 +88,7 @@ def __init__(
7488
)
7589
self._sampling_callback = sampling_callback or _default_sampling_callback
7690
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
91+
self._logging_callback = logging_callback or _default_logging_callback
7792

7893
async def initialize(self) -> types.InitializeResult:
7994
sampling = types.SamplingCapability()
@@ -321,3 +336,13 @@ async def _received_request(
321336
return await responder.respond(
322337
types.ClientResult(root=types.EmptyResult())
323338
)
339+
340+
async def _received_notification(
341+
self, notification: types.ServerNotification
342+
) -> None:
343+
"""Handle notifications from the server."""
344+
match notification.root:
345+
case types.LoggingMessageNotification(params=params):
346+
await self._logging_callback(params)
347+
case _:
348+
pass

‎src/mcp/shared/memory.py

Copy file name to clipboardExpand all lines: src/mcp/shared/memory.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12-
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
12+
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
1313
from mcp.server import Server
1414
from mcp.types import JSONRPCMessage
1515

@@ -56,6 +56,7 @@ async def create_connected_server_and_client_session(
5656
read_timeout_seconds: timedelta | None = None,
5757
sampling_callback: SamplingFnT | None = None,
5858
list_roots_callback: ListRootsFnT | None = None,
59+
logging_callback: LoggingFnT | None = None,
5960
raise_exceptions: bool = False,
6061
) -> AsyncGenerator[ClientSession, None]:
6162
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -84,6 +85,7 @@ async def create_connected_server_and_client_session(
8485
read_timeout_seconds=read_timeout_seconds,
8586
sampling_callback=sampling_callback,
8687
list_roots_callback=list_roots_callback,
88+
logging_callback=logging_callback,
8789
) as client_session:
8890
await client_session.initialize()
8991
yield client_session

‎tests/client/test_logging_callback.py

Copy file name to clipboard
+85Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import List, Literal
2+
3+
import anyio
4+
import pytest
5+
6+
from mcp.shared.memory import (
7+
create_connected_server_and_client_session as create_session,
8+
)
9+
from mcp.types import (
10+
LoggingMessageNotificationParams,
11+
TextContent,
12+
)
13+
14+
15+
class LoggingCollector:
16+
def __init__(self):
17+
self.log_messages: List[LoggingMessageNotificationParams] = []
18+
19+
async def __call__(self, params: LoggingMessageNotificationParams) -> None:
20+
self.log_messages.append(params)
21+
22+
23+
@pytest.mark.anyio
24+
async def test_logging_callback():
25+
from mcp.server.fastmcp import FastMCP
26+
27+
server = FastMCP("test")
28+
logging_collector = LoggingCollector()
29+
30+
# Create a simple test tool
31+
@server.tool("test_tool")
32+
async def test_tool() -> bool:
33+
# The actual tool is very simple and just returns True
34+
return True
35+
36+
# Create a function that can send a log notification
37+
@server.tool("test_tool_with_log")
38+
async def test_tool_with_log(
39+
message: str, level: Literal["debug", "info", "warning", "error"], logger: str
40+
) -> bool:
41+
"""Send a log notification to the client."""
42+
await server.get_context().log(
43+
level=level,
44+
message=message,
45+
logger_name=logger,
46+
)
47+
return True
48+
49+
async with anyio.create_task_group() as tg:
50+
async with create_session(
51+
server._mcp_server, logging_callback=logging_collector
52+
) as client_session:
53+
54+
async def listen_session():
55+
try:
56+
async for message in client_session.incoming_messages:
57+
if isinstance(message, Exception):
58+
raise message
59+
except anyio.EndOfStream:
60+
pass
61+
62+
tg.start_soon(listen_session)
63+
64+
# First verify our test tool works
65+
result = await client_session.call_tool("test_tool", {})
66+
assert result.isError is False
67+
assert isinstance(result.content[0], TextContent)
68+
assert result.content[0].text == "true"
69+
70+
# Now send a log message via our tool
71+
log_result = await client_session.call_tool(
72+
"test_tool_with_log",
73+
{
74+
"message": "Test log message",
75+
"level": "info",
76+
"logger": "test_logger",
77+
},
78+
)
79+
assert log_result.isError is False
80+
assert len(logging_collector.log_messages) == 1
81+
assert logging_collector.log_messages[
82+
0
83+
] == LoggingMessageNotificationParams(
84+
level="info", logger="test_logger", data="Test log message"
85+
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.