From eb8e3744a7903932ab02982e506d6a75e79fe3ab Mon Sep 17 00:00:00 2001 From: Renxia Wang Date: Sat, 29 Mar 2025 20:49:51 -0400 Subject: [PATCH 001/249] feat(transport): Add customizable HTTP request builder support (#86) Enhances FlowSseClient and HttpClientSseClientTransport to accept a custom HttpRequest.Builder, allowing for greater flexibility when configuring HTTP requests. This enables clients to customize headers, timeouts, and other request properties across all SSE connections and message sending operations. Signed-off-by: Christian Tzolov --- .../client/transport/FlowSseClient.java | 15 ++++++- .../HttpClientSseClientTransport.java | 41 +++++++++++++++++-- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java index 7fc679937..50af35c70 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java @@ -39,6 +39,8 @@ public class FlowSseClient { private final HttpClient httpClient; + private final HttpRequest.Builder requestBuilder; + /** * Pattern to extract the data content from SSE data field lines. Matches lines * starting with "data:" and captures the remaining content. @@ -92,7 +94,17 @@ public interface SseEventHandler { * @param httpClient the {@link HttpClient} instance to use for SSE connections */ public FlowSseClient(HttpClient httpClient) { + this(httpClient, HttpRequest.newBuilder()); + } + + /** + * Creates a new FlowSseClient with the specified HTTP client and request builder. + * @param httpClient the {@link HttpClient} instance to use for SSE connections + * @param requestBuilder the {@link HttpRequest.Builder} to use for SSE requests + */ + public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) { this.httpClient = httpClient; + this.requestBuilder = requestBuilder; } /** @@ -109,8 +121,7 @@ public FlowSseClient(HttpClient httpClient) { * @throws RuntimeException if the connection fails with a non-200 status code */ public void subscribe(String url, SseEventHandler eventHandler) { - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(url)) + HttpRequest request = this.requestBuilder.uri(URI.create(url)) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") .GET() diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 696efdffd..0b482533d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -82,6 +82,9 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ private final HttpClient httpClient; + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + /** JSON object mapper for message serialization/deserialization */ protected ObjectMapper objectMapper; @@ -126,15 +129,33 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas */ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param clientBuilder the HTTP client builder to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, + String baseUri, String sseEndpoint, ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(clientBuilder, "clientBuilder must not be null"); + Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = baseUri; this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); - this.sseClient = new FlowSseClient(this.httpClient); + this.requestBuilder = requestBuilder; + + this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); } /** @@ -159,6 +180,8 @@ public static class Builder { private ObjectMapper objectMapper = new ObjectMapper(); + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -190,6 +213,17 @@ public Builder clientBuilder(HttpClient.Builder clientBuilder) { return this; } + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + /** * Sets the object mapper for JSON serialization/deserialization. * @param objectMapper the object mapper @@ -206,7 +240,7 @@ public Builder objectMapper(ObjectMapper objectMapper) { * @return a new transport instance */ public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder, baseUri, sseEndpoint, objectMapper); + return new HttpClientSseClientTransport(clientBuilder, requestBuilder, baseUri, sseEndpoint, objectMapper); } } @@ -301,8 +335,7 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); - HttpRequest request = HttpRequest.newBuilder() - .uri(URI.create(this.baseUri + endpoint)) + HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); From c3a7c1ac1e04c141e95df1b1a77dc127b7ce0311 Mon Sep 17 00:00:00 2001 From: jitokim Date: Sun, 6 Apr 2025 02:12:34 +0900 Subject: [PATCH 002/249] perf(webflux): optimize session broadcasting with Flux.fromIterable (#109) Replace Flux.fromStream(sessions.values().stream()) with more efficient Flux.fromIterable(sessions.values()) to eliminate unnecessary stream conversion when broadcasting messages to active sessions Signed-off-by: jitokim --- .../server/transport/WebFluxSseServerTransportProvider.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 85a39a82f..af2ff06a3 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -171,10 +171,10 @@ public Mono notifyClients(String method, Map params) { logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - return Flux.fromStream(sessions.values().stream()) + return Flux.fromIterable(sessions.values()) .flatMap(session -> session.sendNotification(method, params) - .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), - e.getMessage())) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) .onErrorComplete()) .then(); } From cd624a7d5719db9648711c986be9bc9a149a34e4 Mon Sep 17 00:00:00 2001 From: Oleksandr Popov Date: Sun, 6 Apr 2025 10:29:29 +0200 Subject: [PATCH 003/249] fix: correct typos and improve documentation (#35) Signed-off-by: Christian Tzolov --- mcp/pom.xml | 2 +- .../io/modelcontextprotocol/client/McpAsyncClient.java | 10 +++++++++- .../client/transport/HttpClientSseClientTransport.java | 2 +- .../client/transport/StdioClientTransport.java | 2 +- .../io/modelcontextprotocol/spec/McpClientSession.java | 4 ++-- .../spec/McpClientSessionTests.java | 2 +- 6 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mcp/pom.xml b/mcp/pom.xml index f6e93b39c..edb1c8f07 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -97,7 +97,7 @@ test - diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 379b47e23..ce49b0a5e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -364,7 +364,7 @@ private Mono withInitializationCheck(String actionName, } // -------------------------- - // Basic Utilites + // Basic Utilities // -------------------------- /** @@ -751,6 +751,14 @@ private NotificationHandler asyncPromptsChangeNotificationHandler( // -------------------------- // Logging // -------------------------- + /** + * Create a notification handler for logging notifications from the server. This + * handler automatically distributes logging messages to all registered consumers. + * @param loggingConsumers List of consumers that will be notified when a logging + * message is received. Each consumer receives the logging message notification. + * @return A NotificationHandler that processes log notifications by distributing the + * message to all registered consumers + */ private NotificationHandler asyncLoggingNotificationHandler( List>> loggingConsumers) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 0b482533d..a5bdd43e2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -376,7 +376,7 @@ public Mono closeGracefully() { } /** - * Unmarshals data to the specified type using the configured object mapper. + * Unmarshal data to the specified type using the configured object mapper. * @param data the data to unmarshal * @param typeRef the type reference for the target type * @param the target type diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index f9a97849f..9d71cbb48 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -292,7 +292,7 @@ private void startInboundProcessing() { */ private void startOutboundProcessing() { this.handleOutbound(messages -> messages - // this bit is important since writes come from user threads and we + // this bit is important since writes come from user threads, and we // want to ensure that the actual writing happens on a dedicated thread .publishOn(outboundScheduler) .handle((message, s) -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index e29646e6a..719a78001 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -107,7 +107,7 @@ public interface NotificationHandler { public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { - Assert.notNull(requestTimeout, "The requstTimeout can not be null"); + Assert.notNull(requestTimeout, "The requestTimeout can not be null"); Assert.notNull(transport, "The transport can not be null"); Assert.notNull(requestHandlers, "The requestHandlers can not be null"); Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); @@ -127,7 +127,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, logger.debug("Received Response: {}", response); var sink = pendingResponses.remove(response.id()); if (sink == null) { - logger.warn("Unexpected response for unkown id {}", response.id()); + logger.warn("Unexpected response for unknown id {}", response.id()); } else { sink.success(response); diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index 715d6651e..f72be43e0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -61,7 +61,7 @@ void tearDown() { void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("requstTimeout can not be null"); + .hasMessageContaining("The requestTimeout can not be null"); assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) From 734153a445585cd452f041520583e6517f3674f3 Mon Sep 17 00:00:00 2001 From: jitokim Date: Sun, 6 Apr 2025 02:10:43 +0900 Subject: [PATCH 004/249] fix typo in WebFluxSseIntegrationTests Signed-off-by: jitokim --- .../WebFluxSseIntegrationTests.java | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2be2f81f2..dbfad821f 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -52,8 +52,6 @@ public class WebFluxSseIntegrationTests { private static final int PORT = 8182; - // private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; @@ -62,7 +60,7 @@ public class WebFluxSseIntegrationTests { private WebFluxSseServerTransportProvider mcpServerTransportProvider; - ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @BeforeEach public void before() { @@ -77,11 +75,11 @@ public void before() { ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - clientBulders.put("httpclient", + clientBuilders.put("httpclient", McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build())); - clientBulders.put("webflux", + clientBuilders.put("webflux", McpClient .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) .sseEndpoint(CUSTOM_SSE_ENDPOINT) @@ -103,7 +101,7 @@ public void after() { @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageWithoutSamplingCapabilities(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -134,7 +132,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { void testCreateMessageSuccess(String clientType) throws InterruptedException { // Client - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); @@ -203,7 +201,7 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); @@ -250,7 +248,7 @@ void testRootsSuccess(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithoutCapability(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -284,7 +282,7 @@ void testRootsWithoutCapability(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsNotifciationWithEmptyRootsList(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); var mcpServer = McpServer.sync(mcpServerTransportProvider) @@ -311,7 +309,7 @@ void testRootsNotifciationWithEmptyRootsList(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithMultipleHandlers(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -345,7 +343,7 @@ void testRootsWithMultipleHandlers(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testRootsServerCloseWithActiveSubscription(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -390,7 +388,7 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testToolCallSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( @@ -430,7 +428,7 @@ void testToolCallSuccess(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testToolListChangeHandlingSuccess(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( @@ -500,7 +498,7 @@ void testToolListChangeHandlingSuccess(String clientType) { @ValueSource(strings = { "httpclient", "webflux" }) void testInitialize(String clientType) { - var clientBuilder = clientBulders.get(clientType); + var clientBuilder = clientBuilders.get(clientType); var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); From 0db4c0f70d28c72ec94750c289e001d54a5bace6 Mon Sep 17 00:00:00 2001 From: minguncle <57527858+minguncle@users.noreply.github.com> Date: Thu, 27 Mar 2025 15:53:30 +0800 Subject: [PATCH 005/249] feat(webmvc): Add support for custom context paths in WebMvcSseServerTransportProvider Adds the ability to specify a base URL for message endpoints in WebMvcSseServerTransportProvider, enabling proper handling of custom servlet context paths in Spring WebMVC applications. This ensures that clients receive the correct full endpoint URL when connecting through SSE. - Add messageBaseUrl field to WebMvcSseServerTransportProvider - Create new constructor that accepts messageBaseUrl parameter - Update endpoint event to include base URL in the message endpoint - Add TomcatTestUtil class to simplify test server creation - Add WebMvcSseCustomContextPathTests to verify custom context path functionality - Refactor WebMvcSseIntegrationTests to use the new TomcatTestUtil Co-authored-by: Christian Tzolov Signed-off-by: Christian Tzolov --- .../WebMvcSseServerTransportProvider.java | 51 ++++++--- .../server/TomcatTestUtil.java | 60 ++++++++++ .../WebMvcSseCustomContextPathTests.java | 105 ++++++++++++++++++ .../server/WebMvcSseIntegrationTests.java | 62 +++-------- mcp-test/pom.xml | 1 + 5 files changed, 216 insertions(+), 63 deletions(-) create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 65416b256..f6dbd4779 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -91,6 +91,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private final String sseEndpoint; + private final String messageBaseUrl; + private final RouterFunction routerFunction; private McpServerSession.Factory sessionFactory; @@ -105,6 +107,20 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + /** + * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + /** * Constructs a new WebMvcSseServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -116,11 +132,30 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi * @throws IllegalArgumentException if any parameter is null */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, "", messageEndpoint, sseEndpoint); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageBaseUrl The base URL for the message endpoint, used to construct the + * full endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageBaseUrl, String messageEndpoint, + String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageBaseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); this.objectMapper = objectMapper; + this.messageBaseUrl = messageBaseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() @@ -129,20 +164,6 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag .build(); } - /** - * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; @@ -248,7 +269,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { try { sseBuilder.id(sessionId) .event(ENDPOINT_EVENT_TYPE) - .data(messageEndpoint + "?sessionId=" + sessionId); + .data(this.messageBaseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } catch (Exception e) { logger.error("Failed to send initial endpoint event: {}", e.getMessage()); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java new file mode 100644 index 000000000..fcd7fb4dc --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java @@ -0,0 +1,60 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; + +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) { + } + + public TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { + + // Set up Tomcat first + var tomcat = new Tomcat(); + tomcat.setPort(port); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext(contextPath, baseDir); + + // Create and configure Spring WebMvc context + var appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(componentClass); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + // Configure and start the connector with async support + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return new TomcatServer(tomcat, appContext); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java new file mode 100644 index 000000000..0e81104b9 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +public class WebMvcSseCustomContextPathTests { + + private static final String CUSTOM_CONTEXT_PATH = "/app/1"; + + private static final int PORT = 8183; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = new TomcatTestUtil().createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + + clientBuilder = McpClient.sync(clientTransport); + + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCustomContextPath() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + assertThat(client.initialize()).isNotNull(); + } + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + + return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 3ff755ca9..f9190fd70 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -25,10 +25,8 @@ import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,15 +36,12 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.client.RestClient; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; -import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; public class WebMvcSseIntegrationTests { @@ -75,55 +70,26 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro } - private Tomcat tomcat; - - private AnnotationConfigWebApplicationContext appContext; + private TomcatTestUtil.TomcatServer tomcatServer; @BeforeEach public void before() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - mcpServerTransportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addServletMappingDecoded("/*", "dispatcherServlet"); + tomcatServer = new TomcatTestUtil().createTomcatServer("", PORT, TestConfig.class); try { - // Configure and start the connector with async support - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests - tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); } - this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + + // Get the transport from Spring context + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + } @AfterEach @@ -131,13 +97,13 @@ public void after() { if (mcpServerTransportProvider != null) { mcpServerTransportProvider.closeGracefully().block(); } - if (appContext != null) { - appContext.close(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); } - if (tomcat != null) { + if (tomcatServer.tomcat() != null) { try { - tomcat.stop(); - tomcat.destroy(); + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); } catch (LifecycleException e) { throw new RuntimeException("Failed to stop Tomcat", e); diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index b995618af..95f5dc30a 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -80,6 +80,7 @@ logback-classic ${logback.version} + From 3fa415228757c78bdbcabdfeaf548b3b09882b2f Mon Sep 17 00:00:00 2001 From: zhangzhenhua Date: Wed, 2 Apr 2025 13:54:14 +0800 Subject: [PATCH 006/249] feat(webflux): Add base URL support to WebFluxSseServerTransportProvider (#102) Adds the ability to specify a base URL prefix for message endpoints in the WebFlux SSE server transport provider. This enhancement allows for proper URL construction when the server is running behind a proxy or in a context with a base path. - Add new constructor with baseUrl parameter - Add basePath() method to Builder class - Modify SSE endpoint event to include baseUrl prefix Signed-off-by: Christian Tzolov --- .../WebFluxSseServerTransportProvider.java | 75 ++++++++++++++----- 1 file changed, 58 insertions(+), 17 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index af2ff06a3..df8dd0211 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -82,8 +82,16 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + public static final String DEFAULT_BASE_URL = ""; + private final ObjectMapper objectMapper; + /** + * Base URL for the message endpoint. This is used to construct the full URL for + * clients to send their JSON-RPC messages. + */ + private final String baseUrl; + private final String messageEndpoint; private final String sseEndpoint; @@ -102,6 +110,20 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private volatile boolean isClosing = false; + /** + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + /** * Constructs a new WebFlux SSE server transport provider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -112,11 +134,28 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux messag base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); this.objectMapper = objectMapper; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() @@ -125,20 +164,6 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa .build(); } - /** - * Constructs a new WebFlux SSE server transport provider instance with the default - * SSE endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; @@ -179,7 +204,8 @@ public Mono notifyClients(String method, Map params) { .then(); } - // FIXME: This javadoc makes claims about using isClosing flag but it's not actually + // FIXME: This javadoc makes claims about using isClosing flag but it's not + // actually // doing that. /** * Initiates a graceful shutdown of all the sessions. This method ensures all active @@ -245,7 +271,7 @@ private Mono handleSseConnection(ServerRequest request) { logger.debug("Sending initial endpoint event to session: {}", sessionId); sink.next(ServerSentEvent.builder() .event(ENDPOINT_EVENT_TYPE) - .data(messageEndpoint + "?sessionId=" + sessionId) + .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) .build()); sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); @@ -360,6 +386,8 @@ public static class Builder { private ObjectMapper objectMapper; + private String baseUrl = DEFAULT_BASE_URL; + private String messageEndpoint; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; @@ -377,6 +405,19 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the project basePath as endpoint prefix where clients should send their + * JSON-RPC messages + * @param baseUrl the message basePath . Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if basePath is null + */ + public Builder basePath(String baseUrl) { + Assert.notNull(baseUrl, "basePath must not be null"); + this.baseUrl = baseUrl; + return this; + } + /** * Sets the endpoint URI where clients should send their JSON-RPC messages. * @param messageEndpoint The message endpoint URI. Must not be null. @@ -411,7 +452,7 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); } } From b21cfab10ec9d51c8f57541767337dfd790a43b2 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 6 Apr 2025 16:33:25 +0200 Subject: [PATCH 007/249] refactor(webmvc): Rename messageBaseUrl to baseUrl for consistency Signed-off-by: Christian Tzolov --- .../WebMvcSseServerTransportProvider.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index f6dbd4779..fa2e357f9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -91,7 +91,7 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private final String sseEndpoint; - private final String messageBaseUrl; + private final String baseUrl; private final RouterFunction routerFunction; @@ -139,23 +139,23 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * Constructs a new WebMvcSseServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of messages. - * @param messageBaseUrl The base URL for the message endpoint, used to construct the - * full endpoint URL for clients. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageBaseUrl, String messageEndpoint, + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(messageBaseUrl, "Message base URL must not be null"); + Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); this.objectMapper = objectMapper; - this.messageBaseUrl = messageBaseUrl; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() @@ -269,7 +269,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { try { sseBuilder.id(sessionId) .event(ENDPOINT_EVENT_TYPE) - .data(this.messageBaseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } catch (Exception e) { logger.error("Failed to send initial endpoint event: {}", e.getMessage()); From 8fc72aed88616cfe4ba4fe8adae038b32fcc9f8b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 6 Apr 2025 18:41:25 +0200 Subject: [PATCH 008/249] feat(mcp): Add support for custom context paths in HTTP Servlet SSE server transport Enhance HttpServletSseServerTransportProvider to support deployment under non-root context paths by: - Adding baseUrl field and DEFAULT_BASE_URL constant - Creating new constructor that accepts a baseUrl parameter - Extending Builder with baseUrl configuration method - Prepending baseUrl to message endpoint in SSE events - Add HttpServletSseServerCustomContextPathTests to verify custom context path functionality - Extract common Tomcat server setup code to TomcatTestUtil for test reuse Related to #79 Signed-off-by: Christian Tzolov --- ...HttpServletSseServerTransportProvider.java | 37 +++++++- ...ervletSseServerCustomContextPathTests.java | 86 +++++++++++++++++++ ...rverTransportProviderIntegrationTests.java | 21 +---- .../server/transport/TomcatTestUtil.java | 45 ++++++++++ 4 files changed, 167 insertions(+), 22 deletions(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index a64b4a353..e52fc88b7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -80,9 +80,14 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Event type for endpoint information */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String DEFAULT_BASE_URL = ""; + /** JSON object mapper for serialization/deserialization */ private final ObjectMapper objectMapper; + /** Base URL for the server transport */ + private final String baseUrl; + /** The endpoint path for handling client messages */ private final String messageEndpoint; @@ -108,7 +113,22 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement */ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint) { this.objectMapper = objectMapper; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; } @@ -203,7 +223,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) this.sessions.put(sessionId, session); // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId); + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } /** @@ -449,6 +469,8 @@ public static class Builder { private ObjectMapper objectMapper = new ObjectMapper(); + private String baseUrl = DEFAULT_BASE_URL; + private String messageEndpoint; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; @@ -464,6 +486,17 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + /** * Sets the endpoint path where clients will send their messages. * @param messageEndpoint The message endpoint path @@ -502,7 +535,7 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java new file mode 100644 index 000000000..1254e2ad8 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class HttpServletSseServerCustomContextPathTests { + + private static final int PORT = 8195; + + private static final String CUSTOM_CONTEXT_PATH = "/api/v1"; + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .baseUrl(CUSTOM_CONTEXT_PATH) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, mcpServerTransportProvider); + + try { + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + CUSTOM_SSE_ENDPOINT) + .build()); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCustomContextPath() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + assertThat(client.initialize()).isNotNull(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 1cd395e74..b04940c79 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -26,7 +26,6 @@ import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; @@ -59,14 +58,6 @@ public class HttpServletSseServerTransportProviderIntegrationTests { @BeforeEach public void before() { - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - Context context = tomcat.addContext("", baseDir); - // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() .objectMapper(new ObjectMapper()) @@ -74,18 +65,8 @@ public void before() { .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); - // Add transport servlet to Tomcat - org.apache.catalina.Wrapper wrapper = context.createWrapper(); - wrapper.setName("mcpServlet"); - wrapper.setServlet(mcpServerTransportProvider); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addChild(wrapper); - context.addServletMappingDecoded("/*", "mcpServlet"); - + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); try { - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); tomcat.start(); assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java new file mode 100644 index 000000000..6f922dfa6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -0,0 +1,45 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.servlet.Servlet; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; + +import static org.junit.Assert.assertThat; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + + var tomcat = new Tomcat(); + tomcat.setPort(port); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Context context = tomcat.addContext("", baseDir); + Context context = tomcat.addContext(contextPath, baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(servlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + return tomcat; + } + +} From 13c4474b3ea00e75be47b653d315ba9de7125cb3 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 9 Apr 2025 16:05:07 +0200 Subject: [PATCH 009/249] Change the URLs used to test blocking rest calls Signed-off-by: Christian Tzolov --- .../io/modelcontextprotocol/WebFluxSseIntegrationTests.java | 6 +++--- .../server/WebMvcSseIntegrationTests.java | 6 +++--- ...tpServletSseServerTransportProviderIntegrationTests.java | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index dbfad821f..ac487b6f5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -396,7 +396,7 @@ void testToolCallSuccess(String clientType) { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -436,7 +436,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -453,7 +453,7 @@ void testToolListChangeHandlingSuccess(String clientType) { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index f9190fd70..420f4b987 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -388,7 +388,7 @@ void testToolCallSuccess() { new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() - .get() + .get()https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() .body(String.class); @@ -424,7 +424,7 @@ void testToolListChangeHandlingSuccess() { McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service - String response = RestClient.create() + String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md .get() .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() @@ -441,7 +441,7 @@ void testToolListChangeHandlingSuccess() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service - String response = RestClient.create() + String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md .get() .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") .retrieve() diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index b04940c79..e34baf9d6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -374,7 +374,7 @@ void testToolCallSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -411,7 +411,7 @@ void testToolListChangeHandlingSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -428,7 +428,7 @@ void testToolListChangeHandlingSuccess() { // perform a blocking call to a remote service String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); From fbea833384c097a46927624f1f7cbb9562c15e74 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 9 Apr 2025 16:17:16 +0200 Subject: [PATCH 010/249] Fix compilation issue introduced by the previous commit Signed-off-by: Christian Tzolov --- .../server/WebMvcSseIntegrationTests.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 420f4b987..c203e3bd5 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -388,8 +388,8 @@ void testToolCallSuccess() { new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() - .get()https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -424,9 +424,9 @@ void testToolListChangeHandlingSuccess() { McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service - String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md + String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); @@ -441,9 +441,9 @@ void testToolListChangeHandlingSuccess() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service - String https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md + String response = RestClient.create() .get() - .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") .retrieve() .body(String.class); assertThat(response).isNotBlank(); From fab434c088e7e90ad4cbbedd55b28c553536c7de Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Mon, 7 Apr 2025 14:30:04 +0200 Subject: [PATCH 011/249] refactor(client): enhance HttpClientSseClientTransport with flexible customization API (#117) - Add builder customizeClient() and customizeRequest() methods - Enable HTTP client and request configuration through consumer-based customization - Deprecate direct constructors in favor of the more flexible builder approach - Add test coverage for customization capabilities Co-authored-by: Christian Tzolov Signed-off-by: Christian Tzolov --- .../server/WebMvcSseIntegrationTests.java | 12 +-- .../HttpClientSseClientTransport.java | 92 ++++++++++++++++-- .../client/HttpSseMcpAsyncClientTests.java | 4 +- .../client/HttpSseMcpSyncClientTests.java | 2 +- .../HttpClientSseClientTransportTests.java | 97 ++++++++++++++++++- 5 files changed, 185 insertions(+), 22 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index c203e3bd5..d5c9f90ff 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -44,7 +44,7 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; -public class WebMvcSseIntegrationTests { +class WebMvcSseIntegrationTests { private static final int PORT = 8183; @@ -79,13 +79,13 @@ public void before() { try { tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); // Get the transport from Spring context mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); @@ -200,8 +200,7 @@ void testCreateMessageSuccess() throws InterruptedException { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); mcpClient.close(); mcpServer.close(); @@ -410,8 +409,7 @@ void testToolCallSuccess() { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); mcpClient.close(); mcpServer.close(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index a5bdd43e2..632d3844a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -13,6 +13,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; @@ -103,7 +104,10 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(String baseUri) { this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); } @@ -114,7 +118,10 @@ public HttpClientSseClientTransport(String baseUri) { * @param baseUri the base URI of the MCP server * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); } @@ -126,7 +133,10 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); @@ -141,18 +151,37 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint, + objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - Assert.notNull(clientBuilder, "clientBuilder must not be null"); + Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = baseUri; this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; - this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); + this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); @@ -164,7 +193,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @return a new builder instance */ public static Builder builder(String baseUri) { - return new Builder(baseUri); + return new Builder().baseUri(baseUri); } /** @@ -172,25 +201,50 @@ public static Builder builder(String baseUri) { */ public static class Builder { - private final String baseUri; + private String baseUri; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private HttpClient.Builder clientBuilder = HttpClient.newBuilder(); + private HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); private ObjectMapper objectMapper = new ObjectMapper(); - private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .header("Content-Type", "application/json"); + + /** + * Creates a new builder instance. + */ + Builder() { + // Default constructor + } /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. + * This constructor is deprecated and will be removed or made {@code protected} or + * {@code private} in a future release. */ + @Deprecated(forRemoval = true) public Builder(String baseUri) { Assert.hasText(baseUri, "baseUri must not be empty"); this.baseUri = baseUri; } + /** + * Sets the base URI. + * @param baseUri the base URI + * @return this builder + */ + Builder baseUri(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + /** * Sets the SSE endpoint path. * @param sseEndpoint the SSE endpoint path @@ -213,6 +267,17 @@ public Builder clientBuilder(HttpClient.Builder clientBuilder) { return this; } + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + /** * Sets the HTTP request builder. * @param requestBuilder the HTTP request builder @@ -224,6 +289,17 @@ public Builder requestBuilder(HttpRequest.Builder requestBuilder) { return this; } + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + /** * Sets the object mapper for JSON serialization/deserialization. * @param objectMapper the object mapper @@ -240,7 +316,8 @@ public Builder objectMapper(ObjectMapper objectMapper) { * @return a new transport instance */ public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder, requestBuilder, baseUri, sseEndpoint, objectMapper); + return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, + objectMapper); } } @@ -336,7 +413,6 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) - .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 15749d4ff..fdff4b777 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -15,7 +15,7 @@ * * @author Christian Tzolov */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(15) class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { String host = "http://localhost:3004"; @@ -29,7 +29,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); + return HttpClientSseClientTransport.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 067f92957..204cf2984 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -29,7 +29,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); + return HttpClientSseClientTransport.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 294056fbe..e5178c0ee 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -4,9 +4,15 @@ package io.modelcontextprotocol.client.transport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; import java.time.Duration; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; @@ -26,6 +32,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import com.fasterxml.jackson.databind.ObjectMapper; + /** * Tests for the {@link HttpClientSseClientTransport} class. * @@ -51,8 +59,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestHttpClientSseClientTransport(String baseUri) { - super(baseUri); + public TestHttpClientSseClientTransport(final String baseUri) { + super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper()); } public int getInboundMessageCount() { @@ -191,13 +199,14 @@ void testGracefulShutdown() { StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); // Message count should remain 0 after shutdown - assertThat(transport.getInboundMessageCount()).isEqualTo(0); + assertThat(transport.getInboundMessageCount()).isZero(); } @Test void testRetryBehavior() { // Create a client that simulates connection failures - HttpClientSseClientTransport failingTransport = new HttpClientSseClientTransport("http://non-existent-host"); + HttpClientSseClientTransport failingTransport = HttpClientSseClientTransport.builder("http://non-existent-host") + .build(); // Verify that the transport attempts to reconnect StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); @@ -275,4 +284,84 @@ void testMessageOrderPreservation() { assertThat(transport.getInboundMessageCount()).isEqualTo(3); } + @Test + void testCustomizeClient() { + // Create an atomic boolean to verify the customizer was called + AtomicBoolean customizerCalled = new AtomicBoolean(false); + + // Create a transport with the customizer + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + .customizeClient(builder -> { + builder.version(HttpClient.Version.HTTP_2); + customizerCalled.set(true); + }) + .build(); + + // Verify the customizer was called + assertThat(customizerCalled.get()).isTrue(); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testCustomizeRequest() { + // Create an atomic boolean to verify the customizer was called + AtomicBoolean customizerCalled = new AtomicBoolean(false); + + // Create a reference to store the custom header value + AtomicReference headerName = new AtomicReference<>(); + AtomicReference headerValue = new AtomicReference<>(); + + // Create a transport with the customizer + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + // Create a request customizer that adds a custom header + .customizeRequest(builder -> { + builder.header("X-Custom-Header", "test-value"); + customizerCalled.set(true); + + // Create a new request to verify the header was set + HttpRequest request = builder.uri(URI.create("http://example.com")).build(); + headerName.set("X-Custom-Header"); + headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null)); + }) + .build(); + + // Verify the customizer was called + assertThat(customizerCalled.get()).isTrue(); + + // Verify the header was set correctly + assertThat(headerName.get()).isEqualTo("X-Custom-Header"); + assertThat(headerValue.get()).isEqualTo("test-value"); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testChainedCustomizations() { + // Create atomic booleans to verify both customizers were called + AtomicBoolean clientCustomizerCalled = new AtomicBoolean(false); + AtomicBoolean requestCustomizerCalled = new AtomicBoolean(false); + + // Create a transport with both customizers chained + HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) + .customizeClient(builder -> { + builder.connectTimeout(Duration.ofSeconds(30)); + clientCustomizerCalled.set(true); + }) + .customizeRequest(builder -> { + builder.header("X-Api-Key", "test-api-key"); + requestCustomizerCalled.set(true); + }) + .build(); + + // Verify both customizers were called + assertThat(clientCustomizerCalled.get()).isTrue(); + assertThat(requestCustomizerCalled.get()).isTrue(); + + // Clean up + customizedTransport.closeGracefully().block(); + } + } From 391ec19fdc346c6d0ebf369f692c370a48339d3d Mon Sep 17 00:00:00 2001 From: Christian Tzolov <1351573+tzolov@users.noreply.github.com> Date: Thu, 10 Apr 2025 12:26:29 +0200 Subject: [PATCH 012/249] refactor: change notification params type from Map to Object (#137) * refactor: change notification params type from Map to Object This change generalizes the parameter type for notification methods across the MCP framework, allowing for more flexible parameter passing. Instead of requiring parameters to be structured as a Map, the API now accepts any Object as parameters. The primary motivation is to simplify client usage by allowing direct passing of strongly-typed objects without requiring conversion to a Map first, as demonstrated in the McpAsyncServer logging notification implementation. Affected components: - McpSession interface and implementations - McpServerTransportProvider interface and implementations - McpSchema JSONRPCNotification record --------- Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseServerTransportProvider.java | 2 +- .../server/transport/WebMvcSseServerTransportProvider.java | 2 +- .../io/modelcontextprotocol/server/McpAsyncServer.java | 7 ++----- .../transport/HttpServletSseServerTransportProvider.java | 2 +- .../server/transport/StdioServerTransportProvider.java | 2 +- .../io/modelcontextprotocol/spec/McpClientSession.java | 2 +- .../main/java/io/modelcontextprotocol/spec/McpSchema.java | 2 +- .../io/modelcontextprotocol/spec/McpServerSession.java | 2 +- .../spec/McpServerTransportProvider.java | 4 ++-- .../main/java/io/modelcontextprotocol/spec/McpSession.java | 4 ++-- .../MockMcpServerTransportProvider.java | 2 +- 11 files changed, 14 insertions(+), 17 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index df8dd0211..be30bd72f 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -188,7 +188,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { * errors if any session fails to receive the message */ @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fa2e357f9..7bd1aa6c9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -179,7 +179,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { * @return A Mono that completes when the broadcast attempt is finished */ @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index df9386685..ec2a04c9e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -669,15 +669,12 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN return Mono.error(new McpError("Logging message must not be null")); } - Map params = this.objectMapper.convertValue(loggingMessageNotification, - new TypeReference>() { - }); - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { return Mono.empty(); } - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); } private McpServerSession.RequestHandler setLoggerRequestHandler() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index e52fc88b7..afdbff472 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -160,7 +160,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { * @return A Mono that completes when the broadcast attempt is finished */ @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index a8b980e90..819da9777 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -99,7 +99,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { } @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { if (this.session == null) { return Mono.error(new McpError("No session to close")); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 719a78001..0895e02b0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -258,7 +258,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc * @return A Mono that completes when the notification is sent */ @Override - public Mono sendNotification(String method, Map params) { + public Mono sendNotification(String method, Object params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); return this.transport.sendMessage(jsonrpcNotification); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e38403c32..4c596b628 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -191,7 +191,7 @@ public record JSONRPCRequest( // @formatter:off public record JSONRPCNotification( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, - @JsonProperty("params") Map params) implements JSONRPCMessage { + @JsonProperty("params") Object params) implements JSONRPCMessage { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index bcdf22486..46014af8d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -132,7 +132,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc } @Override - public Mono sendNotification(String method, Map params) { + public Mono sendNotification(String method, Object params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); return this.transport.sendMessage(jsonrpcNotification); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index dba8cc43f..5fdbd7ab6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -42,11 +42,11 @@ public interface McpServerTransportProvider { /** * Sends a notification to all connected clients. * @param method the name of the notification method to be called on the clients - * @param params a map of parameters to be sent with the notification + * @param params parameters to be sent with the notification * @return a Mono that completes when the notification has been broadcast * @see McpSession#sendNotification(String, Map) */ - Mono notifyClients(String method, Map params); + Mono notifyClients(String method, Object params); /** * Immediately closes all the transports with connected clients and releases any diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index b97c3ccc4..473a860c2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -63,10 +63,10 @@ default Mono sendNotification(String method) { * parameters with the notification. *

* @param method the name of the notification method to be sent to the counterparty - * @param params a map of parameters to be sent with the notification + * @param params parameters to be sent with the notification * @return a Mono that completes when the notification has been sent */ - Mono sendNotification(String method, Map params); + Mono sendNotification(String method, Object params); /** * Closes the session and releases any associated resources asynchronously. diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 3fb19180b..20a8c0cf5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -47,7 +47,7 @@ public void setSessionFactory(Factory sessionFactory) { } @Override - public Mono notifyClients(String method, Map params) { + public Mono notifyClients(String method, Object params) { return session.sendNotification(method, params); } From 2895d1589ac3c81366eccfc584c6c733d5846127 Mon Sep 17 00:00:00 2001 From: Christian Tzolov <1351573+tzolov@users.noreply.github.com> Date: Thu, 10 Apr 2025 13:18:55 +0200 Subject: [PATCH 013/249] fix: Add null check for session in WebFluxSseServerTransportProvider (#138) Add error handling to return a 404 NOT_FOUND response when a request is made with a non-existent session ID. This prevents potential NullPointerExceptions when processing requests with invalid session IDs. Signed-off-by: Christian Tzolov --- .../server/transport/WebFluxSseServerTransportProvider.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index be30bd72f..eed8a53af 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -306,6 +306,11 @@ private Mono handleMessage(ServerRequest request) { McpServerSession session = sessions.get(request.queryParam("sessionId").get()); + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get())); + } + return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); From c88ac937f3e195c7e767c61e5024737c3417ad72 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 9 Apr 2025 11:15:16 +0200 Subject: [PATCH 014/249] feat(mcp): refactor logging to use exchange for targeted client notifications (#132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactors the MCP logging system to use the exchange mechanism for sending logging notifications only to specific client sessions rather than broadcasting to all clients. - Move logging notification delivery from server-wide broadcast to per-session exchange - Implement per-session minimum logging level tracking and filtering - Add proper logging level filtering at the exchange level - Change setLoggingLevel from notification to request/response pattern (breaking change) - Deprecate global server.loggingNotification in favor of exchange.loggingNotification - Add SetLevelRequest record to McpSchema - Add integration test demonstrating filtered logging notifications Resolves #131 Signed-off-by: Christian Tzolov Co-authored-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 356 +++++++++++------ .../server/WebMvcSseIntegrationTests.java | 260 +++++++------ .../client/AbstractMcpAsyncClientTests.java | 14 +- .../server/AbstractMcpAsyncServerTests.java | 49 --- .../server/AbstractMcpSyncServerTests.java | 49 --- .../client/McpAsyncClient.java | 7 +- .../client/McpSyncClient.java | 1 - .../server/McpAsyncServer.java | 31 +- .../server/McpAsyncServerExchange.java | 44 +++ .../server/McpSyncServer.java | 13 +- .../server/McpSyncServerExchange.java | 17 +- .../modelcontextprotocol/spec/McpSchema.java | 5 + .../client/AbstractMcpAsyncClientTests.java | 14 +- .../server/AbstractMcpAsyncServerTests.java | 49 --- .../server/AbstractMcpSyncServerTests.java | 49 --- ...ervletSseServerCustomContextPathTests.java | 11 +- ...rverTransportProviderIntegrationTests.java | 365 ++++++++++++------ 17 files changed, 721 insertions(+), 613 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index ac487b6f5..d71fe1ab0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -111,27 +112,28 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { return Mono.just(mock(CallToolResult.class)); }); - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { - assertThat(client.initialize()).isNotNull(); + assertThat(client.initialize()).isNotNull(); - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } } + server.close(); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageSuccess(String clientType) throws InterruptedException { - // Client var clientBuilder = clientBuilders.get(clientType); Function samplingHandler = request -> { @@ -142,13 +144,6 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -183,15 +178,19 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } mcpServer.close(); } @@ -206,41 +205,42 @@ void testRootsSuccess(String clientType) { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -261,21 +261,21 @@ void testRootsWithoutCapability(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + try ( + // Create client without roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } } - mcpClient.close(); mcpServer.close(); } @@ -285,30 +285,31 @@ void testRootsNotifciationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(List.of()) // Empty roots list - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } - mcpClient.close(); mcpServer.close(); } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithMultipleHandlers(String clientType) { + var clientBuilder = clientBuilders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -321,21 +322,21 @@ void testRootsWithMultipleHandlers(String clientType) { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -348,28 +349,26 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } - // Close server while subscription is active mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); } // --------------------------------------- @@ -378,9 +377,9 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { String emptyJsonSchema = """ { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} } """; @@ -408,19 +407,19 @@ void testToolCallSuccess(String clientType) { .tools(tool1) .build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } - mcpClient.close(); mcpServer.close(); } @@ -443,13 +442,14 @@ void testToolListChangeHandlingSuccess(String clientType) { return callResponse; }); + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -458,39 +458,40 @@ void testToolListChangeHandlingSuccess(String clientType) { .body(String.class); assertThat(response).isNotBlank(); rootsRef.set(toolsUpdate); - }).build(); + }).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - mcpServer.notifyToolsListChanged(); + mcpServer.notifyToolsListChanged(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); - // Remove a tool - mcpServer.removeTool("tool1"); + // Remove a tool + mcpServer.removeTool("tool1"); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), + (exchange, request) -> callResponse); - mcpServer.addTool(tool2); + mcpServer.addTool(tool2); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -502,12 +503,115 @@ void testInitialize(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testLoggingNotification(String clientType) { + // Create a list to store received logging notifications + List receivedNotifications = new ArrayList<>(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), + (exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }); - mcpClient.close(); + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + // Wait for notifications to be processed + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(3); + + // First notification should be NOTICE level + assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + }); + } mcpServer.close(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index d5c9f90ff..be01365a1 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -125,27 +125,34 @@ void testCreateMessageWithoutSamplingCapabilities() { return Mono.just(mock(CallToolResult.class)); }); - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + //@formatter:off + var server = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder + .clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) {//@formatter:on + + assertThat(client.initialize()).isNotNull(); - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } } + server.close(); } @Test void testCreateMessageSuccess() throws InterruptedException { - // Client - Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); @@ -154,13 +161,6 @@ void testCreateMessageSuccess() throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -190,19 +190,25 @@ void testCreateMessageSuccess() throws InterruptedException { return Mono.just(callResponse); }); + //@formatter:off var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try ( + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) {//@formatter:on - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull().isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); + assertThat(response).isNotNull().isEqualTo(callResponse); + } mcpServer.close(); } @@ -214,41 +220,42 @@ void testRootsSuccess() { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -266,21 +273,22 @@ void testRootsWithoutCapability() { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } } - mcpClient.close(); mcpServer.close(); } @@ -292,20 +300,20 @@ void testRootsNotifciationWithEmptyRootsList() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(List.of()) // Empty roots list - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -321,20 +329,20 @@ void testRootsWithMultipleHandlers() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -343,28 +351,26 @@ void testRootsServerCloseWithActiveSubscription() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } - // Close server while subscription is active mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); } // --------------------------------------- @@ -400,18 +406,18 @@ void testToolCallSuccess() { .tools(tool1) .build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull().isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); + } - mcpClient.close(); mcpServer.close(); } @@ -431,13 +437,14 @@ void testToolListChangeHandlingSuccess() { return callResponse; }); + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -446,39 +453,40 @@ void testToolListChangeHandlingSuccess() { .body(String.class); assertThat(response).isNotBlank(); rootsRef.set(toolsUpdate); - }).build(); + }).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - mcpServer.notifyToolsListChanged(); + mcpServer.notifyToolsListChanged(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); - // Remove a tool - mcpServer.removeTool("tool1"); + // Remove a tool + mcpServer.removeTool("tool1"); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), + (exchange, request) -> callResponse); - mcpServer.addTool(tool2); + mcpServer.addTool(tool2); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -487,12 +495,12 @@ void testInitialize() { var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } - mcpClient.close(); mcpServer.close(); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 713563519..5452c8eac 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -453,15 +454,10 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { withClient(createMcpTransport(), mcpAsyncClient -> { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); - - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier + .create(mcpAsyncClient.initialize() + .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) + .verifyComplete(); }); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 7bcb9a8b2..a91632c6c 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -416,53 +416,4 @@ void testRootsChangeHandlers() { .doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 7846e053b..9a63143c9 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -388,53 +388,4 @@ void testRootsChangeHandlers() { assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); - } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index ce49b0a5e..df099836d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -786,10 +786,9 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { } return this.withInitializationCheck("setting logging level", initializedResult -> { - String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference() { - }); - Map params = Map.of("level", levelName); - return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params); + var params = new McpSchema.SetLevelRequest(loggingLevel); + return this.mcpSession.sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, new TypeReference() { + }).then(); }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 071d76462..32cf325e9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -6,7 +6,6 @@ import java.time.Duration; -import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index ec2a04c9e..062de13ed 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -21,6 +21,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; @@ -216,11 +217,17 @@ public Mono notifyPromptsListChanged() { // --------------------------------------- /** - * Send a logging message notification to all connected clients. Messages below the - * current minimum logging level will be filtered out. + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. * @param loggingMessageNotification The logging message to send * @return A Mono that completes when the notification has been sent + * @deprecated Use + * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. */ + @Deprecated public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { return this.delegate.loggingNotification(loggingMessageNotification); } @@ -257,6 +264,8 @@ private static class AsyncServerImpl extends McpAsyncServer { private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); @@ -677,12 +686,22 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN loggingMessageNotification); } - private McpServerSession.RequestHandler setLoggerRequestHandler() { + private McpServerSession.RequestHandler setLoggerRequestHandler() { return (exchange, params) -> { - this.minLoggingLevel = objectMapper.convertValue(params, new TypeReference() { - }); + return Mono.defer(() -> { - return Mono.empty(); + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + new TypeReference() { + }); + + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); + + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); + + return Mono.just(Map.of()); + }); }; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 658628448..889dc66d0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -1,9 +1,16 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + package io.modelcontextprotocol.server; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; /** @@ -11,6 +18,7 @@ * exchange provides methods to interact with the client and query its capabilities. * * @author Dariusz Jędrzejczyk + * @author Christian Tzolov */ public class McpAsyncServerExchange { @@ -20,6 +28,8 @@ public class McpAsyncServerExchange { private final McpSchema.Implementation clientInfo; + private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { }; @@ -101,4 +111,38 @@ public Mono listRoots(String cursor) { LIST_ROOTS_RESULT_TYPE_REF); } + /** + * Send a logging message notification to all connected clients. Messages below the + * current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + return Mono.defer(() -> { + if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) { + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); + } + return Mono.empty(); + }); + } + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 72eba8b86..bf3104508 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -4,9 +4,7 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; @@ -151,9 +149,16 @@ public void notifyPromptsListChanged() { } /** - * Send a logging message notification to all clients. - * @param loggingMessageNotification The logging message notification to send + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. + * @param loggingMessageNotification The logging message to send + * @deprecated Use + * {@link McpSyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. */ + @Deprecated public void loggingNotification(LoggingMessageNotification loggingMessageNotification) { this.asyncServer.loggingNotification(loggingMessageNotification).block(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index f121db552..52360e54b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -1,13 +1,19 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + package io.modelcontextprotocol.server; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; /** * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The * exchange provides methods to interact with the client and query its capabilities. * * @author Dariusz Jędrzejczyk + * @author Christian Tzolov */ public class McpSyncServerExchange { @@ -75,4 +81,13 @@ public McpSchema.ListRootsResult listRoots(String cursor) { return this.exchange.listRoots(cursor).block(); } + /** + * Send a logging message notification to all connected clients. Messages below the + * current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + */ + public void loggingNotification(LoggingMessageNotification loggingMessageNotification) { + this.exchange.loggingNotification(loggingMessageNotification).block(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 4c596b628..e621ac19b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1165,6 +1165,11 @@ public int level() { } // @formatter:on + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record SetLevelRequest(@JsonProperty("level") LoggingLevel level) { + } + // --------------------------- // Autocomplete // --------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index ac7b9e5ec..72b409af9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -31,6 +31,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -454,15 +455,10 @@ void testLoggingLevelsWithoutInitialization() { @Test void testLoggingLevels() { withClient(createMcpTransport(), mcpAsyncClient -> { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); - - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier + .create(mcpAsyncClient.initialize() + .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) + .verifyComplete(); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 4b4fc434f..c7c69b52b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -415,53 +415,4 @@ void testRootsChangeHandlers() { .doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 17feb36e5..8c9328cc7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -387,53 +387,4 @@ void testRootsChangeHandlers() { assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); } - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - var notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message with level " + level) - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - } - - @Test - void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().build()) // No logging capability - .build(); - - var notification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Test log message") - .build(); - - assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); - } - - @Test - void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); - } - } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 1254e2ad8..212a3c95d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -8,7 +8,6 @@ import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.spec.McpSchema; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; @@ -78,9 +77,13 @@ public void after() { @Test void testCustomContextPath() { - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - assertThat(client.initialize()).isNotNull(); + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + try (//@formatter:off + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) .build()) { //@formatter:on + + assertThat(client.initialize()).isNotNull(); + } + server.close(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index e34baf9d6..a7b634824 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server.transport; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @@ -44,7 +45,7 @@ public class HttpServletSseServerTransportProviderIntegrationTests { - private static final int PORT = 8185; + private static final int PORT = 8189; private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; @@ -110,27 +111,29 @@ void testCreateMessageWithoutSamplingCapabilities() { return Mono.just(mock(CallToolResult.class)); }); - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { - assertThat(client.initialize()).isNotNull(); + assertThat(client.initialize()).isNotNull(); - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } } + server.close(); } @Test void testCreateMessageSuccess() throws InterruptedException { - // Client - Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); @@ -139,13 +142,6 @@ void testCreateMessageSuccess() throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -180,15 +176,19 @@ void testCreateMessageSuccess() throws InterruptedException { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } mcpServer.close(); } @@ -200,42 +200,43 @@ void testRootsSuccess() { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); - mcpClient.close(); - mcpServer.close(); + mcpServer.close(); + } } @Test @@ -252,21 +253,19 @@ void testRootsWithoutCapability() { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } } - mcpClient.close(); mcpServer.close(); } @@ -278,20 +277,20 @@ void testRootsNotifciationWithEmptyRootsList() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(List.of()) // Empty roots list - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -307,20 +306,20 @@ void testRootsWithMultipleHandlers() { .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - assertThat(mcpClient.initialize()).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } - mcpClient.close(); mcpServer.close(); } @@ -329,28 +328,26 @@ void testRootsServerCloseWithActiveSubscription() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) - .build(); + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.rootsListChangedNotification(); + mcpClient.rootsListChangedNotification(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } - // Close server while subscription is active mcpServer.close(); - - // Verify client can handle server closure gracefully - mcpClient.close(); } // --------------------------------------- @@ -386,19 +383,18 @@ void testToolCallSuccess() { .tools(tool1) .build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } - mcpClient.close(); mcpServer.close(); } @@ -418,13 +414,14 @@ void testToolListChangeHandlingSuccess() { return callResponse; }); + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); - AtomicReference> rootsRef = new AtomicReference<>(); - var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -433,53 +430,167 @@ void testToolListChangeHandlingSuccess() { .body(String.class); assertThat(response).isNotBlank(); rootsRef.set(toolsUpdate); - }).build(); + }).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(rootsRef.get()).isNull(); + assertThat(rootsRef.get()).isNull(); - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - mcpServer.notifyToolsListChanged(); + mcpServer.notifyToolsListChanged(); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); - // Remove a tool - mcpServer.removeTool("tool1"); + // Remove a tool + mcpServer.removeTool("tool1"); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), + (exchange, request) -> callResponse); - mcpServer.addTool(tool2); + mcpServer.addTool(tool2); - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } - mcpClient.close(); mcpServer.close(); } @Test void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - var mcpClient = clientBuilder.build(); + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @Test + void testLoggingNotification() { + // Create a list to store received logging notifications + List receivedNotifications = new ArrayList<>(); - mcpClient.close(); + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema), + (exchange, request) -> { + + // Create and send notifications with different levels + + // This should be filtered out (DEBUG < NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .block(); + + // This should be sent (NOTICE >= NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build()) + .block(); + + // This should be sent (ERROR > NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build()) + .block(); + + // This should be filtered out (INFO < NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build()) + .block(); + + // This should be sent (ERROR >= NOTICE) + exchange + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build()) + .block(); + + return Mono.just(new CallToolResult("Logging test completed", false)); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + // Wait for notifications to be processed + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + + System.out.println("Received notifications: " + receivedNotifications); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(3); + + // First notification should be NOTICE level + assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); + assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + }); + } mcpServer.close(); } From 63724f17a4a7d72f8f28b149a5940122b4a5bf02 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 10 Apr 2025 16:54:24 +0200 Subject: [PATCH 015/249] refactor(tests): improve notification assertions in WebFluxSseIntegrationTests Replace index-based assertions with content-based lookups using a notification map. This change makes the tests more resilient by removing the dependency on notification order, which is important for asynchronous messaging tests. Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index d71fe1ab0..76f908b8a 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; @@ -596,20 +597,24 @@ void testLoggingNotification(String clientType) { // Should have received 3 notifications (1 NOTICE and 2 ERROR) assertThat(receivedNotifications).hasSize(3); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + // First notification should be NOTICE level - assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); // Second notification should be ERROR level - assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); // Third notification should be ERROR level - assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + assertThat(notificationMap.get("Another error message").level()) + .isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); }); } mcpServer.close(); From 2e953c81aa6d0e173801282fc03b01bfb413ff0f Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 10 Apr 2025 17:16:27 +0200 Subject: [PATCH 016/249] refactor(tests): improve notification assertions in HttpServletSseServerTransportProviderIntegrationTests Replace index-based assertions with content-based lookups using a notification map. This change makes the tests more resilient by removing the dependency on notification order, which is important for asynchronous messaging tests. Signed-off-by: Christian Tzolov --- ...rverTransportProviderIntegrationTests.java | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index a7b634824..135de83fa 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -9,6 +9,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; @@ -575,20 +576,24 @@ void testLoggingNotification() { // Should have received 3 notifications (1 NOTICE and 2 ERROR) assertThat(receivedNotifications).hasSize(3); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + // First notification should be NOTICE level - assertThat(receivedNotifications.get(0).level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(receivedNotifications.get(0).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(0).data()).isEqualTo("Notice message"); + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); // Second notification should be ERROR level - assertThat(receivedNotifications.get(1).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(1).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(1).data()).isEqualTo("Error message"); + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); // Third notification should be ERROR level - assertThat(receivedNotifications.get(2).level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(receivedNotifications.get(2).logger()).isEqualTo("test-logger"); - assertThat(receivedNotifications.get(2).data()).isEqualTo("Another error message"); + assertThat(notificationMap.get("Another error message").level()) + .isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); }); } mcpServer.close(); From f348a83e5acef05b6c8807c7000c59098b667d28 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 10 Apr 2025 18:06:46 +0200 Subject: [PATCH 017/249] Next development version Signed-off-by: Christian Tzolov --- mcp-bom/pom.xml | 2 +- mcp-spring/mcp-spring-webflux/pom.xml | 6 +++--- mcp-spring/mcp-spring-webmvc/pom.xml | 6 +++--- mcp-test/pom.xml | 4 ++-- mcp/pom.xml | 2 +- pom.xml | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 77d55da34..4f24f719f 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 186ade796..63c32a8a8 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 67e6b0aee..b59be6a03 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT test diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index 95f5dc30a..f1484ae77 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT diff --git a/mcp/pom.xml b/mcp/pom.xml index edb1c8f07..6b0f4a9fe 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT mcp jar diff --git a/pom.xml b/pom.xml index 8e7cca2a9..ff485b75d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.9.0-SNAPSHOT + 0.10.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk From 8068854227cb378e44dcfe2985c5b002b65626e2 Mon Sep 17 00:00:00 2001 From: James Ward Date: Mon, 14 Apr 2025 23:41:41 -0600 Subject: [PATCH 018/249] add access to server instructions (#148) --- .../client/McpAsyncClient.java | 15 +++++++++++++++ .../client/McpSyncClient.java | 9 +++++++++ 2 files changed, 24 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index df099836d..1a9c39360 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -112,6 +112,11 @@ public class McpAsyncClient { */ private McpSchema.ServerCapabilities serverCapabilities; + /** + * Server instructions. + */ + private String serverInstructions; + /** * Server implementation information. */ @@ -240,6 +245,15 @@ public McpSchema.ServerCapabilities getServerCapabilities() { return this.serverCapabilities; } + /** + * Get the server instructions that provide guidance to the client on how to interact + * with this server. + * @return The server instructions + */ + public String getServerInstructions() { + return this.serverInstructions; + } + /** * Get the server implementation information. * @return The server implementation details @@ -328,6 +342,7 @@ public Mono initialize() { return result.flatMap(initializeResult -> { this.serverCapabilities = initializeResult.capabilities(); + this.serverInstructions = initializeResult.instructions(); this.serverInfo = initializeResult.serverInfo(); logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 32cf325e9..8544c3637 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -79,6 +79,15 @@ public McpSchema.ServerCapabilities getServerCapabilities() { return this.delegate.getServerCapabilities(); } + /** + * Get the server instructions that provide guidance to the client on how to interact + * with this server. + * @return The instructions + */ + public String getServerInstructions() { + return this.delegate.getServerInstructions(); + } + /** * Get the server implementation information. * @return The server implementation details From 263e3741b285f92baf87ff166e8d6dc3eafe9124 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Wed, 9 Apr 2025 19:43:29 +0200 Subject: [PATCH 019/249] feat(test): Use dynamic port allocation in integration tests (#133) - Add TestUtil class with findAvailablePort() method to the mcp-test module - Add findAvailablePort() method to TomcatTestUtil classes - Replace hardcoded port numbers with dynamic port allocation Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 7 +++-- .../server/WebFluxSseMcpAsyncServerTests.java | 2 +- .../server/WebFluxSseMcpSyncServerTests.java | 2 +- .../server/TomcatTestUtil.java | 10 +++++- .../WebMvcSseAsyncServerTransportTests.java | 3 +- .../WebMvcSseCustomContextPathTests.java | 8 ++--- .../server/WebMvcSseIntegrationTests.java | 6 ++-- .../WebMvcSseSyncServerTransportTests.java | 3 +- .../modelcontextprotocol/server/TestUtil.java | 31 +++++++++++++++++++ ...ervletSseServerCustomContextPathTests.java | 7 +++-- ...rverTransportProviderIntegrationTests.java | 9 +++--- .../server/transport/TomcatTestUtil.java | 28 ++++++++++++++--- 12 files changed, 87 insertions(+), 29 deletions(-) create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 76f908b8a..214b97f1b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -18,6 +18,7 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -50,9 +51,9 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; -public class WebFluxSseIntegrationTests { +class WebFluxSseIntegrationTests { - private static final int PORT = 8182; + private static final int PORT = TestUtil.findAvailablePort(); private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; @@ -133,7 +134,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) throws InterruptedException { + void testCreateMessageSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 98844c741..cc33e7b94 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -23,7 +23,7 @@ @Timeout(15) // Giving extra time beyond the client timeout class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 71072855e..2fc104538 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -23,7 +23,7 @@ @Timeout(15) // Giving extra time beyond the client timeout class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - private static final int PORT = 8182; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java index fcd7fb4dc..ccf9e2d77 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java @@ -3,6 +3,10 @@ */ package io.modelcontextprotocol.server; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; @@ -14,10 +18,14 @@ */ public class TomcatTestUtil { + TomcatTestUtil() { + // Prevent instantiation + } + public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) { } - public TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { + public static TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { // Set up Tomcat first var tomcat = new Tomcat(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index 08d5de671..6a6ad17e9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -25,7 +25,7 @@ class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private Tomcat tomcat; @@ -73,7 +73,6 @@ protected McpServerTransportProvider createMcpTransportProvider() { // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); // Add servlet to Tomcat and get the wrapper var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 0e81104b9..1b5218cc5 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -22,11 +22,11 @@ import static org.assertj.core.api.Assertions.assertThat; -public class WebMvcSseCustomContextPathTests { +class WebMvcSseCustomContextPathTests { private static final String CUSTOM_CONTEXT_PATH = "/app/1"; - private static final int PORT = 8183; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; @@ -39,11 +39,11 @@ public class WebMvcSseCustomContextPathTests { @BeforeEach public void before() { - tomcatServer = new TomcatTestUtil().createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); + tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); try { tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index be01365a1..df527f87e 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -46,7 +46,7 @@ class WebMvcSseIntegrationTests { - private static final int PORT = 8183; + private static final int PORT = TestUtil.findAvailablePort(); private static final String MESSAGE_ENDPOINT = "/mcp/message"; @@ -75,7 +75,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro @BeforeEach public void before() { - tomcatServer = new TomcatTestUtil().createTomcatServer("", PORT, TestConfig.class); + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); try { tomcatServer.tomcat().start(); @@ -151,7 +151,7 @@ void testCreateMessageWithoutSamplingCapabilities() { } @Test - void testCreateMessageSuccess() throws InterruptedException { + void testCreateMessageSuccess() { Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index b85bed379..1964703c1 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -24,7 +24,7 @@ class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private static final int PORT = 8181; + private static final int PORT = TestUtil.findAvailablePort(); private Tomcat tomcat; @@ -72,7 +72,6 @@ protected WebMvcSseServerTransportProvider createMcpTransportProvider() { // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); // Add servlet to Tomcat and get the wrapper var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java new file mode 100644 index 000000000..0085f31ed --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java @@ -0,0 +1,31 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +public class TestUtil { + + TestUtil() { + // Prevent instantiation + } + + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 212a3c95d..2cd62889a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server.transport; import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; @@ -17,9 +18,9 @@ import static org.assertj.core.api.Assertions.assertThat; -public class HttpServletSseServerCustomContextPathTests { +class HttpServletSseServerCustomContextPathTests { - private static final int PORT = 8195; + private static final int PORT = TomcatTestUtil.findAvailablePort(); private static final String CUSTOM_CONTEXT_PATH = "/api/v1"; @@ -48,7 +49,7 @@ public void before() { try { tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 135de83fa..f25ce5678 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -12,6 +12,7 @@ import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; @@ -44,9 +45,9 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; -public class HttpServletSseServerTransportProviderIntegrationTests { +class HttpServletSseServerTransportProviderIntegrationTests { - private static final int PORT = 8189; + private static final int PORT = TomcatTestUtil.findAvailablePort(); private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; @@ -70,7 +71,7 @@ public void before() { tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); try { tomcat.start(); - assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); @@ -133,7 +134,7 @@ void testCreateMessageWithoutSamplingCapabilities() { } @Test - void testCreateMessageSuccess() throws InterruptedException { + void testCreateMessageSuccess() { Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java index 6f922dfa6..f61cdc413 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -3,19 +3,23 @@ */ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + import jakarta.servlet.Servlet; import org.apache.catalina.Context; -import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; -import static org.junit.Assert.assertThat; - /** * @author Christian Tzolov */ public class TomcatTestUtil { + TomcatTestUtil() { + // Prevent instantiation + } + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { var tomcat = new Tomcat(); @@ -24,7 +28,6 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se String baseDir = System.getProperty("java.io.tmpdir"); tomcat.setBaseDir(baseDir); - // Context context = tomcat.addContext("", baseDir); Context context = tomcat.addContext(contextPath, baseDir); // Add transport servlet to Tomcat @@ -42,4 +45,19 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se return tomcat; } + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + } From 472f07ec6da7a2233d0d73b2be8ff7b6f9a85233 Mon Sep 17 00:00:00 2001 From: mipengcheng3 Date: Tue, 8 Apr 2025 17:53:59 +0800 Subject: [PATCH 020/249] feat(mcp): add configurable request timeout to MCP server (#134) Adds the ability to configure request timeouts for MCP server operations. This enhancement allows setting a custom duration to wait for server responses before timing out requests, which applies to all requests made through the client including tool calls, resource access, and prompt operations. - Add requestTimeout parameter to McpServerSession constructor - Add requestTimeout field and builder method to server classes - Pass timeout configuration through to session creation - Add tests for both success and failure scenarios across different transport implementations - Default timeout is set to 10 seconds if not explicitly configured. --- .../WebFluxSseIntegrationTests.java | 149 ++++++++++++++++++ .../server/WebMvcSseIntegrationTests.java | 145 +++++++++++++++++ .../server/McpAsyncServer.java | 13 +- .../server/McpServer.java | 39 ++++- .../spec/McpServerSession.java | 12 +- ...rverTransportProviderIntegrationTests.java | 145 +++++++++++++++++ 6 files changed, 491 insertions(+), 12 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 214b97f1b..dab54376e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -48,6 +49,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -196,6 +198,153 @@ void testCreateMessageSuccess(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .requestTimeout(Duration.ofSeconds(4)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(3); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .requestTimeout(Duration.ofSeconds(1)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index df527f87e..07b36c25b 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -41,6 +42,7 @@ import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -212,6 +214,149 @@ void testCreateMessageSuccess() { mcpServer.close(); } + @Test + void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 062de13ed..4f7d0e87e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -90,8 +91,8 @@ public class McpAsyncServer { * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, features); + McpServerFeatures.Async features, Duration requestTimeout) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features); } /** @@ -271,7 +272,7 @@ private static class AsyncServerImpl extends McpAsyncServer { private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features) { + Duration requestTimeout, McpServerFeatures.Async features) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -330,9 +331,9 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider - .setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory( + transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d5427335d..60434a841 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -193,11 +194,28 @@ class AsyncSpecification { private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + private AsyncSpecification(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); this.transportProvider = transportProvider; } + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public AsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + /** * Sets the server implementation information that will be shared with clients * during connection initialization. This helps with version compatibility, @@ -565,7 +583,7 @@ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); } } @@ -619,11 +637,28 @@ class SyncSpecification { private final List>> rootsChangeHandlers = new ArrayList<>(); + private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + private SyncSpecification(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); this.transportProvider = transportProvider; } + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public SyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + /** * Sets the server implementation information that will be shared with clients * during connection initialization. This helps with version compatibility, @@ -992,7 +1027,7 @@ public McpSyncServer build() { this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); return new McpSyncServer(asyncServer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 46014af8d..46c356cdd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -27,6 +27,9 @@ public class McpServerSession implements McpSession { private final String id; + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + private final AtomicLong requestCounter = new AtomicLong(0); private final InitRequestHandler initRequestHandler; @@ -65,10 +68,11 @@ public class McpServerSession implements McpSession { * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use */ - public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, - InitNotificationHandler initNotificationHandler, Map> requestHandlers, - Map notificationHandlers) { + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, Map notificationHandlers) { this.id = id; + this.requestTimeout = requestTimeout; this.transport = transport; this.initRequestHandler = initHandler; this.initNotificationHandler = initNotificationHandler; @@ -116,7 +120,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc this.pendingResponses.remove(requestId); sink.error(error); }); - }).timeout(Duration.ofSeconds(10)).handle((jsonRpcResponse, sink) -> { + }).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { sink.error(new McpError(jsonRpcResponse.error())); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index f25ce5678..b8f040c7a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -42,6 +43,7 @@ import org.springframework.web.client.RestClient; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -194,6 +196,149 @@ void testCreateMessageSuccess() { mcpServer.close(); } + @Test + void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- From f7f8ccd0acb6d39558b65ceb1ae4e5f71619a37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Mon, 14 Apr 2025 12:08:30 +0200 Subject: [PATCH 021/249] Fix flaky test running blocking code in event loop (#155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace StepVerifier with assertWith for cleaner test assertions - Add try-with-resources blocks for proper client resource management - Use closeGracefully().block() for proper server shutdown Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 135 ++++++++---------- 1 file changed, 62 insertions(+), 73 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index dab54376e..6ba0911ed 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -37,10 +37,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; @@ -50,6 +48,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -109,12 +108,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), + (exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) + .thenReturn(mock(CallToolResult.class))); var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); @@ -151,6 +147,8 @@ void testCreateMessageSuccess(String clientType) { CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + AtomicReference samplingResult = new AtomicReference<>(); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -165,16 +163,9 @@ void testCreateMessageSuccess(String clientType) { .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createMessage(craeteMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); }); var mcpServer = McpServer.async(mcpServerTransportProvider) @@ -194,8 +185,17 @@ void testCreateMessageSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); } - mcpServer.close(); + mcpServer.closeGracefully().block(); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -218,16 +218,13 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - // Server CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + AtomicReference samplingResult = new AtomicReference<>(); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -242,16 +239,9 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createMessage(craeteMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); }); var mcpServer = McpServer.async(mcpServerTransportProvider) @@ -260,16 +250,30 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.close(); - mcpServer.close(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + + mcpServer.closeGracefully().block(); } @ParameterizedTest(name = "{0} : {displayName} ") @@ -283,7 +287,7 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); try { - TimeUnit.SECONDS.sleep(3); + TimeUnit.SECONDS.sleep(2); } catch (InterruptedException e) { throw new RuntimeException(e); @@ -292,11 +296,6 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - // Server CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), @@ -308,24 +307,9 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); }); var mcpServer = McpServer.async(mcpServerTransportProvider) @@ -334,15 +318,21 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.close(); - mcpServer.close(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + } + + mcpServer.closeGracefully().block(); } // --------------------------------------- @@ -412,9 +402,8 @@ void testRootsWithoutCapability(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { }).tools(tool).build(); - try ( - // Create client without roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + // Create client without roots capability + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { assertThat(mcpClient.initialize()).isNotNull(); From 84adde16a539e2e482bf4deb2c6c3b17b3c148fd Mon Sep 17 00:00:00 2001 From: mackey0225 Date: Tue, 15 Apr 2025 10:48:19 +0900 Subject: [PATCH 022/249] fix: correct typos in variable names, method names, and commentsj (#159) --- .../WebFluxSseServerTransportProvider.java | 2 +- .../WebFluxSseIntegrationTests.java | 6 +++--- .../server/WebMvcSseIntegrationTests.java | 6 +++--- .../server/AbstractMcpAsyncServerTests.java | 2 +- .../server/AbstractMcpSyncServerTests.java | 16 ++++++++-------- .../modelcontextprotocol/client/McpClient.java | 4 ++-- .../server/McpAsyncServer.java | 2 +- .../io/modelcontextprotocol/spec/McpSchema.java | 2 +- .../server/AbstractMcpAsyncServerTests.java | 2 +- .../server/AbstractMcpSyncServerTests.java | 16 ++++++++-------- ...eServerTransportProviderIntegrationTests.java | 6 +++--- 11 files changed, 32 insertions(+), 32 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index eed8a53af..62264d9aa 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -141,7 +141,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * Constructs a new WebFlux SSE server transport provider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of MCP messages. Must not be null. - * @param baseUrl webflux messag base path + * @param baseUrl webflux message base path * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 6ba0911ed..80a126441 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -152,7 +152,7 @@ void testCreateMessageSuccess(String clientType) { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) .modelPreferences(ModelPreferences.builder() @@ -163,7 +163,7 @@ void testCreateMessageSuccess(String clientType) { .build()) .build(); - return exchange.createMessage(craeteMessageRequest) + return exchange.createMessage(createMessageRequest) .doOnNext(samplingResult::set) .thenReturn(callResponse); }); @@ -421,7 +421,7 @@ void testRootsWithoutCapability(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsNotifciationWithEmptyRootsList(String clientType) { + void testRootsNotificationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 07b36c25b..b12d68439 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -169,7 +169,7 @@ void testCreateMessageSuccess() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) .modelPreferences(ModelPreferences.builder() @@ -180,7 +180,7 @@ void testCreateMessageSuccess() { .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(Role.USER); assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); @@ -438,7 +438,7 @@ void testRootsWithoutCapability() { } @Test - void testRootsNotifciationWithEmptyRootsList() { + void testRootsNotificationWithEmptyRootsList() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpServer = McpServer.sync(mcpServerTransportProvider) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index a91632c6c..cdd43e7ef 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -110,7 +110,7 @@ void testAddTool() { .build(); StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 9a63143c9..c81e638c1 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -200,16 +200,16 @@ void testAddResource() { Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullSpecifiation() { + void testAddResourceWithNullSpecification() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) @@ -279,11 +279,11 @@ void testAddPromptWithoutCapability() { .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -300,14 +300,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specificaiton) + .prompts(specification) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -340,7 +340,7 @@ void testRootsChangeHandlers() { var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchage, roots) -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index f7b179616..a1dc11685 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -196,7 +196,7 @@ public SyncSpec requestTimeout(Duration requestTimeout) { } /** - * @param initializationTimeout The duration to wait for the initializaiton + * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. * @return This builder instance for method chaining * @throws IllegalArgumentException if initializationTimeout is null @@ -435,7 +435,7 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { } /** - * @param initializationTimeout The duration to wait for the initializaiton + * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. * @return This builder instance for method chaining * @throws IllegalArgumentException if initializationTimeout is null diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 4f7d0e87e..28b63cecb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -359,7 +359,7 @@ private Mono asyncInitializeRequestHandler( } else { logger.warn( - "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", initializeRequest.protocolVersion(), serverProtocolVersion); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e621ac19b..6eb5159f6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1103,7 +1103,7 @@ public record ProgressNotification(// @formatter:off * setting minimum log levels, with servers sending notifications containing severity * levels, optional logger names, and arbitrary JSON-serializable data. * - * @param level The severity levels. The mimimum log level is set by the client. + * @param level The severity levels. The minimum log level is set by the client. * @param logger The logger that generated the message. * @param data JSON-serializable logging data. */ diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index c7c69b52b..df0b0c729 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -109,7 +109,7 @@ void testAddTool() { .build(); StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 8c9328cc7..0b38da857 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -199,16 +199,16 @@ void testAddResource() { Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullSpecifiation() { + void testAddResourceWithNullSpecification() { var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) @@ -278,11 +278,11 @@ void testAddPromptWithoutCapability() { .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -299,14 +299,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specificaiton) + .prompts(specification) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -339,7 +339,7 @@ void testRootsChangeHandlers() { var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchage, roots) -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index b8f040c7a..2ff6325a4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -152,7 +152,7 @@ void testCreateMessageSuccess() { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + var createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message")))) .modelPreferences(ModelPreferences.builder() @@ -163,7 +163,7 @@ void testCreateMessageSuccess() { .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(Role.USER); assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); @@ -417,7 +417,7 @@ void testRootsWithoutCapability() { } @Test - void testRootsNotifciationWithEmptyRootsList() { + void testRootsNotificationWithEmptyRootsList() { AtomicReference> rootsRef = new AtomicReference<>(); var mcpServer = McpServer.sync(mcpServerTransportProvider) From 7853efefdabeec61f093b4c048dda0eb95263348 Mon Sep 17 00:00:00 2001 From: jitokim Date: Fri, 11 Apr 2025 13:36:37 +0900 Subject: [PATCH 023/249] feat(completion): implement completion API support (#141) Add completion API support to the MCP protocol implementation: - Add CompleteRequest and CompleteResult schema classes - Implement completion capabilities in ServerCapabilities - Add completion handlers in McpAsyncServer and McpServer - Add completion client methods in McpAsyncClient and McpSyncClient - Add CompletionRefKey and completion specifications in McpServerFeatures - Add integration test for completion functionality - Fix isPresent() check to use isEmpty() in WebMvcSseServerTransportProvider - Replace McpServerFeatures.CompletionRefKey by McpSchemaCompleteReference Co-authored-by: Christian Tzolov Signed-off-by: jitokim --- .../WebFluxSseIntegrationTests.java | 65 +++++++++++--- .../WebMvcSseServerTransportProvider.java | 2 +- .../client/McpAsyncClient.java | 22 +++++ .../client/McpSyncClient.java | 11 +++ .../server/McpAsyncServer.java | 84 +++++++++++++++++++ .../server/McpServer.java | 43 +++++++++- .../server/McpServerFeatures.java | 71 +++++++++++++++- .../spec/McpClientSession.java | 1 + .../modelcontextprotocol/spec/McpSchema.java | 74 ++++++++++++---- 9 files changed, 340 insertions(+), 33 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 80a126441..08619bd31 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; @@ -20,19 +21,12 @@ import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.*; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; @@ -759,4 +753,53 @@ void testLoggingNotification(String clientType) { mcpServer.close(); } -} + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions(new CompletionCapabilities()).build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "this is code review prompt", List.of()), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference("ref/prompt", "code_review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "code_review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 7bd1aa6c9..fc86cfaa0 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -300,7 +300,7 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - if (!request.param("sessionId").isPresent()) { + if (request.param("sessionId").isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 1a9c39360..2bc74f258 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -71,6 +71,7 @@ * * @author Dariusz Jędrzejczyk * @author Christian Tzolov + * @author Jihoon Kim * @see McpClient * @see McpSchema * @see McpClientSession @@ -816,4 +817,25 @@ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } + // -------------------------- + // Completions + // -------------------------- + private static final TypeReference COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Sends a completion/complete request to generate value suggestions based on a given + * reference and argument. This is typically used to provide auto-completion options + * for user input fields. + * @param completeRequest The request containing the prompt or resource reference and + * argument for which to generate completions. + * @return A Mono that completes with the result containing completion suggestions. + * @see McpSchema.CompleteRequest + * @see McpSchema.CompleteResult + */ + public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { + return this.withInitializationCheck("complete completions", initializedResult -> this.mcpSession + .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 8544c3637..c91638a7e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -46,6 +46,7 @@ * * @author Dariusz Jędrzejczyk * @author Christian Tzolov + * @author Jihoon Kim * @see McpClient * @see McpAsyncClient * @see McpSchema @@ -334,4 +335,14 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { this.delegate.setLoggingLevel(loggingLevel).block(); } + /** + * Send a completion/complete request. + * @param completeRequest the completion request contains the prompt or resource + * reference and arguments for generating suggestions. + * @return the completion result containing suggested values. + */ + public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) { + return this.delegate.completeCompletion(completeRequest).block(); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 28b63cecb..906cb9a08 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -69,6 +69,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Jihoon Kim * @see McpServer * @see McpSchema * @see McpClientSession @@ -269,6 +270,8 @@ private static class AsyncServerImpl extends McpAsyncServer { // broadcasting loggingNotification. private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, @@ -282,6 +285,7 @@ private static class AsyncServerImpl extends McpAsyncServer { this.resources.putAll(features.resources()); this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); Map> requestHandlers = new HashMap<>(); @@ -314,6 +318,11 @@ private static class AsyncServerImpl extends McpAsyncServer { requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); } + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + Map notificationHandlers = new HashMap<>(); notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); @@ -706,6 +715,81 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { }; } + private McpServerSession.RequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); + } + + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); + } + + String type = request.ref().type(); + + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification prompt = this.prompts.get(promptReference.name()); + if (prompt == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + } + } + + if (type.equals("ref/resource") + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpServerFeatures.AsyncResourceSpecification resource = this.resources.get(resourceReference.uri()); + if (resource == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); + } + } + + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } + + return specification.completionHandler().apply(exchange, request); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a + * {@link McpSchema.CompleteRequest} object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input + * map, determines the correct reference type (either prompt or resource), and + * constructs a fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" + * and "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured + * completion request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument( + argName, argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + // --------------------------------------- // Sampling // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 60434a841..84089703c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -115,6 +115,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Jihoon Kim * @see McpAsyncServer * @see McpSyncServer * @see McpServerTransportProvider @@ -192,6 +193,8 @@ class AsyncSpecification { */ private final Map prompts = new HashMap<>(); + private final Map completions = new HashMap<>(); + private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout @@ -581,7 +584,8 @@ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { */ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions); + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); } @@ -635,6 +639,8 @@ class SyncSpecification { */ private final Map prompts = new HashMap<>(); + private final Map completions = new HashMap<>(); + private final List>> rootsChangeHandlers = new ArrayList<>(); private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout @@ -957,6 +963,37 @@ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... pr return this; } + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + * @see #completions(McpServerFeatures.SyncCompletionSpecification...) + */ + public SyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public SyncSpecification completions(McpServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.SyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + /** * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new @@ -1023,8 +1060,8 @@ public SyncSpecification objectMapper(ObjectMapper objectMapper) { */ public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, - this.instructions); + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index e0f337b78..8311f5d41 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -21,6 +21,7 @@ * MCP server features specification that a particular server can choose to support. * * @author Dariusz Jędrzejczyk + * @author Jihoon Kim */ public class McpServerFeatures { @@ -41,6 +42,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s List tools, Map resources, List resourceTemplates, Map prompts, + Map completions, List, Mono>> rootsChangeConsumers, String instructions) { @@ -60,6 +62,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s List tools, Map resources, List resourceTemplates, Map prompts, + Map completions, List, Mono>> rootsChangeConsumers, String instructions) { @@ -67,7 +70,8 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.serverInfo = serverInfo; this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable // logging // by @@ -81,6 +85,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.resources = (resources != null) ? resources : Map.of(); this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); this.instructions = instructions; } @@ -109,6 +114,11 @@ static Async fromSync(Sync syncSpec) { prompts.put(key, AsyncPromptSpecification.fromSync(prompt)); }); + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion)); + }); + List, Mono>> rootChangeConsumers = new ArrayList<>(); for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) { @@ -118,7 +128,7 @@ static Async fromSync(Sync syncSpec) { } return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, rootChangeConsumers, syncSpec.instructions()); + syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions()); } } @@ -140,6 +150,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Map resources, List resourceTemplates, Map prompts, + Map completions, List>> rootsChangeConsumers, String instructions) { /** @@ -159,6 +170,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Map resources, List resourceTemplates, Map prompts, + Map completions, List>> rootsChangeConsumers, String instructions) { @@ -166,7 +178,8 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.serverInfo = serverInfo; this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities - : new McpSchema.ServerCapabilities(null, // experimental + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable // logging // by @@ -180,6 +193,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.resources = (resources != null) ? resources : new HashMap<>(); this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); this.instructions = instructions; } @@ -325,6 +339,44 @@ static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { } } + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *

    + *
  • Customizable response generation logic + *
  • Parameter-driven template expansion + *
  • Dynamic interaction with connected clients + *
+ * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The first argument is an + * {@link McpAsyncServerExchange} used to interact with the client. The second + * argument is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), + (exchange, request) -> Mono.fromCallable( + () -> completion.completionHandler().apply(new McpSyncServerExchange(exchange), request)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + /** * Specification of a tool with its synchronous handler function. Tools are the * primary way for MCP servers to expose functionality to AI models. Each tool @@ -431,4 +483,17 @@ public record SyncPromptSpecification(McpSchema.Prompt prompt, BiFunction promptHandler) { } + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The first argument is an + * {@link McpSyncServerExchange} used to interact with the client. The second argument + * is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 0895e02b0..c1f42e3fb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -238,6 +238,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc }); }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { + logger.error("Error handling request: {}", jsonRpcResponse.error()); sink.error(new McpError(jsonRpcResponse.error())); } else { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 6eb5159f6..55fdc1724 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -79,6 +79,8 @@ private McpSchema() { public static final String METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED = "notifications/prompts/list_changed"; + public static final String METHOD_COMPLETION_COMPLETE = "completion/complete"; + // Logging Methods public static final String METHOD_LOGGING_SET_LEVEL = "logging/setLevel"; @@ -314,12 +316,16 @@ public ClientCapabilities build() { @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ServerCapabilities( // @formatter:off + @JsonProperty("completions") CompletionCapabilities completions, @JsonProperty("experimental") Map experimental, @JsonProperty("logging") LoggingCapabilities logging, @JsonProperty("prompts") PromptCapabilities prompts, @JsonProperty("resources") ResourceCapabilities resources, @JsonProperty("tools") ToolCapabilities tools) { + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record CompletionCapabilities() { + } @JsonInclude(JsonInclude.Include.NON_ABSENT) public record LoggingCapabilities() { @@ -347,12 +353,18 @@ public static Builder builder() { public static class Builder { + private CompletionCapabilities completions; private Map experimental; private LoggingCapabilities logging = new LoggingCapabilities(); private PromptCapabilities prompts; private ResourceCapabilities resources; private ToolCapabilities tools; + public Builder completions(CompletionCapabilities completions) { + this.completions = completions; + return this; + } + public Builder experimental(Map experimental) { this.experimental = experimental; return this; @@ -379,7 +391,7 @@ public Builder tools(Boolean listChanged) { } public ServerCapabilities build() { - return new ServerCapabilities(experimental, logging, prompts, resources, tools); + return new ServerCapabilities(completions, experimental, logging, prompts, resources, tools); } } } // @formatter:on @@ -1173,31 +1185,63 @@ public record SetLevelRequest(@JsonProperty("level") LoggingLevel level) { // --------------------------- // Autocomplete // --------------------------- - public record CompleteRequest(PromptOrResourceReference ref, CompleteArgument argument) implements Request { - public sealed interface PromptOrResourceReference permits PromptReference, ResourceReference { + public sealed interface CompleteReference permits PromptReference, ResourceReference { + + String type(); + + String identifier(); - String type(); + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record PromptReference(// @formatter:off + @JsonProperty("type") String type, + @JsonProperty("name") String name) implements McpSchema.CompleteReference { + public PromptReference(String name) { + this("ref/prompt", name); } - public record PromptReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("name") String name) implements PromptOrResourceReference { - }// @formatter:on + @Override + public String identifier() { + return name(); + } + }// @formatter:on - public record ResourceReference(// @formatter:off - @JsonProperty("type") String type, - @JsonProperty("uri") String uri) implements PromptOrResourceReference { - }// @formatter:on + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ResourceReference(// @formatter:off + @JsonProperty("type") String type, + @JsonProperty("uri") String uri) implements McpSchema.CompleteReference { - public record CompleteArgument(// @formatter:off + public ResourceReference(String uri) { + this("ref/resource", uri); + } + + @Override + public String identifier() { + return uri(); + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteRequest(// @formatter:off + @JsonProperty("ref") McpSchema.CompleteReference ref, + @JsonProperty("argument") CompleteArgument argument) implements Request { + + public record CompleteArgument( @JsonProperty("name") String name, @JsonProperty("value") String value) { }// @formatter:on } - public record CompleteResult(CompleteCompletion completion) { - public record CompleteCompletion(// @formatter:off + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CompleteResult(@JsonProperty("values") CompleteCompletion completion) { // @formatter:off + + public record CompleteCompletion( @JsonProperty("values") List values, @JsonProperty("total") Integer total, @JsonProperty("hasMore") Boolean hasMore) { From 734d1732d6d3e74cd427ac5a4dba95a02ba08618 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 17 Apr 2025 09:28:08 +0200 Subject: [PATCH 024/249] Refactor: Simplify ServerCapabilities builder API for completions Signed-off-by: Christian Tzolov --- .../io/modelcontextprotocol/WebFluxSseIntegrationTests.java | 2 +- mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 08619bd31..660f814da 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -774,7 +774,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { }; var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().completions(new CompletionCapabilities()).build()) + .capabilities(ServerCapabilities.builder().completions().build()) .prompts(new McpServerFeatures.SyncPromptSpecification( new Prompt("code_review", "this is code review prompt", List.of()), (mcpSyncServerExchange, getPromptRequest) -> null)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 55fdc1724..e7e338030 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -360,8 +360,8 @@ public static class Builder { private ResourceCapabilities resources; private ToolCapabilities tools; - public Builder completions(CompletionCapabilities completions) { - this.completions = completions; + public Builder completions() { + this.completions = new CompletionCapabilities(); return this; } From 344f1b838340efbcf0cfcd428358bdb77e9a533c Mon Sep 17 00:00:00 2001 From: E550448 Date: Sun, 13 Apr 2025 12:54:48 +0200 Subject: [PATCH 025/249] feat(mcp): resolve absolute and relative message endpoint URIs (#150) Improve endpoint URI handling by supporting both relative paths and properly validated absolute URIs. - Implement URI resolution in HttpClientSseClientTransport: - Change baseUri field from String to URI type - Add Utils.resolveUri method to handle both absolute and relative URIs - Resolve relative URIs against the base URI - Validate absolute URIs to ensure they match base URI's scheme, authority, and path - Add parameterized tests for various URI resolution scenarios - Add ByteBuddy dependency for HttpClient mocking and update Mockito Signed-off-by: Christian Tzolov --- README.md | 2 +- mcp/pom.xml | 14 +++++ .../HttpClientSseClientTransport.java | 11 ++-- .../io/modelcontextprotocol/util/Utils.java | 56 ++++++++++++++++++- .../HttpClientSseClientTransportTests.java | 29 +++++++++- .../modelcontextprotocol/util/UtilsTests.java | 29 ++++++++++ pom.xml | 5 +- 7 files changed, 136 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ca87736cd..9fc17306e 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. -## 📚 Reference Documentation +## 📚 Reference Documentation #### MCP Java SDK documentation For comprehensive guides and SDK API documentation, visit the [MCP Java SDK Reference Documentation](https://modelcontextprotocol.io/sdk/java/mcp-overview). diff --git a/mcp/pom.xml b/mcp/pom.xml index 6b0f4a9fe..17693ab32 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -126,12 +126,26 @@ ${junit.version} test + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + org.mockito mockito-core ${mockito.version} test + + + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 632d3844a..99cf2a625 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -24,6 +24,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -69,7 +70,7 @@ public class HttpClientSseClientTransport implements McpClientTransport { private static final String DEFAULT_SSE_ENDPOINT = "/sse"; /** Base URI for the MCP server */ - private final String baseUri; + private final URI baseUri; /** SSE endpoint path */ private final String sseEndpoint; @@ -178,7 +179,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); - this.baseUri = baseUri; + this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = httpClient; @@ -340,7 +341,8 @@ public Mono connect(Function, Mono> h CompletableFuture future = new CompletableFuture<>(); connectionFuture.set(future); - sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() { + URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { if (isClosing) { @@ -412,7 +414,8 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); - HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) + URI requestUri = Utils.resolveUri(baseUri, endpoint); + HttpRequest request = this.requestBuilder.uri(requestUri) .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 0f799ca0f..8e654e596 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,11 +4,12 @@ package io.modelcontextprotocol.util; +import reactor.util.annotation.Nullable; + +import java.net.URI; import java.util.Collection; import java.util.Map; -import reactor.util.annotation.Nullable; - /** * Miscellaneous utility methods. * @@ -52,4 +53,55 @@ public static boolean isEmpty(@Nullable Map map) { return (map == null || map.isEmpty()); } + /** + * Resolves the given endpoint URL against the base URL. + *
    + *
  • If the endpoint URL is relative, it will be resolved against the base URL.
  • + *
  • If the endpoint URL is absolute, it will be validated to ensure it matches the + * base URL's scheme, authority, and path prefix.
  • + *
  • If validation fails for an absolute URL, an {@link IllegalArgumentException} is + * thrown.
  • + *
+ * @param baseUrl The base URL (must be absolute) + * @param endpointUrl The endpoint URL (can be relative or absolute) + * @return The resolved endpoint URI + * @throws IllegalArgumentException If the absolute endpoint URL does not match the + * base URL or URI is malformed + */ + public static URI resolveUri(URI baseUrl, String endpointUrl) { + URI endpointUri = URI.create(endpointUrl); + if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { + throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); + } + else { + return baseUrl.resolve(endpointUri); + } + } + + /** + * Checks if the given absolute endpoint URI falls under the base URI. It validates + * the scheme, authority (host and port), and ensures that the base path is a prefix + * of the endpoint path. + * @param baseUri The base URI + * @param endpointUri The endpoint URI to check + * @return true if endpointUri is within baseUri's hierarchy, false otherwise + */ + private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) { + if (!baseUri.getScheme().equals(endpointUri.getScheme()) + || !baseUri.getAuthority().equals(endpointUri.getAuthority())) { + return false; + } + + URI normalizedBase = baseUri.normalize(); + URI normalizedEndpoint = endpointUri.normalize(); + + String basePath = normalizedBase.getPath(); + String endpointPath = normalizedEndpoint.getPath(); + + if (basePath.endsWith("/")) { + basePath = basePath.substring(0, basePath.length() - 1); + } + return endpointPath.startsWith(basePath); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index e5178c0ee..762264de3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,12 +7,13 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; @@ -21,6 +22,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -31,6 +34,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; @@ -364,4 +370,25 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } + @Test + @SuppressWarnings("unchecked") + void testResolvingClientEndpoint() { + HttpClient httpClient = Mockito.mock(HttpClient.class); + HttpResponse httpResponse = Mockito.mock(HttpResponse.class); + CompletableFuture> future = new CompletableFuture<>(); + future.complete(httpResponse); + when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); + + HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), + "http://example.com", "http://example.com/sse", new ObjectMapper()); + + transport.connect(Function.identity()); + + ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); + assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); + + transport.closeGracefully().block(); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java index aced20cbc..0f2e689b5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java @@ -6,12 +6,17 @@ import org.junit.jupiter.api.Test; +import java.net.URI; import java.util.Collection; import java.util.List; import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; class UtilsTests { @@ -37,4 +42,28 @@ void testMapIsEmpty() { assertFalse(Utils.isEmpty(Map.of("key", "value"))); } + @ParameterizedTest + @CsvSource({ + // relative endpoints + "http://localhost:8080/root, /api/v1, http://localhost:8080/api/v1", + "http://localhost:8080/root/, api, http://localhost:8080/root/api", + "http://localhost:8080, /api, http://localhost:8080/api", + // absolute endpoints matching base + "http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1", + "http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" }) + void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) { + URI result = Utils.resolveUri(URI.create(baseUrl), endpoint); + assertThat(result.toString()).isEqualTo(expectedResult); + } + + @ParameterizedTest + @CsvSource({ "http://localhost:8080/root, http://localhost:8080/other/api", + "http://localhost:8080/root, http://otherhost/api", + "http://localhost:8080/root, http://localhost:9090/root/api" }) + void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) { + assertThatThrownBy(() -> Utils.resolveUri(URI.create(baseUrl), endpoint)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("does not match the base URL"); + } + } \ No newline at end of file diff --git a/pom.xml b/pom.xml index ff485b75d..9be256ccf 100644 --- a/pom.xml +++ b/pom.xml @@ -60,8 +60,9 @@ 3.26.3 5.10.2 - 5.11.0 + 5.17.0 1.20.4 + 1.17.5 2.0.16 1.5.15 @@ -356,4 +357,4 @@ - \ No newline at end of file + From e4091f458a28e31f87a517f411fe9d18811027a6 Mon Sep 17 00:00:00 2001 From: "jie.bao" Date: Fri, 18 Apr 2025 09:34:55 +0800 Subject: [PATCH 026/249] feat(completion): fix the schema about CompleteResult /** * The server's response to a completion/complete request */ export interface CompleteResult extends Result { completion: { /** * An array of completion values. Must not exceed 100 items. */ values: string[]; /** * The total number of completion options available. This can exceed the number of values actually sent in the response. */ total?: number; /** * Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown. */ hasMore?: boolean; }; } --- mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e7e338030..e77edb3b7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1239,7 +1239,7 @@ public record CompleteArgument( @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - public record CompleteResult(@JsonProperty("values") CompleteCompletion completion) { // @formatter:off + public record CompleteResult(@JsonProperty("completion") CompleteCompletion completion) { // @formatter:off public record CompleteCompletion( @JsonProperty("values") List values, From 41c6bd9af09462a87064dc035d5e123d7f1eae58 Mon Sep 17 00:00:00 2001 From: JermaineHua Date: Thu, 17 Apr 2025 22:25:00 +0800 Subject: [PATCH 027/249] Fix method not found error msg for server Signed-off-by: JermaineHua --- .../io/modelcontextprotocol/spec/McpClientSession.java | 2 +- .../io/modelcontextprotocol/spec/McpServerSession.java | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index c1f42e3fb..9ed0d8edd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -178,7 +178,7 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR record MethodNotFoundError(String method, String message, Object data) { } - public static MethodNotFoundError getMethodNotFoundError(String method) { + private MethodNotFoundError getMethodNotFoundError(String method) { switch (method) { case McpSchema.METHOD_ROOTS_LIST: return new MethodNotFoundError(method, "Roots not supported", diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 46c356cdd..64315095b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -257,14 +257,8 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti record MethodNotFoundError(String method, String message, Object data) { } - static MethodNotFoundError getMethodNotFoundError(String method) { - switch (method) { - case McpSchema.METHOD_ROOTS_LIST: - return new MethodNotFoundError(method, "Roots not supported", - Map.of("reason", "Client does not have roots capability")); - default: - return new MethodNotFoundError(method, "Method not found: " + method, null); - } + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); } @Override From 04046ca05b6b90f9a6ec2f40236c69470b878fe6 Mon Sep 17 00:00:00 2001 From: JermaineHua Date: Wed, 16 Apr 2025 23:10:59 +0800 Subject: [PATCH 028/249] Optimize client nested streams in McpClientSession (#33) Signed-off-by: JermaineHua --- .../spec/McpClientSession.java | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 9ed0d8edd..a25f38c5c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -122,7 +122,12 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, // Observation associated with the individual message - it can be used to // create child Observation and emit it together with the message to the // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { + this.connection = this.transport.connect(mono -> mono.doOnNext(message -> handle(message).subscribe())) + .subscribe(); + } + + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { if (message instanceof McpSchema.JSONRPCResponse response) { logger.debug("Received Response: {}", response); var sink = pendingResponses.remove(response.id()); @@ -132,23 +137,27 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, else { sink.success(response); } + return Mono.empty(); } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), - error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - transport.sendMessage(errorResponse).subscribe(); - }); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); } else if (message instanceof McpSchema.JSONRPCNotification notification) { logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification).subscribe(null, - error -> logger.error("Error handling notification: {}", error.getMessage())); + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } - })).subscribe(); + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); } /** From 866732c3833e863ea145c6e1dfa32b9d089211e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 23 Apr 2025 11:06:26 +0200 Subject: [PATCH 029/249] Polish #33 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../spec/McpClientSession.java | 58 +++++++++---------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index a25f38c5c..6eca34757 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -122,42 +122,38 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, // Observation associated with the individual message - it can be used to // create child Observation and emit it together with the message to the // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(message -> handle(message).subscribe())) - .subscribe(); + this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); } - public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { - if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); - } - else { - sink.success(response); - } - return Mono.empty(); - } - else if (message instanceof McpSchema.JSONRPCRequest request) { - logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage); - } - else if (message instanceof McpSchema.JSONRPCNotification notification) { - logger.debug("Received notification: {}", notification); - return handleIncomingNotification(notification) - .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + private void handle(McpSchema.JSONRPCMessage message) { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); } else { - logger.warn("Received unknown message type: {}", message); - return Mono.empty(); + sink.success(response); } - }); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage).subscribe(); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())) + .subscribe(); + } + else { + logger.warn("Received unknown message type: {}", message); + } } /** From 86e3e9048f53b706849a2a58a11aae70c3a1f391 Mon Sep 17 00:00:00 2001 From: jito Date: Wed, 23 Apr 2025 23:27:10 +0900 Subject: [PATCH 030/249] Fix typo in WebFluxSseIntegrationTests (#142) Signed-off-by: jitokim From f70b98b4b4160ea590a0c845ee3e2a7357bdcae9 Mon Sep 17 00:00:00 2001 From: Richie Caputo <43445060+arcaputo3@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:47:48 -0400 Subject: [PATCH 031/249] feat(schema): add support for JSON Schema $defs and definitions (#146) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added support for $defs and definitions properties in JsonSchema record to handle JSON Schema references properly. Added tests to verify both formats work correctly. The JsonSchema test approach uses serialization/deserialization round-trip validation instead of property-by-property assertions. This makes tests more maintainable and less likely to break when new properties are added. 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude --- .../modelcontextprotocol/spec/McpSchema.java | 4 +- .../spec/McpSchemaTests.java | 129 ++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e77edb3b7..8df8a1584 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -703,7 +703,9 @@ public record JsonSchema( // @formatter:off @JsonProperty("type") String type, @JsonProperty("properties") Map properties, @JsonProperty("required") List required, - @JsonProperty("additionalProperties") Boolean additionalProperties) { + @JsonProperty("additionalProperties") Boolean additionalProperties, + @JsonProperty("$defs") Map defs, + @JsonProperty("definitions") Map definitions) { } // @formatter:on /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index a41fc095f..ff78c1bfc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -9,6 +9,7 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; @@ -449,6 +450,92 @@ void testGetPromptResult() throws Exception { // Tool Tests + @Test + void testJsonSchema() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/$defs/Address" + } + }, + "required": ["name"], + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = mapper.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = mapper.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + + @Test + void testJsonSchemaWithDefinitions() throws Exception { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "$ref": "#/definitions/Address" + } + }, + "required": ["name"], + "definitions": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + } + } + """; + + // Deserialize the original string to a JsonSchema object + McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + + // Serialize the object back to a string + String serialized = mapper.writeValueAsString(schema); + + // Deserialize again + McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + + // Serialize one more time and compare with the first serialization + String serializedAgain = mapper.writeValueAsString(deserialized); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + } + @Test void testTool() throws Exception { String schemaJson = """ @@ -477,6 +564,48 @@ void testTool() throws Exception { {"name":"test-tool","description":"A test tool","inputSchema":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"number"}},"required":["name"]}}""")); } + @Test + void testToolWithComplexSchema() throws Exception { + String complexSchemaJson = """ + { + "type": "object", + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "properties": { + "name": {"type": "string"}, + "shippingAddress": {"$ref": "#/$defs/Address"} + }, + "required": ["name", "shippingAddress"] + } + """; + + McpSchema.Tool tool = new McpSchema.Tool("addressTool", "Handles addresses", complexSchemaJson); + + // Serialize the tool to a string + String serialized = mapper.writeValueAsString(tool); + + // Deserialize back to a Tool object + McpSchema.Tool deserializedTool = mapper.readValue(serialized, McpSchema.Tool.class); + + // Serialize again and compare with first serialization + String serializedAgain = mapper.writeValueAsString(deserializedTool); + + // The two serialized strings should be the same + assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); + + // Just verify the basic structure was preserved + assertThat(deserializedTool.inputSchema().defs()).isNotNull(); + assertThat(deserializedTool.inputSchema().defs()).containsKey("Address"); + } + @Test void testCallToolRequest() throws Exception { Map arguments = new HashMap<>(); From 9c92a2b8bffe41f4c6df27ca1977bc8ee8343137 Mon Sep 17 00:00:00 2001 From: wangzhi <1277975348@qq.com> Date: Wed, 23 Apr 2025 23:03:10 +0800 Subject: [PATCH 032/249] Fix javadoc references and formatting (#149) --- .../server/AbstractMcpAsyncServerTests.java | 2 +- .../server/AbstractMcpSyncServerTests.java | 2 +- .../java/io/modelcontextprotocol/client/McpAsyncClient.java | 4 ++-- .../java/io/modelcontextprotocol/spec/McpServerSession.java | 3 ++- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index cdd43e7ef..025cfeacf 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -30,7 +30,7 @@ /** * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransportProvider} implementations. + * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. * * @author Christian Tzolov */ diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index c81e638c1..e313454bd 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -27,7 +27,7 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransportProvider} implementations. + * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. * * @author Christian Tzolov */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 2bc74f258..e3a997ba3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -317,9 +317,9 @@ public Mono closeGracefully() { * The client MUST initiate this phase by sending an initialize request containing: * The protocol version the client supports, client's capabilities and clients * implementation information. - *

+ *

* The server MUST respond with its own capabilities and information. - *

+ *

* After successful initialization, the client MUST send an initialized notification * to indicate it is ready to begin normal operations. * @return the initialize result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 64315095b..86906d859 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -64,7 +64,8 @@ public class McpServerSession implements McpSession { * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the * server * @param initNotificationHandler called when a - * {@link McpSchema.METHOD_NOTIFICATION_INITIALIZED} is received. + * {@link io.modelcontextprotocol.spec.McpSchema#METHOD_NOTIFICATION_INITIALIZED} is + * received. * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use */ From 261554bb7f1cc630aefeb5487434c1740a72b856 Mon Sep 17 00:00:00 2001 From: Francis Hodianto <61911161+FH-30@users.noreply.github.com> Date: Wed, 23 Apr 2025 23:09:41 +0800 Subject: [PATCH 033/249] fix: propagate Reactor Context into client transport chain (#154) --- .../java/io/modelcontextprotocol/spec/McpClientSession.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 6eca34757..f577b493a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -230,18 +230,19 @@ private String generateRequestId() { public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = this.generateRequestId(); - return Mono.create(sink -> { + return Mono.deferContextual(ctx -> Mono.create(sink -> { this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); this.transport.sendMessage(jsonrpcRequest) + .contextWrite(ctx) // TODO: It's most efficient to create a dedicated Subscriber here .subscribe(v -> { }, error -> { this.pendingResponses.remove(requestId); sink.error(error); }); - }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + })).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { logger.error("Error handling request: {}", jsonRpcResponse.error()); sink.error(new McpError(jsonRpcResponse.error())); From e610d853f922e36ba474b2240f5c6546166e4840 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 20 Apr 2025 10:58:51 +0300 Subject: [PATCH 034/249] feat: Add customizable URI template manager factory to MCP server Implement URI template functionality for MCP resources, allowing dynamic resource URIs with variables in the format {variableName}. - Enable resource URIs with variable placeholders (e.g., "/api/users/{userId}") - Automatic extraction of variable values from request URIs - Validation of template arguments in completions - Matching of request URIs against templates - Add new URI template management interfaces and implementations - Enhanced resource template listing to include templated resources - Updated resource request handling to support template matching - Test coverage for URI template functionality - Adding a configurable uriTemplateManagerFactory field to both AsyncSpecification and SyncSpecification classes - Adding builder methods to allow setting a custom URI template manager factory - Modifying constructors to pass the URI template manager factory to the server implementation - Updating the server implementation to use the provided factory - Add bulk registration methods for async completions Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 3 +- .../server/McpAsyncServer.java | 77 +++++++-- .../server/McpServer.java | 68 +++++++- .../DeafaultMcpUriTemplateManagerFactory.java | 23 +++ .../util/DefaultMcpUriTemplateManager.java | 163 ++++++++++++++++++ .../util/McpUriTemplateManager.java | 52 ++++++ .../util/McpUriTemplateManagerFactory.java | 22 +++ .../McpUriTemplateManagerTests.java | 97 +++++++++++ 8 files changed, 489 insertions(+), 16 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 660f814da..2ba047461 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -776,7 +776,8 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().completions().build()) .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "this is code review prompt", List.of()), + new Prompt("code_review", "this is code review prompt", + List.of(new PromptArgument("language", "string", false))), (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( new McpSchema.PromptReference("ref/prompt", "code_review"), completionHandler)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 906cb9a08..3c112ad76 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.server; import java.time.Duration; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -22,10 +23,13 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,8 +96,10 @@ public class McpAsyncServer { * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features); + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features, + uriTemplateManagerFactory); } /** @@ -274,8 +280,11 @@ private static class AsyncServerImpl extends McpAsyncServer { private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - Duration requestTimeout, McpServerFeatures.Async features) { + Duration requestTimeout, McpServerFeatures.Async features, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -286,6 +295,7 @@ private static class AsyncServerImpl extends McpAsyncServer { this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; Map> requestHandlers = new HashMap<>(); @@ -564,8 +574,26 @@ private McpServerSession.RequestHandler resources private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), + resource.description(), resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; } private McpServerSession.RequestHandler resourcesReadRequestHandler() { @@ -574,11 +602,16 @@ private McpServerSession.RequestHandler resourcesR new TypeReference() { }); var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceSpecification specification = this.resources.get(resourceUri); - if (specification != null) { - return specification.readHandler().apply(exchange, resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); + + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(exchange, resourceRequest); }; } @@ -729,20 +762,38 @@ private McpServerSession.RequestHandler completionComp String type = request.ref().type(); + String argumentName = request.argument().name(); + // check if the referenced resource exists if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification prompt = this.prompts.get(promptReference.name()); - if (prompt == null) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + return Mono.error(new McpError("Argument not found: " + argumentName)); + } } if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resource = this.resources.get(resourceReference.uri()); - if (resource == null) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources + .get(resourceReference.uri()); + if (resourceSpec == null) { return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + } McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 84089703c..d6ec2cc30 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import reactor.core.publisher.Mono; /** @@ -156,6 +158,8 @@ class AsyncSpecification { private final McpServerTransportProvider transportProvider; + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private ObjectMapper objectMapper; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -204,6 +208,19 @@ private AsyncSpecification(McpServerTransportProvider transportProvider) { this.transportProvider = transportProvider; } + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + /** * Sets the duration to wait for server responses before timing out requests. This * timeout applies to all requests made through the client, including tool calls, @@ -517,6 +534,36 @@ public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... return this; } + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + /** * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new @@ -587,7 +634,8 @@ public McpAsyncServer build() { this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory); } } @@ -600,6 +648,8 @@ class SyncSpecification { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private final McpServerTransportProvider transportProvider; private ObjectMapper objectMapper; @@ -650,6 +700,19 @@ private SyncSpecification(McpServerTransportProvider transportProvider) { this.transportProvider = transportProvider; } + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + /** * Sets the duration to wait for server responses before timing out requests. This * timeout applies to all requests made through the client, including tool calls, @@ -1064,7 +1127,8 @@ public McpSyncServer build() { this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory); return new McpSyncServer(asyncServer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java new file mode 100644 index 000000000..3870b76fc --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java @@ -0,0 +1,23 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.util; + +/** + * @author Christian Tzolov + */ +public class DeafaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + @Override + public McpUriTemplateManager create(String uriTemplate) { + return new DefaultMcpUriTemplateManager(uriTemplate); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java b/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java new file mode 100644 index 000000000..b2e9a5285 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java @@ -0,0 +1,163 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Default implementation of the UriTemplateUtils interface. + *

+ * This class provides methods for extracting variables from URI templates and matching + * them against actual URIs. + * + * @author Christian Tzolov + */ +public class DefaultMcpUriTemplateManager implements McpUriTemplateManager { + + /** + * Pattern to match URI variables in the format {variableName}. + */ + private static final Pattern URI_VARIABLE_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); + + private final String uriTemplate; + + /** + * Constructor for DefaultMcpUriTemplateManager. + * @param uriTemplate The URI template to be used for variable extraction + */ + public DefaultMcpUriTemplateManager(String uriTemplate) { + if (uriTemplate == null || uriTemplate.isEmpty()) { + throw new IllegalArgumentException("URI template must not be null or empty"); + } + this.uriTemplate = uriTemplate; + } + + /** + * Extract URI variable names from a URI template. + * @param uriTemplate The URI template containing variables in the format + * {variableName} + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + @Override + public List getVariableNames() { + if (uriTemplate == null || uriTemplate.isEmpty()) { + return List.of(); + } + + List variables = new ArrayList<>(); + Matcher matcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + + while (matcher.find()) { + String variableName = matcher.group(1); + if (variables.contains(variableName)) { + throw new IllegalArgumentException("Duplicate URI variable name in template: " + variableName); + } + variables.add(variableName); + } + + return variables; + } + + /** + * Extract URI variable values from the actual request URI. + *

+ * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param requestUri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + @Override + public Map extractVariableValues(String requestUri) { + Map variableValues = new HashMap<>(); + List uriVariables = this.getVariableNames(); + + if (requestUri == null || uriVariables.isEmpty()) { + return variableValues; + } + + try { + // Create a regex pattern by replacing each {variableName} with a capturing + // group + StringBuilder patternBuilder = new StringBuilder("^"); + + // Find all variable placeholders and their positions + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Add the text between the last variable and this one, escaped for regex + String textBefore = uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + + // Add a capturing group for the variable + patternBuilder.append("([^/]+)"); + + lastEnd = variableMatcher.end(); + } + + // Add any remaining text after the last variable + if (lastEnd < uriTemplate.length()) { + patternBuilder.append(Pattern.quote(uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); + + // Compile the pattern and match against the request URI + Pattern pattern = Pattern.compile(patternBuilder.toString()); + Matcher matcher = pattern.matcher(requestUri); + + if (matcher.find() && matcher.groupCount() == uriVariables.size()) { + for (int i = 0; i < uriVariables.size(); i++) { + String value = matcher.group(i + 1); + if (value == null || value.isEmpty()) { + throw new IllegalArgumentException( + "Empty value for URI variable '" + uriVariables.get(i) + "' in URI: " + requestUri); + } + variableValues.put(uriVariables.get(i), value); + } + } + } + catch (Exception e) { + throw new IllegalArgumentException("Error parsing URI template: " + uriTemplate + " for URI: " + requestUri, + e); + } + + return variableValues; + } + + /** + * Check if a URI matches the uriTemplate with variables. + * @param uri The URI to check + * @return true if the URI matches the pattern, false otherwise + */ + @Override + public boolean matches(String uri) { + // If the uriTemplate doesn't contain variables, do a direct comparison + if (!this.isUriTemplate(this.uriTemplate)) { + return uri.equals(this.uriTemplate); + } + + // Convert the pattern to a regex + String regex = this.uriTemplate.replaceAll("\\{[^/]+?\\}", "([^/]+?)"); + regex = regex.replace("/", "\\/"); + + // Check if the URI matches the regex + return Pattern.compile(regex).matcher(uri).matches(); + } + + @Override + public boolean isUriTemplate(String uri) { + return URI_VARIABLE_PATTERN.matcher(uri).find(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java new file mode 100644 index 000000000..19569e49f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +/** + * Interface for working with URI templates. + *

+ * This interface provides methods for extracting variables from URI templates and + * matching them against actual URIs. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManager { + + /** + * Extract URI variable names from this URI template. + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + List getVariableNames(); + + /** + * Extract URI variable values from the actual request URI. + *

+ * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param uri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + Map extractVariableValues(String uri); + + /** + * Indicate whether the given URI matches this template. + * @param uri the URI to match to + * @return {@code true} if it matches; {@code false} otherwise + */ + boolean matches(String uri); + + /** + * Check if the given URI is a URI template. + * @return Returns true if the URI contains variables in the format {variableName} + */ + public boolean isUriTemplate(String uri); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java new file mode 100644 index 000000000..9644f9a6c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java @@ -0,0 +1,22 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.util; + +/** + * Factory interface for creating instances of {@link McpUriTemplateManager}. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + McpUriTemplateManager create(String uriTemplate); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java b/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java new file mode 100644 index 000000000..6f041daa6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManager; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link McpUriTemplateManager} and its implementations. + * + * @author Christian Tzolov + */ +public class McpUriTemplateManagerTests { + + private McpUriTemplateManagerFactory uriTemplateFactory; + + @BeforeEach + void setUp() { + this.uriTemplateFactory = new DeafaultMcpUriTemplateManagerFactory(); + } + + @Test + void shouldExtractVariableNamesFromTemplate() { + List variables = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .getVariableNames(); + assertEquals(2, variables.size()); + assertEquals("userId", variables.get(0)); + assertEquals("postId", variables.get(1)); + } + + @Test + void shouldReturnEmptyListWhenTemplateHasNoVariables() { + List variables = this.uriTemplateFactory.create("/api/users/all").getVariableNames(); + assertEquals(0, variables.size()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromNullTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create(null).getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromEmptyTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create("").getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenTemplateContainsDuplicateVariables() { + assertThrows(IllegalArgumentException.class, + () -> this.uriTemplateFactory.create("/api/users/{userId}/posts/{userId}").getVariableNames()); + } + + @Test + void shouldExtractVariableValuesFromRequestUri() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues("/api/users/123/posts/456"); + assertEquals(2, values.size()); + assertEquals("123", values.get("userId")); + assertEquals("456", values.get("postId")); + } + + @Test + void shouldReturnEmptyMapWhenTemplateHasNoVariables() { + Map values = this.uriTemplateFactory.create("/api/users/all") + .extractVariableValues("/api/users/all"); + assertEquals(0, values.size()); + } + + @Test + void shouldReturnEmptyMapWhenRequestUriIsNull() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues(null); + assertEquals(0, values.size()); + } + + @Test + void shouldMatchUriAgainstTemplatePattern() { + var uriTemplateManager = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}"); + + assertTrue(uriTemplateManager.matches("/api/users/123/posts/456")); + assertFalse(uriTemplateManager.matches("/api/users/123/comments/456")); + } + +} From e34babbe56b730514d35191118d5e66bc9c51b9a Mon Sep 17 00:00:00 2001 From: jito Date: Thu, 8 May 2025 18:31:17 +0900 Subject: [PATCH 035/249] Add missing isInitialized method to McpSyncClient (#181) The isInitialized method is present in McpAsyncClient and needs to be mirrored in McpSyncClient. Signed-off-by: jitokim --- .../io/modelcontextprotocol/client/McpSyncClient.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index c91638a7e..a8fb979e1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -97,6 +97,14 @@ public McpSchema.Implementation getServerInfo() { return this.delegate.getServerInfo(); } + /** + * Check if the client-server connection is initialized. + * @return true if the client-server connection is initialized + */ + public boolean isInitialized() { + return this.delegate.isInitialized(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities From eae3840e7d44932c60c131cb7a346b5367b788ff Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Fri, 25 Apr 2025 18:47:54 +0200 Subject: [PATCH 036/249] fix: Mockito inline mocking for Java 21+ (#207) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this fix the execution of the maven surefire plugin with Java 21 logged warnings that mockito should be added as a java agent, because the self-attaching won't be supported in future java releases. In Java 24 the test just broke. This problem is solved by modifying the pom.xml of the parent and doing this changes: * Adding mockito as a java agent. * Removing the surefireArgLine from the properties. This can be added back when it's needed (for example when JaCoCo will be used). Furthermore, the pom.xml in the mcp-spring-* modules now have the byte-buddy dependency included, as the test would otherwise break when trying to mock McpSchema#CreateMessageRequest. Fixes #187 Co-authored-by: Dariusz Jędrzejczyk --- mcp-spring/mcp-spring-webflux/pom.xml | 6 ++++++ mcp-spring/mcp-spring-webmvc/pom.xml | 6 ++++++ pom.xml | 15 +++++++++++++-- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 63c32a8a8..86f46bf95 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -82,6 +82,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index b59be6a03..82fbbf3e6 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -77,6 +77,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + org.testcontainers junit-jupiter diff --git a/pom.xml b/pom.xml index 9be256ccf..638457406 100644 --- a/pom.xml +++ b/pom.xml @@ -57,6 +57,7 @@ 17 17 17 + 3.26.3 5.10.2 @@ -163,13 +164,23 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + + properties + + + + org.apache.maven.plugins maven-surefire-plugin ${maven-surefire-plugin.version} - ${surefireArgLine} - + ${surefireArgLine} -javaagent:${org.mockito:mockito-core:jar} false false From 0069c977ef88b91162b08899bb8040a0ffcb8653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 9 May 2025 12:57:36 +0200 Subject: [PATCH 037/249] Remove temporary delegate impl from McpAsyncServer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../server/McpAsyncServer.java | 1082 ++++++++--------- 1 file changed, 484 insertions(+), 598 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 3c112ad76..1efa13de3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -82,11 +82,33 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - private final McpAsyncServer delegate; + private final McpServerTransportProvider mcpTransportProvider; - McpAsyncServer() { - this.delegate = null; - } + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); /** * Create a new McpAsyncServer with the given transport provider and capabilities. @@ -98,8 +120,104 @@ public class McpAsyncServer { McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features, - uriTemplateManagerFactory); + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + mcpTransportProvider.setSessionFactory( + transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); } /** @@ -107,7 +225,7 @@ public class McpAsyncServer { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.delegate.getServerCapabilities(); + return this.serverCapabilities; } /** @@ -115,7 +233,7 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.delegate.getServerInfo(); + return this.serverInfo; } /** @@ -123,26 +241,66 @@ public McpSchema.Implementation getServerInfo() { * @return A Mono that completes when the server has been closed */ public Mono closeGracefully() { - return this.delegate.closeGracefully(); + return this.mcpTransportProvider.closeGracefully(); } /** * Close the server immediately. */ public void close() { - this.delegate.close(); + this.mcpTransportProvider.close(); + } + + private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); } // --------------------------------------- // Tool Management // --------------------------------------- + /** * Add a new tool specification at runtime. * @param toolSpecification The tool specification to add * @return Mono that completes when clients have been notified of the change */ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - return this.delegate.addTool(toolSpecification); + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); } /** @@ -151,7 +309,25 @@ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecifica * @return Mono that completes when clients have been notified of the change */ public Mono removeTool(String toolName) { - return this.delegate.removeTool(toolName); + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); } /** @@ -159,19 +335,65 @@ public Mono removeTool(String toolName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyToolsListChanged() { - return this.delegate.notifyToolsListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpServerSession.RequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; } // --------------------------------------- // Resource Management // --------------------------------------- + /** * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add + * @param resourceSpecification The resource handler to add * @return Mono that completes when clients have been notified of the change */ - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { - return this.delegate.addResource(resourceHandler); + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); } /** @@ -180,7 +402,24 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou * @return Mono that completes when clients have been notified of the change */ public Mono removeResource(String resourceUri) { - return this.delegate.removeResource(resourceUri); + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); } /** @@ -188,19 +427,97 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - return this.delegate.notifyResourcesListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), + resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; + } + + private McpServerSession.RequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(exchange, resourceRequest); + }; } // --------------------------------------- // Prompt Management // --------------------------------------- + /** * Add a new prompt handler at runtime. * @param promptSpecification The prompt handler to add * @return Mono that completes when clients have been notified of the change */ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - return this.delegate.addPrompt(promptSpecification); + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error( + new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); } /** @@ -209,7 +526,27 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe * @return Mono that completes when clients have been notified of the change */ public Mono removePrompt(String promptName) { - return this.delegate.removePrompt(promptName); + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); } /** @@ -217,7 +554,39 @@ public Mono removePrompt(String promptName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyPromptsListChanged() { - return this.delegate.notifyPromptsListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpServerSession.RequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(exchange, promptRequest); + }; } // --------------------------------------- @@ -237,619 +606,136 @@ public Mono notifyPromptsListChanged() { */ @Deprecated public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - return this.delegate.loggingNotification(loggingMessageNotification); - } - - // --------------------------------------- - // Sampling - // --------------------------------------- - /** - * This method is package-private and used for test only. Should not be called by user - * code. - * @param protocolVersions the Client supported protocol versions. - */ - void setProtocolVersions(List protocolVersions) { - this.delegate.setProtocolVersions(protocolVersions); - } - - private static class AsyncServerImpl extends McpAsyncServer { - - private final McpServerTransportProvider mcpTransportProvider; - - private final ObjectMapper objectMapper; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private final String instructions; - - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - // FIXME: this field is deprecated and should be remvoed together with the - // broadcasting loggingNotification. - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); - - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - - AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - Duration requestTimeout, McpServerFeatures.Async features, - McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.instructions = features.instructions(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - this.completions.putAll(features.completions()); - this.uriTemplateManagerFactory = uriTemplateManagerFactory; - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - // Add completion API handlers if the completion capability is enabled - if (this.serverCapabilities.completions() != null) { - requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, - roots) -> Mono.fromRunnable(() -> logger.warn( - "Roots list changed notification, but no consumers provided. Roots list changed: {}", - roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); - } - - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private Mono asyncInitializeRequestHandler( - McpSchema.InitializeRequest initializeRequest) { - return Mono.defer(() -> { - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports - // if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST - // respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, this.instructions)); - }); - } - - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - @Override - public Mono closeGracefully() { - return this.mcpTransportProvider.closeGracefully(); + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); } - @Override - public void close() { - this.mcpTransportProvider.close(); + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); } - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return (exchange, params) -> exchange.listRoots() - .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // --------------------------------------- - // Tool Management - // --------------------------------------- - - @Override - public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); - } - if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolSpecification.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); - } - - this.tools.add(toolSpecification); - logger.debug("Added tool handler: {}", toolSpecification.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); - } - - @Override - public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); + } + private McpServerSession.RequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); - } - - @Override - public Mono notifyToolsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + new TypeReference() { }); - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; - } + return Mono.just(Map.of()); + }); + }; + } - // --------------------------------------- - // Resource Management - // --------------------------------------- + private McpServerSession.RequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); - @Override - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { - if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); } - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); - } + String type = request.ref().type(); - @Override - public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } + String argumentName = request.argument().name(); - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); - } - - @Override - public Mono notifyResourcesListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceSpecification::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - - } - - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), - resource.description(), resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); - - list.addAll(resourceTemplates); - - return list; - } - - private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + if (!promptSpec.prompt() + .arguments() .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) + .filter(arg -> arg.name().equals(argumentName)) .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - - return specification.readHandler().apply(exchange, resourceRequest); - }; - } - - // --------------------------------------- - // Prompt Management - // --------------------------------------- + .isPresent()) { - @Override - public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error(new McpError( - "Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); + return Mono.error(new McpError("Argument not found: " + argumentName)); } - return Mono.empty(); - }); - } - - @Override - public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); } - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); + if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); + if (resourceSpec == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); - } - - @Override - public Mono notifyPromptsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptSpecification::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); - if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); } - return specification.promptHandler().apply(exchange, promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - @Override - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); } - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, - loggingMessageNotification); - } - - private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { - return Mono.defer(() -> { - - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); - - exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - - // FIXME: this field is deprecated and should be removed together - // with the broadcasting loggingNotification. - this.minLoggingLevel = newMinLoggingLevel.level(); - - return Mono.just(Map.of()); - }); - }; - } - - private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { - McpSchema.CompleteRequest request = parseCompletionParams(params); - - if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); - } - - if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); - } - - String type = request.ref().type(); - - String argumentName = request.argument().name(); - - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); - if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); - } - if (!promptSpec.prompt() - .arguments() - .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { - - return Mono.error(new McpError("Argument not found: " + argumentName)); - } - } - - if (type.equals("ref/resource") - && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources - .get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); - } + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - } - - McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - - if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); - } - - return specification.completionHandler().apply(exchange, request); - }; - } - - /** - * Parses the raw JSON-RPC request parameters into a - * {@link McpSchema.CompleteRequest} object. - *

- * This method manually extracts the `ref` and `argument` fields from the input - * map, determines the correct reference type (either prompt or resource), and - * constructs a fully-typed {@code CompleteRequest} instance. - * @param object the raw request parameters, expected to be a Map containing "ref" - * and "argument" entries. - * @return a {@link McpSchema.CompleteRequest} representing the structured - * completion request. - * @throws IllegalArgumentException if the "ref" type is not recognized. - */ - @SuppressWarnings("unchecked") - private McpSchema.CompleteRequest parseCompletionParams(Object object) { - Map params = (Map) object; - Map refMap = (Map) params.get("ref"); - Map argMap = (Map) params.get("argument"); - - String refType = (String) refMap.get("type"); - - McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); - default -> throw new IllegalArgumentException("Invalid ref type: " + refType); - }; - - String argName = (String) argMap.get("name"); - String argValue = (String) argMap.get("value"); - McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument( - argName, argValue); - - return new McpSchema.CompleteRequest(ref, argument); - } + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } - // --------------------------------------- - // Sampling - // --------------------------------------- + return specification.completionHandler().apply(exchange, request); + }; + } - @Override - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; } } From b2d3e0098e484e172719237b0933fa395cdfdf4b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 12 May 2025 15:04:05 +0200 Subject: [PATCH 038/249] Next development version Signed-off-by: Christian Tzolov --- mcp-bom/pom.xml | 2 +- mcp-spring/mcp-spring-webflux/pom.xml | 6 +++--- mcp-spring/mcp-spring-webmvc/pom.xml | 6 +++--- mcp-test/pom.xml | 4 ++-- mcp/pom.xml | 2 +- pom.xml | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 4f24f719f..7214dacda 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 86f46bf95..a8b92bd09 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 82fbbf3e6..48d1c3465 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f1484ae77..a6e5bdb08 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT diff --git a/mcp/pom.xml b/mcp/pom.xml index 17693ab32..773432827 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp jar diff --git a/pom.xml b/pom.xml index 638457406..c2327ee8d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk From f34662555a0ab68d74ac118f1b0220441b2c81b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 15:38:02 +0200 Subject: [PATCH 039/249] Fix stdio tests - proper server-everything argument (#237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../modelcontextprotocol/client/StdioMcpAsyncClientTests.java | 4 ++-- .../modelcontextprotocol/client/StdioMcpSyncClientTests.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c39080138..8c0069d6d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -25,12 +25,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 8e75c4a3d..706aa9b2e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -33,12 +33,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); From 2e13f9f9df8610e0d05cc76b1416fe195e249303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 22:46:54 +0200 Subject: [PATCH 040/249] Fix flaky WebFluxSse integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2ba047461..03fbc9962 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -8,6 +8,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; @@ -651,9 +653,11 @@ void testInitialize(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) { + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); // Create a list to store received logging notifications - List receivedNotifications = new ArrayList<>(); + List receivedNotifications = new CopyOnWriteArrayList<>(); var clientBuilder = clientBuilders.get(clientType); @@ -709,6 +713,7 @@ void testLoggingNotification(String clientType) { // Create client with logging notification handler var mcpClient = clientBuilder.loggingConsumer(notification -> { receivedNotifications.add(notification); + latch.countDown(); }).build()) { // Initialize client @@ -724,31 +729,28 @@ void testLoggingNotification(String clientType) { assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(3); + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()) - .isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - }); + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); } mcpServer.close(); } From 1adfa8a047852c8f9e0188b4e63fe2020e0c66c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 14:05:39 +0200 Subject: [PATCH 041/249] Add Contributing Guidelines and Code of Conduct MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- CODE_OF_CONDUCT.md | 119 +++++++++++++++++++++++++++++++++++++++++++++ CONTRIBUTING.md | 91 ++++++++++++++++++++++++++++++++++ README.md | 7 +-- SECURITY.md | 21 ++++++++ 4 files changed, 233 insertions(+), 5 deletions(-) create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 SECURITY.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..6009a645f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,119 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a +harassment-free experience for everyone, regardless of age, body size, visible or +invisible disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, +inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community +include: + +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and + learning from the experience +- Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +- The use of sexualized language or imagery, and sexual attention or advances of any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without + their explicit permission +- Other conduct which could reasonably be considered inappropriate in a professional + setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in response to +any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, +commits, code, wiki edits, issues, and other contributions that are not aligned to this +Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an +individual is officially representing the community in public spaces. Examples of +representing our community include using an official e-mail address, posting via an +official social media account, or acting as an appointed representative at an online or +offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to +the community leaders responsible for enforcement at mcp-coc@anthropic.com. All +complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter +of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the +consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity +around the nature of the violation and an explanation of why the behavior was +inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, for a specified period of time. This includes avoiding interactions in community +spaces as well as external channels like social media. Violating these terms may lead to +a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained +inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication +with the community for a specified period of time. No public or private interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead to a permanent +ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, +including sustained inappropriate behavior, harassment of an individual, or aggression +toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, +available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..a949dcc09 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,91 @@ +# Contributing to Model Context Protocol Java SDK + +Thank you for your interest in contributing to the Model Context Protocol Java SDK! +This document outlines how to contribute to this project. + +## Prerequisites + +The following software is required to work on the codebase: + +- `Java 17` or above +- `Docker` +- `npx` + +## Getting Started + +1. Fork the repository +2. Clone your fork: + +```bash +git clone https://github.com/YOUR-USERNAME/java-sdk.git +cd java-sdk +``` + +3. Build from source: + +```bash +./mvnw clean install -DskipTests # skip the tests +./mvnw test # run tests +``` + +## Reporting Issues + +Please create an issue in the repository if you discover a bug or would like to +propose an enhancement. Bug reports should have a reproducer in the form of a code +sample or a repository attached that the maintainers or contributors can work with to +address the problem. + +## Making Changes + +1. Create a new branch: + +```bash +git checkout -b feature/your-feature-name +``` + +2. Make your changes +3. Validate your changes: + +```bash +./mvnw clean test +``` + +### Change Proposal Guidelines + +#### Principles of MCP + +1. **Simple + Minimal**: It is much easier to add things to the codebase than it is to + remove them. To maintain simplicity, we keep a high bar for adding new concepts and + primitives as each addition requires maintenance and compatibility consideration. +2. **Concrete**: Code changes need to be based on specific usage and implementation + challenges and not on speculative ideas. Most importantly, the SDK is meant to + implement the MCP specification. + +## Submitting Changes + +1. For non-trivial changes, please clarify with the maintainers in an issue whether + you can contribute the change and the desired scope of the change. +2. For trivial changes (for example a couple of lines or documentation changes) there + is no need to open an issue first. +3. Push your changes to your fork. +4. Submit a pull request to the main repository. +5. Follow the pull request template. +6. Wait for review. + +## Code of Conduct + +This project follows a Code of Conduct. Please review it in +[CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). + +## Questions + +If you have questions, please create a discussion in the repository. + +## License + +By contributing, you agree that your contributions will be licensed under the MIT +License. + +## Security + +Please review our [Security Policy](SECURITY.md) for reporting security issues. \ No newline at end of file diff --git a/README.md b/README.md index 9fc17306e..0cd3f84a4 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,8 @@ To run the tests you have to pre-install `Docker` and `npx`. ## Contributing -Contributions are welcome! Please: - -1. Fork the repository -2. Create a feature branch -3. Submit a Pull Request +Contributions are welcome! +Please follow the [Contributing Guidelines](CONTRIBUTING.md). ## Team diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..74e9880fd --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,21 @@ +# Security Policy + +Thank you for helping us keep the SDKs and systems they interact with secure. + +## Reporting Security Issues + +This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model +Context Protocol project. + +The security of our systems and user data is Anthropic’s top priority. We appreciate the +work of security researchers acting in good faith in identifying and reporting potential +vulnerabilities. + +Our security program is managed on HackerOne and we ask that any validated vulnerability +in this functionality be reported through their +[submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). + +## Vulnerability Disclosure Program + +Our Vulnerability Program Guidelines are defined on our +[HackerOne program page](https://hackerone.com/anthropic-vdp). \ No newline at end of file From 07e7b8fd6bac47be4527f97451f8cdd95ed31a38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 18:00:06 +0200 Subject: [PATCH 042/249] Add note about force pushes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- CONTRIBUTING.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a949dcc09..517f32555 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -71,6 +71,9 @@ git checkout -b feature/your-feature-name 4. Submit a pull request to the main repository. 5. Follow the pull request template. 6. Wait for review. +7. For any follow-up work, please add new commits instead of force-pushing. This will + allow the reviewer to focus on incremental changes instead of having to restart the + review process. ## Code of Conduct From 8a5a591d39256ba3947003ec4477e1722363eb35 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 27 May 2025 15:26:44 -0700 Subject: [PATCH 043/249] feat: Add elicitation support to MCP protocol Implement elicitation capabilities allowing servers to request additional information from users through clients during interactions. This feature provides a standardized way for servers to gather necessary information dynamically while clients maintain control over user interactions and data sharing. - Add ElicitRequest and ElicitResult classes to McpSchema - Implement elicitation handlers in client classes - Add elicitation capabilities to server exchange classes - Add tests for elicitation functionality with various scenarios --- .../WebFluxSseIntegrationTests.java | 224 +++++++++++++++++- .../server/WebMvcSseIntegrationTests.java | 213 +++++++++++++++++ .../client/McpAsyncClient.java | 32 +++ .../client/McpClient.java | 40 +++- .../client/McpClientFeatures.java | 31 ++- .../server/McpAsyncServerExchange.java | 28 +++ .../server/McpSyncServerExchange.java | 18 ++ .../modelcontextprotocol/spec/McpSchema.java | 129 ++++++++-- .../client/AbstractMcpAsyncClientTests.java | 22 +- .../McpAsyncClientResponseHandlerTests.java | 150 ++++++++++++ ...rverTransportProviderIntegrationTests.java | 213 +++++++++++++++++ .../spec/McpSchemaTests.java | 34 +++ 12 files changed, 1106 insertions(+), 28 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 03fbc9962..2f85654e8 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -28,11 +27,11 @@ import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.*; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -41,6 +40,7 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -331,6 +331,226 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt mcpServer.closeGracefully().block(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index b12d68439..3f3f7be62 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -357,6 +357,219 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba3..a22ef6b51 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -23,6 +23,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; @@ -141,6 +143,15 @@ public class McpAsyncClient { */ private Function> samplingHandler; + /** + * MCP provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to maintain + * control over user interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured data from users + * with optional JSON schemas to validate responses. + */ + private Function> elicitationHandler; + /** * Client transport implementation. */ @@ -189,6 +200,15 @@ public class McpAsyncClient { requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); } + // Elicitation Handler + if (this.clientCapabilities.elicitation() != null) { + if (features.elicitationHandler() == null) { + throw new McpError("Elicitation handler must not be null when client capabilities include elicitation"); + } + this.elicitationHandler = features.elicitationHandler(); + requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); + } + // Notification Handlers Map notificationHandlers = new HashMap<>(); @@ -500,6 +520,18 @@ private RequestHandler samplingCreateMessageHandler() { }; } + // -------------------------- + // Elicitation + // -------------------------- + private RequestHandler elicitationCreateHandler() { + return params -> { + ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + }); + + return this.elicitationHandler.apply(request); + }; + } + // -------------------------- // Tools // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc11685..280906cff 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -18,6 +18,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.util.Assert; @@ -175,6 +177,8 @@ class SyncSpec { private Function samplingHandler; + private Function elicitationHandler; + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -283,6 +287,21 @@ public SyncSpec sampling(Function sam return this; } + /** + * Sets a custom elicitation handler for processing elicitation message requests. + * The elicitation handler can modify or validate messages before they are sent to + * the server, enabling custom processing logic. + * @param elicitationHandler A function that processes elicitation requests and + * returns results. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationHandler is null + */ + public SyncSpec elicitation(Function elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -364,7 +383,7 @@ public SyncSpec loggingConsumers(List> samplingHandler; + private Function> elicitationHandler; + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -522,6 +543,21 @@ public AsyncSpec sampling(Function> elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -606,7 +642,7 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler, this.elicitationHandler)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 284b93f88..23d7c6a60 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -60,13 +60,15 @@ class McpClientFeatures { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { /** * Create an instance and validate the arguments. @@ -77,6 +79,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -84,14 +87,16 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -99,6 +104,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } /** @@ -138,9 +144,14 @@ public static Async fromSync(Sync syncSpec) { Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); + + Function> elicitationHandler = r -> Mono + .fromCallable(() -> syncSpec.elicitationHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()); + return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers, - samplingHandler); + samplingHandler, elicitationHandler); } } @@ -156,13 +167,15 @@ public static Async fromSync(Sync syncSpec) { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { /** * Create an instance and validate the arguments. @@ -174,20 +187,23 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new HashMap<>(roots) : new HashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -195,6 +211,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d0..cfb07d26c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -36,6 +36,9 @@ public class McpAsyncServerExchange { private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { }; + private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + }; + /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. @@ -93,6 +96,31 @@ public Mono createMessage(McpSchema.CreateMessage CREATE_MESSAGE_RESULT_TYPE_REF); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A Mono that completes when the elicitation has been resolved. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null) { + return Mono.error(new McpError("Client must be configured with elicitation capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, + ELICITATION_RESULT_TYPE_REF); + } + /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 52360e54b..084412b96 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -64,6 +64,24 @@ public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageReques return this.exchange.createMessage(createMessageRequest).block(); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A result containing the elicitation response. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitation(elicitRequest).block(); + } + /** * Retrieves the list of all roots provided by the client. * @return The list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a1584..9dae08266 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -94,6 +94,9 @@ private McpSchema() { // Sampling Methods public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; + // Elicitation Methods + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // --------------------------- @@ -131,8 +134,8 @@ public static final class ErrorCodes { } - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, CompleteRequest, GetPromptRequest { + public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, + CompleteRequest, GetPromptRequest { } @@ -221,7 +224,7 @@ public record JSONRPCError( public record InitializeRequest( // @formatter:off @JsonProperty("protocolVersion") String protocolVersion, @JsonProperty("capabilities") ClientCapabilities capabilities, - @JsonProperty("clientInfo") Implementation clientInfo) implements Request { + @JsonProperty("clientInfo") Implementation clientInfo) implements Request { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -245,6 +248,8 @@ public record InitializeResult( // @formatter:off * access to. * @param sampling Provides a standardized way for servers to request LLM sampling * (“completions” or “generations”) from language models via clients. + * @param elicitation Provides a standardized way for servers to request additional + * information from users through the client during interactions. * */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -252,7 +257,8 @@ public record InitializeResult( // @formatter:off public record ClientCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, - @JsonProperty("sampling") Sampling sampling) { + @JsonProperty("sampling") Sampling sampling, + @JsonProperty("elicitation") Elicitation elicitation) { /** * Roots define the boundaries of where servers can operate within the filesystem, @@ -264,7 +270,7 @@ public record ClientCapabilities( // @formatter:off * has changed since the last time the server checked. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) + @JsonIgnoreProperties(ignoreUnknown = true) public record RootCapabilities( @JsonProperty("listChanged") Boolean listChanged) { } @@ -279,10 +285,22 @@ public record RootCapabilities( * image-based interactions and optionally include context * from MCP servers in their prompts. */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record Sampling() { } + /** + * Provides a standardized way for servers to request additional + * information from users through the client during interactions. + * This flow allows clients to maintain control over user + * interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured + * data from users with optional JSON schemas to validate responses. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Elicitation() { + } + public static Builder builder() { return new Builder(); } @@ -291,6 +309,7 @@ public static class Builder { private Map experimental; private RootCapabilities roots; private Sampling sampling; + private Elicitation elicitation; public Builder experimental(Map experimental) { this.experimental = experimental; @@ -307,8 +326,13 @@ public Builder sampling() { return this; } + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } + public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling); + return new ClientCapabilities(experimental, roots, sampling, elicitation); } } }// @formatter:on @@ -326,11 +350,11 @@ public record ServerCapabilities( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) public record CompletionCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record LoggingCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record PromptCapabilities( @JsonProperty("listChanged") Boolean listChanged) { @@ -727,11 +751,11 @@ public record Tool( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @JsonProperty("inputSchema") JsonSchema inputSchema) { - + public Tool(String name, String description, String schema) { this(name, description, parseSchema(schema)); } - + } // @formatter:on private static JsonSchema parseSchema(String schema) { @@ -758,7 +782,7 @@ public record CallToolRequest(// @formatter:off @JsonProperty("arguments") Map arguments) implements Request { public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments)); + this(name, parseJsonArguments(jsonArguments)); } private static Map parseJsonArguments(String jsonArguments) { @@ -893,7 +917,7 @@ public record ModelPreferences(// @formatter:off @JsonProperty("costPriority") Double costPriority, @JsonProperty("speedPriority") Double speedPriority, @JsonProperty("intelligencePriority") Double intelligencePriority) { - + public static Builder builder() { return new Builder(); } @@ -963,7 +987,7 @@ public record CreateMessageRequest(// @formatter:off @JsonProperty("includeContext") ContextInclusionStrategy includeContext, @JsonProperty("temperature") Double temperature, @JsonProperty("maxTokens") int maxTokens, - @JsonProperty("stopSequences") List stopSequences, + @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata) implements Request { public enum ContextInclusionStrategy { @@ -971,7 +995,7 @@ public enum ContextInclusionStrategy { @JsonProperty("thisServer") THIS_SERVER, @JsonProperty("allServers") ALL_SERVERS } - + public static Builder builder() { return new Builder(); } @@ -1040,7 +1064,7 @@ public record CreateMessageResult(// @formatter:off @JsonProperty("content") Content content, @JsonProperty("model") String model, @JsonProperty("stopReason") StopReason stopReason) { - + public enum StopReason { @JsonProperty("endTurn") END_TURN, @JsonProperty("stopSequence") STOP_SEQUENCE, @@ -1088,6 +1112,79 @@ public CreateMessageResult build() { } }// @formatter:on + // Elicitation + /** + * Used by the server to send an elicitation to the client. + * + * @param message The body of the elicitation message. + * @param requestedSchema The elicitation response schema that must be satisfied. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitRequest(// @formatter:off + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") Map requestedSchema) implements Request { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String message; + private Map requestedSchema; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(Map requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitResult(// @formatter:off + @JsonProperty("action") Action action, + @JsonProperty("content") Map content) { + + public enum Action { + @JsonProperty("accept") ACCEPT, + @JsonProperty("decline") DECLINE, + @JsonProperty("cancel") CANCEL + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Action action; + private Map content; + + public Builder message(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content); + } + } + }// @formatter:on + // --------------------------- // Pagination Interfaces // --------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72b409af9..d1a2581ee 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.Resource; @@ -422,6 +424,20 @@ void testInitializeWithSamplingCapability() { }); } + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() @@ -433,7 +449,11 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 4510b1529..e6cde8e3b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; @@ -349,4 +351,152 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } + @Test + @SuppressWarnings("unchecked") + void testElicitationCreateRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler that echoes back the input + Function> elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isInstanceOf(Map.class); + assertThat(request.requestedSchema().get("type")).isEqualTo("object"); + + var properties = request.requestedSchema().get("properties"); + assertThat(properties).isNotNull(); + assertThat(((Map) properties).get("message")).isInstanceOf(Map.class); + + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + }; + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isEqualTo(Map.of("message", "Test message")); + + asyncMcpClient.closeGracefully(); + } + + @ParameterizedTest + @EnumSource(value = McpSchema.ElicitResult.Action.class, names = { "DECLINE", "CANCEL" }) + void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler to decline the request + Function> elicitationHandler = request -> Mono + .just(McpSchema.ElicitResult.builder().message(action).build()); + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(action); + assertThat(result.content()).isNull(); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithoutCapability() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client without elicitation capability + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().build()) // No elicitation + // capability + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = new McpSchema.ElicitRequest("test", + Map.of("type", "object", "properties", Map.of("test", Map.of("type", "boolean", "defaultValue", true, + "description", "test-description", "title", "test-title")))); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.result()).isNull(); + assertThat(response.error()).isNotNull(); + assertThat(response.error().message()).contains("Method not found: elicitation/create"); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithNullHandler() { + MockMcpClientTransport transport = new MockMcpClientTransport(); + + // Create client with elicitation capability but null handler + assertThatThrownBy(() -> McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .build()).isInstanceOf(McpError.class) + .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 2ff6325a4..dc9d1cfab 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -24,6 +24,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; @@ -339,6 +341,217 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + @Disabled + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index ff78c1bfc..99015d8c4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -807,6 +807,40 @@ void testCreateMessageResult() throws Exception { {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); } + // Elicitation Tests + + @Test + void testCreateElicitationRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", + Map.of("foo", Map.of("type", "string")))) + .build(); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); + } + + @Test + void testCreateElicitationResult() throws Exception { + McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() + .content(Map.of("foo", "bar")) + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"action":"accept","content":{"foo":"bar"}}""")); + } + // Roots Tests @Test From 2f944349cc77009d020ebddc8b9967b8e1b14baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 10 Jun 2025 18:33:40 +0200 Subject: [PATCH 044/249] feat: WebClient Streamable HTTP support (#292) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit An implementation of Streamable HTTP Client with WebFlux WebClient. Aside from implementing the specification, several improvements have been incorporated throughout the client-side of the architecture. The changes cover: - resilience tests using toxiproxy in testcontainers - integration tests using updated everything-server with streamableHttp support - improved logging - session invalidation handling (both transport session and JSON-RPC concept of session) - implicit initialization and burst protection (in case of concurrent `Mcp(Sync|Async)Client` use - more logging, e.g. stdio process lifecycle logs Related #72, #273, #253, #107, #105 Signed-off-by: Dariusz Jędrzejczyk --- mcp-spring/mcp-spring-webflux/pom.xml | 6 + .../WebClientStreamableHttpTransport.java | 520 ++++++++++++++++++ .../transport/WebFluxSseClientTransport.java | 3 + ...eamableHttpAsyncClientResiliencyTests.java | 17 + ...bClientStreamableHttpAsyncClientTests.java | 42 ++ ...ebClientStreamableHttpSyncClientTests.java | 41 ++ .../client/WebFluxSseMcpAsyncClientTests.java | 3 +- .../client/WebFluxSseMcpSyncClientTests.java | 3 +- .../WebFluxSseClientTransportTests.java | 3 +- .../src/test/resources/logback.xml | 8 +- mcp-test/pom.xml | 5 + ...AbstractMcpAsyncClientResiliencyTests.java | 222 ++++++++ .../client/AbstractMcpAsyncClientTests.java | 37 +- .../client/AbstractMcpSyncClientTests.java | 60 +- .../client/McpAsyncClient.java | 291 ++++++---- .../transport/StdioClientTransport.java | 5 + .../spec/DefaultMcpTransportSession.java | 79 +++ .../spec/DefaultMcpTransportStream.java | 74 +++ .../spec/McpClientSession.java | 53 +- .../spec/McpClientTransport.java | 22 +- .../modelcontextprotocol/spec/McpSchema.java | 6 + .../spec/McpTransportSession.java | 60 ++ .../McpTransportSessionNotFoundException.java | 29 + .../spec/McpTransportStream.java | 45 ++ .../client/AbstractMcpAsyncClientTests.java | 37 +- .../client/AbstractMcpSyncClientTests.java | 58 +- .../client/HttpSseMcpAsyncClientTests.java | 3 +- .../client/HttpSseMcpSyncClientTests.java | 3 +- .../client/StdioMcpSyncClientTests.java | 2 +- .../HttpClientSseClientTransportTests.java | 3 +- pom.xml | 3 +- 31 files changed, 1501 insertions(+), 242 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index a8b92bd09..26452fe95 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -99,6 +99,12 @@ ${testcontainers.version} test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + org.awaitility diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java new file mode 100644 index 000000000..e7b7c8ee9 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -0,0 +1,520 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.util.Assert; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

+ * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

+ *

+ * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@link WebFluxSseClientTransport}. + *

+ * + * @author Dariusz Jędrzejczyk + * @see Streamable + * HTTP transport specification + */ +public class WebClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { + }; + + private final ObjectMapper objectMapper; + + private final WebClient webClient; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final AtomicReference activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { + this.objectMapper = objectMapper; + this.webClient = webClientBuilder.build(); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + } + + /** + * Create a stateful builder for creating {@link WebClientStreamableHttpTransport} + * instances. + * @param webClientBuilder the {@link WebClient.Builder} to use + * @return a builder which will create an instance of + * {@link WebClientStreamableHttpTransport} once {@link Builder#build()} is called + */ + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).then(); + } + return Mono.empty(); + }); + } + + private DefaultMcpTransportSession createTransportSession() { + Supplier> onClose = () -> { + DefaultMcpTransportSession transportSession = this.activeSession.get(); + return transportSession.sessionId().isEmpty() ? Mono.empty() + : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { + httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); + }).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then(); + }; + return new DefaultMcpTransportSession(onClose); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + if (currentSession != null) { + return currentSession.closeGracefully(); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + return Mono.deferContextual(ctx -> { + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + // Here we attempt to initialize the client. In case the server supports SSE, + // we will establish a long-running + // session here and listen for messages. If it doesn't, that's ok, the server + // is a simple, stateless one. + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + if (stream != null) { + stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); + } + }) + .exchangeToFlux(response -> { + if (isEventStream(response)) { + return eventStream(stream, response); + } + else if (isNotAllowed(response)) { + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (isNotFound(response)) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return response.createError().doOnError(e -> { + logger.info("Opening an SSE stream failed. This can be safely ignored.", e); + }).flux(); + } + }) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.create(sink -> { + logger.debug("Sending message {}", message); + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session + // here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.post() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + }) + .bodyValue(message) + .exchangeToFlux(response -> { + if (transportSession + .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + reconnect(null).contextWrite(sink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + // The spec mentions only ACCEPTED, but the existing SDKs can return + // 200 OK for notifications + if (response.statusCode().is2xxSuccessful()) { + Optional contentType = response.headers().contentType(); + // Existing SDKs consume notifications with no response body nor + // content type + if (contentType.isEmpty()) { + logger.trace("Message was successfully sent via POST for session {}", + sessionRepresentation); + // signal the caller that the message was successfully + // delivered + sink.success(); + // communicate to downstream there is no streamed data coming + return Flux.empty(); + } + else { + MediaType mediaType = contentType.get(); + if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + // communicate to caller that the message was delivered + sink.success(); + // starting a stream + return newEventStream(response, sessionRepresentation); + } + else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + logger.trace("Received response to POST for session {}", sessionRepresentation); + // communicate to caller the message was delivered + sink.success(); + return responseFlux(response); + } + else { + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); + } + } + } + else { + if (isNotFound(response)) { + return mcpSessionNotFoundError(sessionRepresentation); + } + return extractError(response, sessionRepresentation); + } + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorResume(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + sink.error(t); + return Flux.empty(); + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static Flux mcpSessionNotFoundError(String sessionRepresentation) { + logger.warn("Session {} was not found on the MCP server", sessionRepresentation); + // inform the stream/connection subscriber + return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation)); + } + + private Flux extractError(ClientResponse response, String sessionRepresentation) { + return response.createError().onErrorResume(e -> { + WebClientResponseException responseException = (WebClientResponseException) e; + byte[] body = responseException.getResponseBodyAsByteArray(); + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; + Exception toPropagate; + try { + McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, + McpSchema.JSONRPCResponse.class); + jsonRpcError = jsonRpcResponse.error(); + toPropagate = new McpError(jsonRpcError); + } + catch (IOException ex) { + toPropagate = new RuntimeException("Sending request failed", e); + logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); + } + + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.empty(); + }).flux(); + } + + private Flux eventStream(McpTransportStream stream, ClientResponse response) { + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse); + return Flux.from(sessionStream.consumeSseStream(idWithMessages)); + } + + private static boolean isNotFound(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND); + } + + private static boolean isNotAllowed(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED); + } + + private static boolean isEventStream(ClientResponse response) { + return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + private Flux responseFlux(ClientResponse response) { + return response.bodyToMono(String.class).>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } + catch (IOException e) { + s.error(e); + } + }).flatMapIterable(Function.identity()); + } + + private Flux newEventStream(ClientResponse response, String sessionRepresentation) { + McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), + sessionRepresentation); + return eventStream(sessionStream, response); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + private Tuple2, Iterable> parse(ServerSentEvent event) { + if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + // We don't support batching ATM and probably won't since the next version + // considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); + } + catch (IOException ioException) { + throw new McpError("Error parsing JSON-RPC message: " + event.data()); + } + } + else { + throw new McpError("Received unrecognized SSE event type: " + event.event()); + } + } + + /** + * Builder for {@link WebClientStreamableHttpTransport}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private WebClient.Builder webClientBuilder; + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Configure the {@link ObjectMapper} to use. + * @param objectMapper instance to use + * @return the builder instance + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Configure the {@link WebClient.Builder} to construct the {@link WebClient}. + * @param webClientBuilder instance to use + * @return the builder instance + */ + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link WebClientStreamableHttpTransport} + */ + public WebClientStreamableHttpTransport build() { + ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + + return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams, + openConnectionOnStartup); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 37abe295b..128cda4c3 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -190,6 +190,9 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe */ @Override public Mono connect(Function, Mono> handler) { + // TODO: Avoid eager connection opening and enable resilience + // -> upon disconnects, re-establish connection + // -> allow optimizing for eager connection start using a constructor flag Flux> events = eventStream(); this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { if (ENDPOINT_EVENT_TYPE.equals(event.event())) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..80fc671e2 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,17 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..4c8032659 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,42 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.images.builder.ImageFromDockerfile; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..a8cad4898 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +@Timeout(15) +public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index b43c14493..f0533cb49 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 66ac8a6dd..9b0959a35 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index c757d3da9..42b91d14e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -41,7 +41,8 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 5ad73374a..abc831d13 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,13 +9,13 @@ - + - - + + - + diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index a6e5bdb08..9998569dc 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -68,6 +68,11 @@ junit-jupiter ${testcontainers.version} + + org.testcontainers + toxiproxy + ${toxiproxy.version} + org.awaitility diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..85d6a88e4 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,222 @@ +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + private static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + private static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + private static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next use should immediately re-initialize with no issue and send the + // request without any broken connections. + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 5452c8eac..049bea008 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -110,14 +110,16 @@ void tearDown() { onClose(); } - void verifyInitializationTimeout(Function> operation, String action) { + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); }); } @@ -133,7 +135,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -153,7 +155,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -168,7 +170,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -202,7 +204,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -233,7 +235,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -258,7 +260,7 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -279,7 +281,7 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -354,7 +356,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(), + "listing resource templates"); } @Test @@ -447,8 +450,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 128441f80..3785fd645 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -12,7 +13,6 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -112,33 +112,18 @@ void tearDown() { static final Object DUMMY_RETURN_VALUE = new Object(); - void verifyNotificationTimesOut(Consumer operation, String action) { - verifyCallTimesOut(client -> { + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { operation.accept(client); return DUMMY_RETURN_VALUE; }, action); } - void verifyCallTimesOut(Function blockingOperation, String action) { + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - // This scheduler is not replaced by virtual time scheduler - Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) - // Offload the blocking call to the real scheduler - .subscribeOn(customScheduler)) - .expectSubscription() - // This works without actually waiting but executes all the - // tasks pending execution on the VirtualTimeScheduler. - // It is possible to execute the blocking code from the operation - // because it is blocked on a dedicated Scheduler and the main - // flow is not blocked and uses the VirtualTimeScheduler. - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); - - customScheduler.dispose(); + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))) + .expectNextCount(1) + .verifyComplete(); }); } @@ -154,7 +139,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -175,8 +160,8 @@ void testListTools() { @Test void testCallToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), - "calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test @@ -200,7 +185,7 @@ void testCallTools() { @Test void testPingWithoutInitialization() { - verifyCallTimesOut(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -214,7 +199,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -243,7 +228,7 @@ void testCallToolWithInvalidTool() { @Test void testRootsListChangedWithoutInitialization() { - verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -257,7 +242,7 @@ void testRootsListChanged() { @Test void testListResourcesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -333,8 +318,14 @@ void testRemoveNonExistentRoot() { @Test void testReadResourceWithoutInitialization() { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test @@ -355,7 +346,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null), + "listing resource templates"); } @Test @@ -413,8 +405,8 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index a22ef6b51..8f0433eb1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -9,9 +9,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpClientSession; @@ -32,7 +32,7 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; @@ -77,29 +77,37 @@ * @see McpClient * @see McpSchema * @see McpClientSession + * @see McpClientTransport */ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; - protected final Sinks.One initializedSink = Sinks.one(); + public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + }; - private AtomicBoolean initialized = new AtomicBoolean(false); + private final AtomicReference initializationRef = new AtomicReference<>(); /** * The max timeout to await for the client-server connection to be initialized. */ private final Duration initializationTimeout; - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final McpClientSession mcpSession; - /** * Client capabilities. */ @@ -110,21 +118,6 @@ public class McpAsyncClient { */ private final McpSchema.Implementation clientInfo; - /** - * Server capabilities. - */ - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * Server instructions. - */ - private String serverInstructions; - - /** - * Server implementation information. - */ - private McpSchema.Implementation serverInfo; - /** * Roots define the boundaries of where servers can operate within the filesystem, * allowing them to understand which directories and files they have access to. @@ -155,13 +148,19 @@ public class McpAsyncClient { /** * Client transport implementation. */ - private final McpTransport transport; + private final McpClientTransport transport; /** * Supported protocol versions. */ private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Supplier sessionSupplier; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. @@ -254,8 +253,29 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); + this.transport.setExceptionHandler(this::handleException); + this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers); + + } + + private void handleException(Throwable t) { + logger.warn("Handling exception", t); + if (t instanceof McpTransportSessionNotFoundException) { + Initialization previous = this.initializationRef.getAndSet(null); + if (previous != null) { + previous.close(); + } + // Providing an empty operation since we are only interested in triggering the + // implicit initialization step. + withSession("re-initializing", result -> Mono.empty()).subscribe(); + } + } + private McpSchema.InitializeResult currentInitializationResult() { + Initialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; } /** @@ -263,7 +283,8 @@ public class McpAsyncClient { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.capabilities() : null; } /** @@ -272,7 +293,8 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - return this.serverInstructions; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.instructions() : null; } /** @@ -280,7 +302,8 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.serverInfo() : null; } /** @@ -288,7 +311,8 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - return this.initialized.get(); + Initialization current = this.initializationRef.get(); + return current != null && (current.result.get() != null); } /** @@ -311,7 +335,11 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - this.mcpSession.close(); + Initialization current = this.initializationRef.getAndSet(null); + if (current != null) { + current.close(); + } + this.transport.close(); } /** @@ -319,14 +347,21 @@ public void close() { * @return A Mono that completes when the connection is closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + return Mono.defer(() -> { + Initialization current = this.initializationRef.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose.then(transport.closeGracefully()); + }); } // -------------------------- // Initialization // -------------------------- /** - * The initialization phase MUST be the first interaction between client and server. + * The initialization phase should be the first interaction between client and server. + * The client will ensure it happens in case it has not been explicitly called and in + * case of transport session invalidation. + *

* During this phase, the client and server: *