From e610d853f922e36ba474b2240f5c6546166e4840 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 20 Apr 2025 10:58:51 +0300 Subject: [PATCH 001/216] feat: Add customizable URI template manager factory to MCP server Implement URI template functionality for MCP resources, allowing dynamic resource URIs with variables in the format {variableName}. - Enable resource URIs with variable placeholders (e.g., "/api/users/{userId}") - Automatic extraction of variable values from request URIs - Validation of template arguments in completions - Matching of request URIs against templates - Add new URI template management interfaces and implementations - Enhanced resource template listing to include templated resources - Updated resource request handling to support template matching - Test coverage for URI template functionality - Adding a configurable uriTemplateManagerFactory field to both AsyncSpecification and SyncSpecification classes - Adding builder methods to allow setting a custom URI template manager factory - Modifying constructors to pass the URI template manager factory to the server implementation - Updating the server implementation to use the provided factory - Add bulk registration methods for async completions Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 3 +- .../server/McpAsyncServer.java | 77 +++++++-- .../server/McpServer.java | 68 +++++++- .../DeafaultMcpUriTemplateManagerFactory.java | 23 +++ .../util/DefaultMcpUriTemplateManager.java | 163 ++++++++++++++++++ .../util/McpUriTemplateManager.java | 52 ++++++ .../util/McpUriTemplateManagerFactory.java | 22 +++ .../McpUriTemplateManagerTests.java | 97 +++++++++++ 8 files changed, 489 insertions(+), 16 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 660f814da..2ba047461 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -776,7 +776,8 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().completions().build()) .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "this is code review prompt", List.of()), + new Prompt("code_review", "this is code review prompt", + List.of(new PromptArgument("language", "string", false))), (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( new McpSchema.PromptReference("ref/prompt", "code_review"), completionHandler)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 906cb9a08..3c112ad76 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.server; import java.time.Duration; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -22,10 +23,13 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,8 +96,10 @@ public class McpAsyncServer { * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features); + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features, + uriTemplateManagerFactory); } /** @@ -274,8 +280,11 @@ private static class AsyncServerImpl extends McpAsyncServer { private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - Duration requestTimeout, McpServerFeatures.Async features) { + Duration requestTimeout, McpServerFeatures.Async features, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -286,6 +295,7 @@ private static class AsyncServerImpl extends McpAsyncServer { this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; Map> requestHandlers = new HashMap<>(); @@ -564,8 +574,26 @@ private McpServerSession.RequestHandler resources private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), + resource.description(), resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; } private McpServerSession.RequestHandler resourcesReadRequestHandler() { @@ -574,11 +602,16 @@ private McpServerSession.RequestHandler resourcesR new TypeReference() { }); var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceSpecification specification = this.resources.get(resourceUri); - if (specification != null) { - return specification.readHandler().apply(exchange, resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); + + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(exchange, resourceRequest); }; } @@ -729,20 +762,38 @@ private McpServerSession.RequestHandler completionComp String type = request.ref().type(); + String argumentName = request.argument().name(); + // check if the referenced resource exists if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification prompt = this.prompts.get(promptReference.name()); - if (prompt == null) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + return Mono.error(new McpError("Argument not found: " + argumentName)); + } } if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resource = this.resources.get(resourceReference.uri()); - if (resource == null) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources + .get(resourceReference.uri()); + if (resourceSpec == null) { return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + } McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 84089703c..d6ec2cc30 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import reactor.core.publisher.Mono; /** @@ -156,6 +158,8 @@ class AsyncSpecification { private final McpServerTransportProvider transportProvider; + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private ObjectMapper objectMapper; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -204,6 +208,19 @@ private AsyncSpecification(McpServerTransportProvider transportProvider) { this.transportProvider = transportProvider; } + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + /** * Sets the duration to wait for server responses before timing out requests. This * timeout applies to all requests made through the client, including tool calls, @@ -517,6 +534,36 @@ public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... return this; } + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + /** * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new @@ -587,7 +634,8 @@ public McpAsyncServer build() { this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory); } } @@ -600,6 +648,8 @@ class SyncSpecification { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private final McpServerTransportProvider transportProvider; private ObjectMapper objectMapper; @@ -650,6 +700,19 @@ private SyncSpecification(McpServerTransportProvider transportProvider) { this.transportProvider = transportProvider; } + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + /** * Sets the duration to wait for server responses before timing out requests. This * timeout applies to all requests made through the client, including tool calls, @@ -1064,7 +1127,8 @@ public McpSyncServer build() { this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory); return new McpSyncServer(asyncServer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java new file mode 100644 index 000000000..3870b76fc --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java @@ -0,0 +1,23 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.util; + +/** + * @author Christian Tzolov + */ +public class DeafaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + @Override + public McpUriTemplateManager create(String uriTemplate) { + return new DefaultMcpUriTemplateManager(uriTemplate); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java b/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java new file mode 100644 index 000000000..b2e9a5285 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java @@ -0,0 +1,163 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Default implementation of the UriTemplateUtils interface. + *

+ * This class provides methods for extracting variables from URI templates and matching + * them against actual URIs. + * + * @author Christian Tzolov + */ +public class DefaultMcpUriTemplateManager implements McpUriTemplateManager { + + /** + * Pattern to match URI variables in the format {variableName}. + */ + private static final Pattern URI_VARIABLE_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); + + private final String uriTemplate; + + /** + * Constructor for DefaultMcpUriTemplateManager. + * @param uriTemplate The URI template to be used for variable extraction + */ + public DefaultMcpUriTemplateManager(String uriTemplate) { + if (uriTemplate == null || uriTemplate.isEmpty()) { + throw new IllegalArgumentException("URI template must not be null or empty"); + } + this.uriTemplate = uriTemplate; + } + + /** + * Extract URI variable names from a URI template. + * @param uriTemplate The URI template containing variables in the format + * {variableName} + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + @Override + public List getVariableNames() { + if (uriTemplate == null || uriTemplate.isEmpty()) { + return List.of(); + } + + List variables = new ArrayList<>(); + Matcher matcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + + while (matcher.find()) { + String variableName = matcher.group(1); + if (variables.contains(variableName)) { + throw new IllegalArgumentException("Duplicate URI variable name in template: " + variableName); + } + variables.add(variableName); + } + + return variables; + } + + /** + * Extract URI variable values from the actual request URI. + *

+ * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param requestUri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + @Override + public Map extractVariableValues(String requestUri) { + Map variableValues = new HashMap<>(); + List uriVariables = this.getVariableNames(); + + if (requestUri == null || uriVariables.isEmpty()) { + return variableValues; + } + + try { + // Create a regex pattern by replacing each {variableName} with a capturing + // group + StringBuilder patternBuilder = new StringBuilder("^"); + + // Find all variable placeholders and their positions + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Add the text between the last variable and this one, escaped for regex + String textBefore = uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + + // Add a capturing group for the variable + patternBuilder.append("([^/]+)"); + + lastEnd = variableMatcher.end(); + } + + // Add any remaining text after the last variable + if (lastEnd < uriTemplate.length()) { + patternBuilder.append(Pattern.quote(uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); + + // Compile the pattern and match against the request URI + Pattern pattern = Pattern.compile(patternBuilder.toString()); + Matcher matcher = pattern.matcher(requestUri); + + if (matcher.find() && matcher.groupCount() == uriVariables.size()) { + for (int i = 0; i < uriVariables.size(); i++) { + String value = matcher.group(i + 1); + if (value == null || value.isEmpty()) { + throw new IllegalArgumentException( + "Empty value for URI variable '" + uriVariables.get(i) + "' in URI: " + requestUri); + } + variableValues.put(uriVariables.get(i), value); + } + } + } + catch (Exception e) { + throw new IllegalArgumentException("Error parsing URI template: " + uriTemplate + " for URI: " + requestUri, + e); + } + + return variableValues; + } + + /** + * Check if a URI matches the uriTemplate with variables. + * @param uri The URI to check + * @return true if the URI matches the pattern, false otherwise + */ + @Override + public boolean matches(String uri) { + // If the uriTemplate doesn't contain variables, do a direct comparison + if (!this.isUriTemplate(this.uriTemplate)) { + return uri.equals(this.uriTemplate); + } + + // Convert the pattern to a regex + String regex = this.uriTemplate.replaceAll("\\{[^/]+?\\}", "([^/]+?)"); + regex = regex.replace("/", "\\/"); + + // Check if the URI matches the regex + return Pattern.compile(regex).matcher(uri).matches(); + } + + @Override + public boolean isUriTemplate(String uri) { + return URI_VARIABLE_PATTERN.matcher(uri).find(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java new file mode 100644 index 000000000..19569e49f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +/** + * Interface for working with URI templates. + *

+ * This interface provides methods for extracting variables from URI templates and + * matching them against actual URIs. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManager { + + /** + * Extract URI variable names from this URI template. + * @return A list of variable names extracted from the template + * @throws IllegalArgumentException if duplicate variable names are found + */ + List getVariableNames(); + + /** + * Extract URI variable values from the actual request URI. + *

+ * This method converts the URI template into a regex pattern, then uses that pattern + * to extract variable values from the request URI. + * @param uri The actual URI from the request + * @return A map of variable names to their values + * @throws IllegalArgumentException if the URI template is invalid or the request URI + * doesn't match the template pattern + */ + Map extractVariableValues(String uri); + + /** + * Indicate whether the given URI matches this template. + * @param uri the URI to match to + * @return {@code true} if it matches; {@code false} otherwise + */ + boolean matches(String uri); + + /** + * Check if the given URI is a URI template. + * @return Returns true if the URI contains variables in the format {variableName} + */ + public boolean isUriTemplate(String uri); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java new file mode 100644 index 000000000..9644f9a6c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java @@ -0,0 +1,22 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.util; + +/** + * Factory interface for creating instances of {@link McpUriTemplateManager}. + * + * @author Christian Tzolov + */ +public interface McpUriTemplateManagerFactory { + + /** + * Creates a new instance of {@link McpUriTemplateManager} with the specified URI + * template. + * @param uriTemplate The URI template to be used for variable extraction + * @return A new instance of {@link McpUriTemplateManager} + * @throws IllegalArgumentException if the URI template is null or empty + */ + McpUriTemplateManager create(String uriTemplate); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java b/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java new file mode 100644 index 000000000..6f041daa6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManager; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link McpUriTemplateManager} and its implementations. + * + * @author Christian Tzolov + */ +public class McpUriTemplateManagerTests { + + private McpUriTemplateManagerFactory uriTemplateFactory; + + @BeforeEach + void setUp() { + this.uriTemplateFactory = new DeafaultMcpUriTemplateManagerFactory(); + } + + @Test + void shouldExtractVariableNamesFromTemplate() { + List variables = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .getVariableNames(); + assertEquals(2, variables.size()); + assertEquals("userId", variables.get(0)); + assertEquals("postId", variables.get(1)); + } + + @Test + void shouldReturnEmptyListWhenTemplateHasNoVariables() { + List variables = this.uriTemplateFactory.create("/api/users/all").getVariableNames(); + assertEquals(0, variables.size()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromNullTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create(null).getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenExtractingVariablesFromEmptyTemplate() { + assertThrows(IllegalArgumentException.class, () -> this.uriTemplateFactory.create("").getVariableNames()); + } + + @Test + void shouldThrowExceptionWhenTemplateContainsDuplicateVariables() { + assertThrows(IllegalArgumentException.class, + () -> this.uriTemplateFactory.create("/api/users/{userId}/posts/{userId}").getVariableNames()); + } + + @Test + void shouldExtractVariableValuesFromRequestUri() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues("/api/users/123/posts/456"); + assertEquals(2, values.size()); + assertEquals("123", values.get("userId")); + assertEquals("456", values.get("postId")); + } + + @Test + void shouldReturnEmptyMapWhenTemplateHasNoVariables() { + Map values = this.uriTemplateFactory.create("/api/users/all") + .extractVariableValues("/api/users/all"); + assertEquals(0, values.size()); + } + + @Test + void shouldReturnEmptyMapWhenRequestUriIsNull() { + Map values = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}") + .extractVariableValues(null); + assertEquals(0, values.size()); + } + + @Test + void shouldMatchUriAgainstTemplatePattern() { + var uriTemplateManager = this.uriTemplateFactory.create("/api/users/{userId}/posts/{postId}"); + + assertTrue(uriTemplateManager.matches("/api/users/123/posts/456")); + assertFalse(uriTemplateManager.matches("/api/users/123/comments/456")); + } + +} From e34babbe56b730514d35191118d5e66bc9c51b9a Mon Sep 17 00:00:00 2001 From: jito Date: Thu, 8 May 2025 18:31:17 +0900 Subject: [PATCH 002/216] Add missing isInitialized method to McpSyncClient (#181) The isInitialized method is present in McpAsyncClient and needs to be mirrored in McpSyncClient. Signed-off-by: jitokim --- .../io/modelcontextprotocol/client/McpSyncClient.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index c91638a7e..a8fb979e1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -97,6 +97,14 @@ public McpSchema.Implementation getServerInfo() { return this.delegate.getServerInfo(); } + /** + * Check if the client-server connection is initialized. + * @return true if the client-server connection is initialized + */ + public boolean isInitialized() { + return this.delegate.isInitialized(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities From eae3840e7d44932c60c131cb7a346b5367b788ff Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Fri, 25 Apr 2025 18:47:54 +0200 Subject: [PATCH 003/216] fix: Mockito inline mocking for Java 21+ (#207) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before this fix the execution of the maven surefire plugin with Java 21 logged warnings that mockito should be added as a java agent, because the self-attaching won't be supported in future java releases. In Java 24 the test just broke. This problem is solved by modifying the pom.xml of the parent and doing this changes: * Adding mockito as a java agent. * Removing the surefireArgLine from the properties. This can be added back when it's needed (for example when JaCoCo will be used). Furthermore, the pom.xml in the mcp-spring-* modules now have the byte-buddy dependency included, as the test would otherwise break when trying to mock McpSchema#CreateMessageRequest. Fixes #187 Co-authored-by: Dariusz Jędrzejczyk --- mcp-spring/mcp-spring-webflux/pom.xml | 6 ++++++ mcp-spring/mcp-spring-webmvc/pom.xml | 6 ++++++ pom.xml | 15 +++++++++++++-- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 63c32a8a8..86f46bf95 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -82,6 +82,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index b59be6a03..82fbbf3e6 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -77,6 +77,12 @@ ${mockito.version} test + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + org.testcontainers junit-jupiter diff --git a/pom.xml b/pom.xml index 9be256ccf..638457406 100644 --- a/pom.xml +++ b/pom.xml @@ -57,6 +57,7 @@ 17 17 17 + 3.26.3 5.10.2 @@ -163,13 +164,23 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + + properties + + + + org.apache.maven.plugins maven-surefire-plugin ${maven-surefire-plugin.version} - ${surefireArgLine} - + ${surefireArgLine} -javaagent:${org.mockito:mockito-core:jar} false false From 0069c977ef88b91162b08899bb8040a0ffcb8653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 9 May 2025 12:57:36 +0200 Subject: [PATCH 004/216] Remove temporary delegate impl from McpAsyncServer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../server/McpAsyncServer.java | 1082 ++++++++--------- 1 file changed, 484 insertions(+), 598 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 3c112ad76..1efa13de3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -82,11 +82,33 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - private final McpAsyncServer delegate; + private final McpServerTransportProvider mcpTransportProvider; - McpAsyncServer() { - this.delegate = null; - } + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); /** * Create a new McpAsyncServer with the given transport provider and capabilities. @@ -98,8 +120,104 @@ public class McpAsyncServer { McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features, - uriTemplateManagerFactory); + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + mcpTransportProvider.setSessionFactory( + transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); } /** @@ -107,7 +225,7 @@ public class McpAsyncServer { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.delegate.getServerCapabilities(); + return this.serverCapabilities; } /** @@ -115,7 +233,7 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.delegate.getServerInfo(); + return this.serverInfo; } /** @@ -123,26 +241,66 @@ public McpSchema.Implementation getServerInfo() { * @return A Mono that completes when the server has been closed */ public Mono closeGracefully() { - return this.delegate.closeGracefully(); + return this.mcpTransportProvider.closeGracefully(); } /** * Close the server immediately. */ public void close() { - this.delegate.close(); + this.mcpTransportProvider.close(); + } + + private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); } // --------------------------------------- // Tool Management // --------------------------------------- + /** * Add a new tool specification at runtime. * @param toolSpecification The tool specification to add * @return Mono that completes when clients have been notified of the change */ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - return this.delegate.addTool(toolSpecification); + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); } /** @@ -151,7 +309,25 @@ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecifica * @return Mono that completes when clients have been notified of the change */ public Mono removeTool(String toolName) { - return this.delegate.removeTool(toolName); + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); } /** @@ -159,19 +335,65 @@ public Mono removeTool(String toolName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyToolsListChanged() { - return this.delegate.notifyToolsListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpServerSession.RequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; } // --------------------------------------- // Resource Management // --------------------------------------- + /** * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add + * @param resourceSpecification The resource handler to add * @return Mono that completes when clients have been notified of the change */ - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { - return this.delegate.addResource(resourceHandler); + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); } /** @@ -180,7 +402,24 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou * @return Mono that completes when clients have been notified of the change */ public Mono removeResource(String resourceUri) { - return this.delegate.removeResource(resourceUri); + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); } /** @@ -188,19 +427,97 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - return this.delegate.notifyResourcesListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), + resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; + } + + private McpServerSession.RequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(exchange, resourceRequest); + }; } // --------------------------------------- // Prompt Management // --------------------------------------- + /** * Add a new prompt handler at runtime. * @param promptSpecification The prompt handler to add * @return Mono that completes when clients have been notified of the change */ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - return this.delegate.addPrompt(promptSpecification); + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error( + new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); } /** @@ -209,7 +526,27 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe * @return Mono that completes when clients have been notified of the change */ public Mono removePrompt(String promptName) { - return this.delegate.removePrompt(promptName); + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); } /** @@ -217,7 +554,39 @@ public Mono removePrompt(String promptName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyPromptsListChanged() { - return this.delegate.notifyPromptsListChanged(); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpServerSession.RequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(exchange, promptRequest); + }; } // --------------------------------------- @@ -237,619 +606,136 @@ public Mono notifyPromptsListChanged() { */ @Deprecated public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - return this.delegate.loggingNotification(loggingMessageNotification); - } - - // --------------------------------------- - // Sampling - // --------------------------------------- - /** - * This method is package-private and used for test only. Should not be called by user - * code. - * @param protocolVersions the Client supported protocol versions. - */ - void setProtocolVersions(List protocolVersions) { - this.delegate.setProtocolVersions(protocolVersions); - } - - private static class AsyncServerImpl extends McpAsyncServer { - - private final McpServerTransportProvider mcpTransportProvider; - - private final ObjectMapper objectMapper; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private final String instructions; - - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - // FIXME: this field is deprecated and should be remvoed together with the - // broadcasting loggingNotification. - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); - - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - - AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - Duration requestTimeout, McpServerFeatures.Async features, - McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.instructions = features.instructions(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - this.completions.putAll(features.completions()); - this.uriTemplateManagerFactory = uriTemplateManagerFactory; - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - // Add completion API handlers if the completion capability is enabled - if (this.serverCapabilities.completions() != null) { - requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, - roots) -> Mono.fromRunnable(() -> logger.warn( - "Roots list changed notification, but no consumers provided. Roots list changed: {}", - roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); - } - - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private Mono asyncInitializeRequestHandler( - McpSchema.InitializeRequest initializeRequest) { - return Mono.defer(() -> { - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports - // if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST - // respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, this.instructions)); - }); - } - - public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; - } - - public McpSchema.Implementation getServerInfo() { - return this.serverInfo; - } - @Override - public Mono closeGracefully() { - return this.mcpTransportProvider.closeGracefully(); + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); } - @Override - public void close() { - this.mcpTransportProvider.close(); + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); } - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return (exchange, params) -> exchange.listRoots() - .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - - // --------------------------------------- - // Tool Management - // --------------------------------------- - - @Override - public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { - if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); - } - if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolSpecification.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); - } - - this.tools.add(toolSpecification); - logger.debug("Added tool handler: {}", toolSpecification.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); - } - - @Override - public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); + } + private McpServerSession.RequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); - } - - @Override - public Mono notifyToolsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + new TypeReference() { }); - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; - } + return Mono.just(Map.of()); + }); + }; + } - // --------------------------------------- - // Resource Management - // --------------------------------------- + private McpServerSession.RequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); - @Override - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { - if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); } - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); - } + String type = request.ref().type(); - @Override - public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } + String argumentName = request.argument().name(); - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); - } - - @Override - public Mono notifyResourcesListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceSpecification::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - - } - - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), - resource.description(), resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); - - list.addAll(resourceTemplates); - - return list; - } - - private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + if (!promptSpec.prompt() + .arguments() .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) + .filter(arg -> arg.name().equals(argumentName)) .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - - return specification.readHandler().apply(exchange, resourceRequest); - }; - } - - // --------------------------------------- - // Prompt Management - // --------------------------------------- + .isPresent()) { - @Override - public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { - if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error(new McpError( - "Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); + return Mono.error(new McpError("Argument not found: " + argumentName)); } - return Mono.empty(); - }); - } - - @Override - public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); } - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); + if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); + if (resourceSpec == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); - } - - @Override - public Mono notifyPromptsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptSpecification::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); - if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); } - return specification.promptHandler().apply(exchange, promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - @Override - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); } - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, - loggingMessageNotification); - } - - private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { - return Mono.defer(() -> { - - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); - - exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - - // FIXME: this field is deprecated and should be removed together - // with the broadcasting loggingNotification. - this.minLoggingLevel = newMinLoggingLevel.level(); - - return Mono.just(Map.of()); - }); - }; - } - - private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { - McpSchema.CompleteRequest request = parseCompletionParams(params); - - if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); - } - - if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); - } - - String type = request.ref().type(); - - String argumentName = request.argument().name(); - - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); - if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); - } - if (!promptSpec.prompt() - .arguments() - .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { - - return Mono.error(new McpError("Argument not found: " + argumentName)); - } - } - - if (type.equals("ref/resource") - && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources - .get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); - } + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - } - - McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); - - if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); - } - - return specification.completionHandler().apply(exchange, request); - }; - } - - /** - * Parses the raw JSON-RPC request parameters into a - * {@link McpSchema.CompleteRequest} object. - *

- * This method manually extracts the `ref` and `argument` fields from the input - * map, determines the correct reference type (either prompt or resource), and - * constructs a fully-typed {@code CompleteRequest} instance. - * @param object the raw request parameters, expected to be a Map containing "ref" - * and "argument" entries. - * @return a {@link McpSchema.CompleteRequest} representing the structured - * completion request. - * @throws IllegalArgumentException if the "ref" type is not recognized. - */ - @SuppressWarnings("unchecked") - private McpSchema.CompleteRequest parseCompletionParams(Object object) { - Map params = (Map) object; - Map refMap = (Map) params.get("ref"); - Map argMap = (Map) params.get("argument"); - - String refType = (String) refMap.get("type"); - - McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); - default -> throw new IllegalArgumentException("Invalid ref type: " + refType); - }; - - String argName = (String) argMap.get("name"); - String argValue = (String) argMap.get("value"); - McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument( - argName, argValue); - - return new McpSchema.CompleteRequest(ref, argument); - } + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } - // --------------------------------------- - // Sampling - // --------------------------------------- + return specification.completionHandler().apply(exchange, request); + }; + } - @Override - void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; - } + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name")); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; } } From b2d3e0098e484e172719237b0933fa395cdfdf4b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 12 May 2025 15:04:05 +0200 Subject: [PATCH 005/216] Next development version Signed-off-by: Christian Tzolov --- mcp-bom/pom.xml | 2 +- mcp-spring/mcp-spring-webflux/pom.xml | 6 +++--- mcp-spring/mcp-spring-webmvc/pom.xml | 6 +++--- mcp-test/pom.xml | 4 ++-- mcp/pom.xml | 2 +- pom.xml | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 4f24f719f..7214dacda 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 86f46bf95..a8b92bd09 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webflux @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 82fbbf3e6..48d1c3465 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT test diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f1484ae77..a6e5bdb08 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT diff --git a/mcp/pom.xml b/mcp/pom.xml index 17693ab32..773432827 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT mcp jar diff --git a/pom.xml b/pom.xml index 638457406..c2327ee8d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.10.0-SNAPSHOT + 0.11.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk From f34662555a0ab68d74ac118f1b0220441b2c81b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 15:38:02 +0200 Subject: [PATCH 006/216] Fix stdio tests - proper server-everything argument (#237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../modelcontextprotocol/client/StdioMcpAsyncClientTests.java | 4 ++-- .../modelcontextprotocol/client/StdioMcpSyncClientTests.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c39080138..8c0069d6d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -25,12 +25,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 8e75c4a3d..706aa9b2e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -33,12 +33,12 @@ protected McpClientTransport createMcpTransport() { ServerParameters stdioParams; if (System.getProperty("os.name").toLowerCase().contains("win")) { stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "dir") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } else { stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "dir") + .args("-y", "@modelcontextprotocol/server-everything", "stdio") .build(); } return new StdioClientTransport(stdioParams); From 2e13f9f9df8610e0d05cc76b1416fe195e249303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 22:46:54 +0200 Subject: [PATCH 007/216] Fix flaky WebFluxSse integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2ba047461..03fbc9962 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -8,6 +8,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; @@ -651,9 +653,11 @@ void testInitialize(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) { + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); // Create a list to store received logging notifications - List receivedNotifications = new ArrayList<>(); + List receivedNotifications = new CopyOnWriteArrayList<>(); var clientBuilder = clientBuilders.get(clientType); @@ -709,6 +713,7 @@ void testLoggingNotification(String clientType) { // Create client with logging notification handler var mcpClient = clientBuilder.loggingConsumer(notification -> { receivedNotifications.add(notification); + latch.countDown(); }).build()) { // Initialize client @@ -724,31 +729,28 @@ void testLoggingNotification(String clientType) { assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(3); + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()) - .isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - }); + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); } mcpServer.close(); } From 1adfa8a047852c8f9e0188b4e63fe2020e0c66c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 14:05:39 +0200 Subject: [PATCH 008/216] Add Contributing Guidelines and Code of Conduct MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- CODE_OF_CONDUCT.md | 119 +++++++++++++++++++++++++++++++++++++++++++++ CONTRIBUTING.md | 91 ++++++++++++++++++++++++++++++++++ README.md | 7 +-- SECURITY.md | 21 ++++++++ 4 files changed, 233 insertions(+), 5 deletions(-) create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 SECURITY.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..6009a645f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,119 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a +harassment-free experience for everyone, regardless of age, body size, visible or +invisible disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, +inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community +include: + +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and + learning from the experience +- Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +- The use of sexualized language or imagery, and sexual attention or advances of any kind +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without + their explicit permission +- Other conduct which could reasonably be considered inappropriate in a professional + setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in response to +any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, +commits, code, wiki edits, issues, and other contributions that are not aligned to this +Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an +individual is officially representing the community in public spaces. Examples of +representing our community include using an official e-mail address, posting via an +official social media account, or acting as an appointed representative at an online or +offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to +the community leaders responsible for enforcement at mcp-coc@anthropic.com. All +complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter +of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the +consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity +around the nature of the violation and an explanation of why the behavior was +inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, for a specified period of time. This includes avoiding interactions in community +spaces as well as external channels like social media. Violating these terms may lead to +a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained +inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication +with the community for a specified period of time. No public or private interaction with +the people involved, including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead to a permanent +ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, +including sustained inappropriate behavior, harassment of an individual, or aggression +toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, +available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..a949dcc09 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,91 @@ +# Contributing to Model Context Protocol Java SDK + +Thank you for your interest in contributing to the Model Context Protocol Java SDK! +This document outlines how to contribute to this project. + +## Prerequisites + +The following software is required to work on the codebase: + +- `Java 17` or above +- `Docker` +- `npx` + +## Getting Started + +1. Fork the repository +2. Clone your fork: + +```bash +git clone https://github.com/YOUR-USERNAME/java-sdk.git +cd java-sdk +``` + +3. Build from source: + +```bash +./mvnw clean install -DskipTests # skip the tests +./mvnw test # run tests +``` + +## Reporting Issues + +Please create an issue in the repository if you discover a bug or would like to +propose an enhancement. Bug reports should have a reproducer in the form of a code +sample or a repository attached that the maintainers or contributors can work with to +address the problem. + +## Making Changes + +1. Create a new branch: + +```bash +git checkout -b feature/your-feature-name +``` + +2. Make your changes +3. Validate your changes: + +```bash +./mvnw clean test +``` + +### Change Proposal Guidelines + +#### Principles of MCP + +1. **Simple + Minimal**: It is much easier to add things to the codebase than it is to + remove them. To maintain simplicity, we keep a high bar for adding new concepts and + primitives as each addition requires maintenance and compatibility consideration. +2. **Concrete**: Code changes need to be based on specific usage and implementation + challenges and not on speculative ideas. Most importantly, the SDK is meant to + implement the MCP specification. + +## Submitting Changes + +1. For non-trivial changes, please clarify with the maintainers in an issue whether + you can contribute the change and the desired scope of the change. +2. For trivial changes (for example a couple of lines or documentation changes) there + is no need to open an issue first. +3. Push your changes to your fork. +4. Submit a pull request to the main repository. +5. Follow the pull request template. +6. Wait for review. + +## Code of Conduct + +This project follows a Code of Conduct. Please review it in +[CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). + +## Questions + +If you have questions, please create a discussion in the repository. + +## License + +By contributing, you agree that your contributions will be licensed under the MIT +License. + +## Security + +Please review our [Security Policy](SECURITY.md) for reporting security issues. \ No newline at end of file diff --git a/README.md b/README.md index 9fc17306e..0cd3f84a4 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,8 @@ To run the tests you have to pre-install `Docker` and `npx`. ## Contributing -Contributions are welcome! Please: - -1. Fork the repository -2. Create a feature branch -3. Submit a Pull Request +Contributions are welcome! +Please follow the [Contributing Guidelines](CONTRIBUTING.md). ## Team diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..74e9880fd --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,21 @@ +# Security Policy + +Thank you for helping us keep the SDKs and systems they interact with secure. + +## Reporting Security Issues + +This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model +Context Protocol project. + +The security of our systems and user data is Anthropic’s top priority. We appreciate the +work of security researchers acting in good faith in identifying and reporting potential +vulnerabilities. + +Our security program is managed on HackerOne and we ask that any validated vulnerability +in this functionality be reported through their +[submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). + +## Vulnerability Disclosure Program + +Our Vulnerability Program Guidelines are defined on our +[HackerOne program page](https://hackerone.com/anthropic-vdp). \ No newline at end of file From 07e7b8fd6bac47be4527f97451f8cdd95ed31a38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 14 May 2025 18:00:06 +0200 Subject: [PATCH 009/216] Add note about force pushes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- CONTRIBUTING.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a949dcc09..517f32555 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -71,6 +71,9 @@ git checkout -b feature/your-feature-name 4. Submit a pull request to the main repository. 5. Follow the pull request template. 6. Wait for review. +7. For any follow-up work, please add new commits instead of force-pushing. This will + allow the reviewer to focus on incremental changes instead of having to restart the + review process. ## Code of Conduct From 8a5a591d39256ba3947003ec4477e1722363eb35 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 27 May 2025 15:26:44 -0700 Subject: [PATCH 010/216] feat: Add elicitation support to MCP protocol Implement elicitation capabilities allowing servers to request additional information from users through clients during interactions. This feature provides a standardized way for servers to gather necessary information dynamically while clients maintain control over user interactions and data sharing. - Add ElicitRequest and ElicitResult classes to McpSchema - Implement elicitation handlers in client classes - Add elicitation capabilities to server exchange classes - Add tests for elicitation functionality with various scenarios --- .../WebFluxSseIntegrationTests.java | 224 +++++++++++++++++- .../server/WebMvcSseIntegrationTests.java | 213 +++++++++++++++++ .../client/McpAsyncClient.java | 32 +++ .../client/McpClient.java | 40 +++- .../client/McpClientFeatures.java | 31 ++- .../server/McpAsyncServerExchange.java | 28 +++ .../server/McpSyncServerExchange.java | 18 ++ .../modelcontextprotocol/spec/McpSchema.java | 129 ++++++++-- .../client/AbstractMcpAsyncClientTests.java | 22 +- .../McpAsyncClientResponseHandlerTests.java | 150 ++++++++++++ ...rverTransportProviderIntegrationTests.java | 213 +++++++++++++++++ .../spec/McpSchemaTests.java | 34 +++ 12 files changed, 1106 insertions(+), 28 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 03fbc9962..2f85654e8 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -28,11 +27,11 @@ import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.*; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -41,6 +40,7 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -331,6 +331,226 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt mcpServer.closeGracefully().block(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index b12d68439..3f3f7be62 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -357,6 +357,219 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba3..a22ef6b51 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -23,6 +23,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; @@ -141,6 +143,15 @@ public class McpAsyncClient { */ private Function> samplingHandler; + /** + * MCP provides a standardized way for servers to request additional information from + * users through the client during interactions. This flow allows clients to maintain + * control over user interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured data from users + * with optional JSON schemas to validate responses. + */ + private Function> elicitationHandler; + /** * Client transport implementation. */ @@ -189,6 +200,15 @@ public class McpAsyncClient { requestHandlers.put(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, samplingCreateMessageHandler()); } + // Elicitation Handler + if (this.clientCapabilities.elicitation() != null) { + if (features.elicitationHandler() == null) { + throw new McpError("Elicitation handler must not be null when client capabilities include elicitation"); + } + this.elicitationHandler = features.elicitationHandler(); + requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); + } + // Notification Handlers Map notificationHandlers = new HashMap<>(); @@ -500,6 +520,18 @@ private RequestHandler samplingCreateMessageHandler() { }; } + // -------------------------- + // Elicitation + // -------------------------- + private RequestHandler elicitationCreateHandler() { + return params -> { + ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + }); + + return this.elicitationHandler.apply(request); + }; + } + // -------------------------- // Tools // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc11685..280906cff 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -18,6 +18,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.util.Assert; @@ -175,6 +177,8 @@ class SyncSpec { private Function samplingHandler; + private Function elicitationHandler; + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -283,6 +287,21 @@ public SyncSpec sampling(Function sam return this; } + /** + * Sets a custom elicitation handler for processing elicitation message requests. + * The elicitation handler can modify or validate messages before they are sent to + * the server, enabling custom processing logic. + * @param elicitationHandler A function that processes elicitation requests and + * returns results. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationHandler is null + */ + public SyncSpec elicitation(Function elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -364,7 +383,7 @@ public SyncSpec loggingConsumers(List> samplingHandler; + private Function> elicitationHandler; + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -522,6 +543,21 @@ public AsyncSpec sampling(Function> elicitationHandler) { + Assert.notNull(elicitationHandler, "Elicitation handler must not be null"); + this.elicitationHandler = elicitationHandler; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -606,7 +642,7 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler, this.elicitationHandler)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 284b93f88..23d7c6a60 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -60,13 +60,15 @@ class McpClientFeatures { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { /** * Create an instance and validate the arguments. @@ -77,6 +79,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -84,14 +87,16 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesChangeConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, - Function> samplingHandler) { + Function> samplingHandler, + Function> elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -99,6 +104,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } /** @@ -138,9 +144,14 @@ public static Async fromSync(Sync syncSpec) { Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); + + Function> elicitationHandler = r -> Mono + .fromCallable(() -> syncSpec.elicitationHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()); + return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers, loggingConsumers, - samplingHandler); + samplingHandler, elicitationHandler); } } @@ -156,13 +167,15 @@ public static Async fromSync(Sync syncSpec) { * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { /** * Create an instance and validate the arguments. @@ -174,20 +187,23 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. * @param samplingHandler the sampling handler. + * @param elicitationHandler the elicitation handler. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, List>> resourcesChangeConsumers, List>> promptsChangeConsumers, List> loggingConsumers, - Function samplingHandler) { + Function samplingHandler, + Function elicitationHandler) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; this.clientCapabilities = (clientCapabilities != null) ? clientCapabilities : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, - samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null); + samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); this.roots = roots != null ? new HashMap<>(roots) : new HashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -195,6 +211,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.samplingHandler = samplingHandler; + this.elicitationHandler = elicitationHandler; } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d0..cfb07d26c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -36,6 +36,9 @@ public class McpAsyncServerExchange { private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { }; + private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + }; + /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. @@ -93,6 +96,31 @@ public Mono createMessage(McpSchema.CreateMessage CREATE_MESSAGE_RESULT_TYPE_REF); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A Mono that completes when the elicitation has been resolved. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null) { + return Mono.error(new McpError("Client must be configured with elicitation capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, + ELICITATION_RESULT_TYPE_REF); + } + /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 52360e54b..084412b96 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -64,6 +64,24 @@ public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageReques return this.exchange.createMessage(createMessageRequest).block(); } + /** + * Creates a new elicitation. MCP provides a standardized way for servers to request + * additional information from users through the client during interactions. This flow + * allows clients to maintain control over user interactions and data sharing while + * enabling servers to gather necessary information dynamically. Servers can request + * structured data from users with optional JSON schemas to validate responses. + * @param elicitRequest The request to create a new elicitation + * @return A result containing the elicitation response. + * @see McpSchema.ElicitRequest + * @see McpSchema.ElicitResult + * @see Elicitation + * Specification + */ + public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitation(elicitRequest).block(); + } + /** * Retrieves the list of all roots provided by the client. * @return The list of roots result. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a1584..9dae08266 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -94,6 +94,9 @@ private McpSchema() { // Sampling Methods public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; + // Elicitation Methods + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // --------------------------- @@ -131,8 +134,8 @@ public static final class ErrorCodes { } - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, CompleteRequest, GetPromptRequest { + public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, + CompleteRequest, GetPromptRequest { } @@ -221,7 +224,7 @@ public record JSONRPCError( public record InitializeRequest( // @formatter:off @JsonProperty("protocolVersion") String protocolVersion, @JsonProperty("capabilities") ClientCapabilities capabilities, - @JsonProperty("clientInfo") Implementation clientInfo) implements Request { + @JsonProperty("clientInfo") Implementation clientInfo) implements Request { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -245,6 +248,8 @@ public record InitializeResult( // @formatter:off * access to. * @param sampling Provides a standardized way for servers to request LLM sampling * (“completions” or “generations”) from language models via clients. + * @param elicitation Provides a standardized way for servers to request additional + * information from users through the client during interactions. * */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -252,7 +257,8 @@ public record InitializeResult( // @formatter:off public record ClientCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, - @JsonProperty("sampling") Sampling sampling) { + @JsonProperty("sampling") Sampling sampling, + @JsonProperty("elicitation") Elicitation elicitation) { /** * Roots define the boundaries of where servers can operate within the filesystem, @@ -264,7 +270,7 @@ public record ClientCapabilities( // @formatter:off * has changed since the last time the server checked. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) + @JsonIgnoreProperties(ignoreUnknown = true) public record RootCapabilities( @JsonProperty("listChanged") Boolean listChanged) { } @@ -279,10 +285,22 @@ public record RootCapabilities( * image-based interactions and optionally include context * from MCP servers in their prompts. */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record Sampling() { } + /** + * Provides a standardized way for servers to request additional + * information from users through the client during interactions. + * This flow allows clients to maintain control over user + * interactions and data sharing while enabling servers to gather + * necessary information dynamically. Servers can request structured + * data from users with optional JSON schemas to validate responses. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Elicitation() { + } + public static Builder builder() { return new Builder(); } @@ -291,6 +309,7 @@ public static class Builder { private Map experimental; private RootCapabilities roots; private Sampling sampling; + private Elicitation elicitation; public Builder experimental(Map experimental) { this.experimental = experimental; @@ -307,8 +326,13 @@ public Builder sampling() { return this; } + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } + public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling); + return new ClientCapabilities(experimental, roots, sampling, elicitation); } } }// @formatter:on @@ -326,11 +350,11 @@ public record ServerCapabilities( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) public record CompletionCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record LoggingCapabilities() { } - + @JsonInclude(JsonInclude.Include.NON_ABSENT) public record PromptCapabilities( @JsonProperty("listChanged") Boolean listChanged) { @@ -727,11 +751,11 @@ public record Tool( // @formatter:off @JsonProperty("name") String name, @JsonProperty("description") String description, @JsonProperty("inputSchema") JsonSchema inputSchema) { - + public Tool(String name, String description, String schema) { this(name, description, parseSchema(schema)); } - + } // @formatter:on private static JsonSchema parseSchema(String schema) { @@ -758,7 +782,7 @@ public record CallToolRequest(// @formatter:off @JsonProperty("arguments") Map arguments) implements Request { public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments)); + this(name, parseJsonArguments(jsonArguments)); } private static Map parseJsonArguments(String jsonArguments) { @@ -893,7 +917,7 @@ public record ModelPreferences(// @formatter:off @JsonProperty("costPriority") Double costPriority, @JsonProperty("speedPriority") Double speedPriority, @JsonProperty("intelligencePriority") Double intelligencePriority) { - + public static Builder builder() { return new Builder(); } @@ -963,7 +987,7 @@ public record CreateMessageRequest(// @formatter:off @JsonProperty("includeContext") ContextInclusionStrategy includeContext, @JsonProperty("temperature") Double temperature, @JsonProperty("maxTokens") int maxTokens, - @JsonProperty("stopSequences") List stopSequences, + @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata) implements Request { public enum ContextInclusionStrategy { @@ -971,7 +995,7 @@ public enum ContextInclusionStrategy { @JsonProperty("thisServer") THIS_SERVER, @JsonProperty("allServers") ALL_SERVERS } - + public static Builder builder() { return new Builder(); } @@ -1040,7 +1064,7 @@ public record CreateMessageResult(// @formatter:off @JsonProperty("content") Content content, @JsonProperty("model") String model, @JsonProperty("stopReason") StopReason stopReason) { - + public enum StopReason { @JsonProperty("endTurn") END_TURN, @JsonProperty("stopSequence") STOP_SEQUENCE, @@ -1088,6 +1112,79 @@ public CreateMessageResult build() { } }// @formatter:on + // Elicitation + /** + * Used by the server to send an elicitation to the client. + * + * @param message The body of the elicitation message. + * @param requestedSchema The elicitation response schema that must be satisfied. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitRequest(// @formatter:off + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") Map requestedSchema) implements Request { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String message; + private Map requestedSchema; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(Map requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema); + } + } + }// @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitResult(// @formatter:off + @JsonProperty("action") Action action, + @JsonProperty("content") Map content) { + + public enum Action { + @JsonProperty("accept") ACCEPT, + @JsonProperty("decline") DECLINE, + @JsonProperty("cancel") CANCEL + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Action action; + private Map content; + + public Builder message(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content); + } + } + }// @formatter:on + // --------------------------- // Pagination Interfaces // --------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72b409af9..d1a2581ee 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.Resource; @@ -422,6 +424,20 @@ void testInitializeWithSamplingCapability() { }); } + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() @@ -433,7 +449,11 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 4510b1529..e6cde8e3b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; @@ -349,4 +351,152 @@ void testSamplingCreateMessageRequestHandlingWithNullHandler() { .hasMessage("Sampling handler must not be null when client capabilities include sampling"); } + @Test + @SuppressWarnings("unchecked") + void testElicitationCreateRequestHandling() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler that echoes back the input + Function> elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isInstanceOf(Map.class); + assertThat(request.requestedSchema().get("type")).isEqualTo("object"); + + var properties = request.requestedSchema().get("properties"); + assertThat(properties).isNotNull(); + assertThat(((Map) properties).get("message")).isInstanceOf(Map.class); + + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + }; + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content()).isEqualTo(Map.of("message", "Test message")); + + asyncMcpClient.closeGracefully(); + } + + @ParameterizedTest + @EnumSource(value = McpSchema.ElicitResult.Action.class, names = { "DECLINE", "CANCEL" }) + void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create a test elicitation handler to decline the request + Function> elicitationHandler = request -> Mono + .just(McpSchema.ElicitResult.builder().message(action).build()); + + // Create client with elicitation capability and handler + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.error()).isNull(); + + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + }); + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(action); + assertThat(result.content()).isNull(); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithoutCapability() { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client without elicitation capability + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .capabilities(ClientCapabilities.builder().build()) // No elicitation + // capability + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Create a mock elicitation + var elicitRequest = new McpSchema.ElicitRequest("test", + Map.of("type", "object", "properties", Map.of("test", Map.of("type", "boolean", "defaultValue", true, + "description", "test-description", "title", "test-title")))); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.result()).isNull(); + assertThat(response.error()).isNotNull(); + assertThat(response.error().message()).contains("Method not found: elicitation/create"); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCreateRequestHandlingWithNullHandler() { + MockMcpClientTransport transport = new MockMcpClientTransport(); + + // Create client with elicitation capability but null handler + assertThatThrownBy(() -> McpClient.async(transport) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .build()).isInstanceOf(McpError.class) + .hasMessage("Elicitation handler must not be null when client capabilities include elicitation"); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 2ff6325a4..dc9d1cfab 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -24,6 +24,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; @@ -339,6 +341,217 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { mcpServer.close(); } + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @Test + @Disabled + void testCreateElicitationWithoutElicitationCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @Test + void testCreateElicitationSuccess() { + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutSuccess() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @Test + void testCreateElicitationWithRequestTimeoutFail() { + + // Client + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index ff78c1bfc..99015d8c4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -807,6 +807,40 @@ void testCreateMessageResult() throws Exception { {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"endTurn"}""")); } + // Elicitation Tests + + @Test + void testCreateElicitationRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .requestedSchema(Map.of("type", "object", "required", List.of("a"), "properties", + Map.of("foo", Map.of("type", "string")))) + .build(); + + String value = mapper.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"requestedSchema":{"properties":{"foo":{"type":"string"}},"required":["a"],"type":"object"}}""")); + } + + @Test + void testCreateElicitationResult() throws Exception { + McpSchema.ElicitResult result = McpSchema.ElicitResult.builder() + .content(Map.of("foo", "bar")) + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + String value = mapper.writeValueAsString(result); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"action":"accept","content":{"foo":"bar"}}""")); + } + // Roots Tests @Test From 2f944349cc77009d020ebddc8b9967b8e1b14baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 10 Jun 2025 18:33:40 +0200 Subject: [PATCH 011/216] feat: WebClient Streamable HTTP support (#292) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit An implementation of Streamable HTTP Client with WebFlux WebClient. Aside from implementing the specification, several improvements have been incorporated throughout the client-side of the architecture. The changes cover: - resilience tests using toxiproxy in testcontainers - integration tests using updated everything-server with streamableHttp support - improved logging - session invalidation handling (both transport session and JSON-RPC concept of session) - implicit initialization and burst protection (in case of concurrent `Mcp(Sync|Async)Client` use - more logging, e.g. stdio process lifecycle logs Related #72, #273, #253, #107, #105 Signed-off-by: Dariusz Jędrzejczyk --- mcp-spring/mcp-spring-webflux/pom.xml | 6 + .../WebClientStreamableHttpTransport.java | 520 ++++++++++++++++++ .../transport/WebFluxSseClientTransport.java | 3 + ...eamableHttpAsyncClientResiliencyTests.java | 17 + ...bClientStreamableHttpAsyncClientTests.java | 42 ++ ...ebClientStreamableHttpSyncClientTests.java | 41 ++ .../client/WebFluxSseMcpAsyncClientTests.java | 3 +- .../client/WebFluxSseMcpSyncClientTests.java | 3 +- .../WebFluxSseClientTransportTests.java | 3 +- .../src/test/resources/logback.xml | 8 +- mcp-test/pom.xml | 5 + ...AbstractMcpAsyncClientResiliencyTests.java | 222 ++++++++ .../client/AbstractMcpAsyncClientTests.java | 37 +- .../client/AbstractMcpSyncClientTests.java | 60 +- .../client/McpAsyncClient.java | 291 ++++++---- .../transport/StdioClientTransport.java | 5 + .../spec/DefaultMcpTransportSession.java | 79 +++ .../spec/DefaultMcpTransportStream.java | 74 +++ .../spec/McpClientSession.java | 53 +- .../spec/McpClientTransport.java | 22 +- .../modelcontextprotocol/spec/McpSchema.java | 6 + .../spec/McpTransportSession.java | 60 ++ .../McpTransportSessionNotFoundException.java | 29 + .../spec/McpTransportStream.java | 45 ++ .../client/AbstractMcpAsyncClientTests.java | 37 +- .../client/AbstractMcpSyncClientTests.java | 58 +- .../client/HttpSseMcpAsyncClientTests.java | 3 +- .../client/HttpSseMcpSyncClientTests.java | 3 +- .../client/StdioMcpSyncClientTests.java | 2 +- .../HttpClientSseClientTransportTests.java | 3 +- pom.xml | 3 +- 31 files changed, 1501 insertions(+), 242 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index a8b92bd09..26452fe95 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -99,6 +99,12 @@ ${testcontainers.version} test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + org.awaitility diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java new file mode 100644 index 000000000..e7b7c8ee9 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -0,0 +1,520 @@ +package io.modelcontextprotocol.client.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.util.Assert; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

+ * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

+ *

+ * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@link WebFluxSseClientTransport}. + *

+ * + * @author Dariusz Jędrzejczyk + * @see Streamable + * HTTP transport specification + */ +public class WebClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { + }; + + private final ObjectMapper objectMapper; + + private final WebClient webClient; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final AtomicReference activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, + String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { + this.objectMapper = objectMapper; + this.webClient = webClientBuilder.build(); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + } + + /** + * Create a stateful builder for creating {@link WebClientStreamableHttpTransport} + * instances. + * @param webClientBuilder the {@link WebClient.Builder} to use + * @return a builder which will create an instance of + * {@link WebClientStreamableHttpTransport} once {@link Builder#build()} is called + */ + public static Builder builder(WebClient.Builder webClientBuilder) { + return new Builder(webClientBuilder); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).then(); + } + return Mono.empty(); + }); + } + + private DefaultMcpTransportSession createTransportSession() { + Supplier> onClose = () -> { + DefaultMcpTransportSession transportSession = this.activeSession.get(); + return transportSession.sessionId().isEmpty() ? Mono.empty() + : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { + httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); + }).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then(); + }; + return new DefaultMcpTransportSession(onClose); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + if (currentSession != null) { + return currentSession.closeGracefully(); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + return Mono.deferContextual(ctx -> { + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + // Here we attempt to initialize the client. In case the server supports SSE, + // we will establish a long-running + // session here and listen for messages. If it doesn't, that's ok, the server + // is a simple, stateless one. + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.get() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + if (stream != null) { + stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); + } + }) + .exchangeToFlux(response -> { + if (isEventStream(response)) { + return eventStream(stream, response); + } + else if (isNotAllowed(response)) { + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (isNotFound(response)) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + return mcpSessionNotFoundError(sessionIdRepresentation); + } + else { + return response.createError().doOnError(e -> { + logger.info("Opening an SSE stream failed. This can be safely ignored.", e); + }).flux(); + } + }) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.create(sink -> { + logger.debug("Sending message {}", message); + // Here we attempt to initialize the client. + // In case the server supports SSE, we will establish a long-running session + // here and + // listen for messages. + // If it doesn't, nothing actually happens here, that's just the way it is... + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + Disposable connection = webClient.post() + .uri(this.endpoint) + .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .headers(httpHeaders -> { + transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + }) + .bodyValue(message) + .exchangeToFlux(response -> { + if (transportSession + .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + reconnect(null).contextWrite(sink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + // The spec mentions only ACCEPTED, but the existing SDKs can return + // 200 OK for notifications + if (response.statusCode().is2xxSuccessful()) { + Optional contentType = response.headers().contentType(); + // Existing SDKs consume notifications with no response body nor + // content type + if (contentType.isEmpty()) { + logger.trace("Message was successfully sent via POST for session {}", + sessionRepresentation); + // signal the caller that the message was successfully + // delivered + sink.success(); + // communicate to downstream there is no streamed data coming + return Flux.empty(); + } + else { + MediaType mediaType = contentType.get(); + if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + // communicate to caller that the message was delivered + sink.success(); + // starting a stream + return newEventStream(response, sessionRepresentation); + } + else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { + logger.trace("Received response to POST for session {}", sessionRepresentation); + // communicate to caller the message was delivered + sink.success(); + return responseFlux(response); + } + else { + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); + } + } + } + else { + if (isNotFound(response)) { + return mcpSessionNotFoundError(sessionRepresentation); + } + return extractError(response, sessionRepresentation); + } + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorResume(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + sink.error(t); + return Flux.empty(); + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static Flux mcpSessionNotFoundError(String sessionRepresentation) { + logger.warn("Session {} was not found on the MCP server", sessionRepresentation); + // inform the stream/connection subscriber + return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation)); + } + + private Flux extractError(ClientResponse response, String sessionRepresentation) { + return response.createError().onErrorResume(e -> { + WebClientResponseException responseException = (WebClientResponseException) e; + byte[] body = responseException.getResponseBodyAsByteArray(); + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; + Exception toPropagate; + try { + McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, + McpSchema.JSONRPCResponse.class); + jsonRpcError = jsonRpcResponse.error(); + toPropagate = new McpError(jsonRpcError); + } + catch (IOException ex) { + toPropagate = new RuntimeException("Sending request failed", e); + logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); + } + + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { + return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); + } + return Mono.empty(); + }).flux(); + } + + private Flux eventStream(McpTransportStream stream, ClientResponse response) { + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse); + return Flux.from(sessionStream.consumeSseStream(idWithMessages)); + } + + private static boolean isNotFound(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND); + } + + private static boolean isNotAllowed(ClientResponse response) { + return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED); + } + + private static boolean isEventStream(ClientResponse response) { + return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() + && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + private Flux responseFlux(ClientResponse response) { + return response.bodyToMono(String.class).>handle((responseMessage, s) -> { + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } + catch (IOException e) { + s.error(e); + } + }).flatMapIterable(Function.identity()); + } + + private Flux newEventStream(ClientResponse response, String sessionRepresentation) { + McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), + sessionRepresentation); + return eventStream(sessionStream, response); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + private Tuple2, Iterable> parse(ServerSentEvent event) { + if (MESSAGE_EVENT_TYPE.equals(event.event())) { + try { + // We don't support batching ATM and probably won't since the next version + // considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); + return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); + } + catch (IOException ioException) { + throw new McpError("Error parsing JSON-RPC message: " + event.data()); + } + } + else { + throw new McpError("Received unrecognized SSE event type: " + event.event()); + } + } + + /** + * Builder for {@link WebClientStreamableHttpTransport}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private WebClient.Builder webClientBuilder; + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private Builder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + } + + /** + * Configure the {@link ObjectMapper} to use. + * @param objectMapper instance to use + * @return the builder instance + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Configure the {@link WebClient.Builder} to construct the {@link WebClient}. + * @param webClientBuilder instance to use + * @return the builder instance + */ + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link WebClientStreamableHttpTransport} + */ + public WebClientStreamableHttpTransport build() { + ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + + return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams, + openConnectionOnStartup); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 37abe295b..128cda4c3 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -190,6 +190,9 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe */ @Override public Mono connect(Function, Mono> handler) { + // TODO: Avoid eager connection opening and enable resilience + // -> upon disconnects, re-establish connection + // -> allow optimizing for eager connection start using a constructor flag Flux> events = eventStream(); this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { if (ENDPOINT_EVENT_TYPE.equals(event.event())) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..80fc671e2 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,17 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..4c8032659 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,42 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.images.builder.ImageFromDockerfile; + +@Timeout(15) +public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..a8cad4898 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +@Timeout(15) +public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index b43c14493..f0533cb49 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 66ac8a6dd..9b0959a35 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -26,7 +26,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index c757d3da9..42b91d14e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -41,7 +41,8 @@ class WebFluxSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml index 5ad73374a..abc831d13 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml @@ -9,13 +9,13 @@ - + - - + + - + diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index a6e5bdb08..9998569dc 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -68,6 +68,11 @@ junit-jupiter ${testcontainers.version} + + org.testcontainers + toxiproxy + ${toxiproxy.version} + org.awaitility diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..85d6a88e4 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,222 @@ +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + private static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + private static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + private static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next use should immediately re-initialize with no issue and send the + // request without any broken connections. + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 5452c8eac..049bea008 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -110,14 +110,16 @@ void tearDown() { onClose(); } - void verifyInitializationTimeout(Function> operation, String action) { + void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, + String action) { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) - .expectSubscription() - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); + StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); + }); + } + + void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); }); } @@ -133,7 +135,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -153,7 +155,7 @@ void testListTools() { @Test void testPingWithoutInitialization() { - verifyInitializationTimeout(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -168,7 +170,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -202,7 +204,7 @@ void testCallToolWithInvalidTool() { @Test void testListResourcesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -233,7 +235,7 @@ void testMcpAsyncClientState() { @Test void testListPromptsWithoutInitialization() { - verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(null), "listing " + "prompts"); } @Test @@ -258,7 +260,7 @@ void testListPrompts() { @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); + verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); } @Test @@ -279,7 +281,7 @@ void testGetPrompt() { @Test void testRootsListChangedWithoutInitialization() { - verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -354,7 +356,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(), + "listing resource templates"); } @Test @@ -447,8 +450,8 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 128441f80..3785fd645 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -12,7 +13,6 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -112,33 +112,18 @@ void tearDown() { static final Object DUMMY_RETURN_VALUE = new Object(); - void verifyNotificationTimesOut(Consumer operation, String action) { - verifyCallTimesOut(client -> { + void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { + verifyCallSucceedsWithImplicitInitialization(client -> { operation.accept(client); return DUMMY_RETURN_VALUE; }, action); } - void verifyCallTimesOut(Function blockingOperation, String action) { + void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { withClient(createMcpTransport(), mcpSyncClient -> { - // This scheduler is not replaced by virtual time scheduler - Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); - - StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) - // Offload the blocking call to the real scheduler - .subscribeOn(customScheduler)) - .expectSubscription() - // This works without actually waiting but executes all the - // tasks pending execution on the VirtualTimeScheduler. - // It is possible to execute the blocking code from the operation - // because it is blocked on a dedicated Scheduler and the main - // flow is not blocked and uses the VirtualTimeScheduler. - .thenAwait(getInitializationTimeout()) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before " + action)) - .verify(); - - customScheduler.dispose(); + StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))) + .expectNextCount(1) + .verifyComplete(); }); } @@ -154,7 +139,7 @@ void testConstructorWithInvalidArguments() { @Test void testListToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.listTools(null), "listing tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(null), "listing tools"); } @Test @@ -175,8 +160,8 @@ void testListTools() { @Test void testCallToolsWithoutInitialization() { - verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), - "calling tools"); + verifyCallSucceedsWithImplicitInitialization( + client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); } @Test @@ -200,7 +185,7 @@ void testCallTools() { @Test void testPingWithoutInitialization() { - verifyCallTimesOut(client -> client.ping(), "pinging the server"); + verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); } @Test @@ -214,7 +199,7 @@ void testPing() { @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); + verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); } @Test @@ -243,7 +228,7 @@ void testCallToolWithInvalidTool() { @Test void testRootsListChangedWithoutInitialization() { - verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), "sending roots list changed notification"); } @@ -257,7 +242,7 @@ void testRootsListChanged() { @Test void testListResourcesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResources(null), "listing resources"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(null), "listing resources"); } @Test @@ -333,8 +318,14 @@ void testRemoveNonExistentRoot() { @Test void testReadResourceWithoutInitialization() { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); + AtomicReference> resources = new AtomicReference<>(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + resources.set(mcpSyncClient.listResources().resources()); + }); + + verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), + "reading resources"); } @Test @@ -355,7 +346,8 @@ void testReadResource() { @Test void testListResourceTemplatesWithoutInitialization() { - verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); + verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(null), + "listing resource templates"); } @Test @@ -413,8 +405,8 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), - "setting logging level"); + verifyNotificationSucceedsWithImplicitInitialization( + client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); } @Test diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index a22ef6b51..8f0433eb1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -9,9 +9,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpClientSession; @@ -32,7 +32,7 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; @@ -77,29 +77,37 @@ * @see McpClient * @see McpSchema * @see McpClientSession + * @see McpClientTransport */ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { }; - protected final Sinks.One initializedSink = Sinks.one(); + public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + }; + + public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + }; - private AtomicBoolean initialized = new AtomicBoolean(false); + private final AtomicReference initializationRef = new AtomicReference<>(); /** * The max timeout to await for the client-server connection to be initialized. */ private final Duration initializationTimeout; - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final McpClientSession mcpSession; - /** * Client capabilities. */ @@ -110,21 +118,6 @@ public class McpAsyncClient { */ private final McpSchema.Implementation clientInfo; - /** - * Server capabilities. - */ - private McpSchema.ServerCapabilities serverCapabilities; - - /** - * Server instructions. - */ - private String serverInstructions; - - /** - * Server implementation information. - */ - private McpSchema.Implementation serverInfo; - /** * Roots define the boundaries of where servers can operate within the filesystem, * allowing them to understand which directories and files they have access to. @@ -155,13 +148,19 @@ public class McpAsyncClient { /** * Client transport implementation. */ - private final McpTransport transport; + private final McpClientTransport transport; /** * Supported protocol versions. */ private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Supplier sessionSupplier; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. @@ -254,8 +253,29 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); + this.transport.setExceptionHandler(this::handleException); + this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers); + + } + + private void handleException(Throwable t) { + logger.warn("Handling exception", t); + if (t instanceof McpTransportSessionNotFoundException) { + Initialization previous = this.initializationRef.getAndSet(null); + if (previous != null) { + previous.close(); + } + // Providing an empty operation since we are only interested in triggering the + // implicit initialization step. + withSession("re-initializing", result -> Mono.empty()).subscribe(); + } + } + private McpSchema.InitializeResult currentInitializationResult() { + Initialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; } /** @@ -263,7 +283,8 @@ public class McpAsyncClient { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.capabilities() : null; } /** @@ -272,7 +293,8 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - return this.serverInstructions; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.instructions() : null; } /** @@ -280,7 +302,8 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + McpSchema.InitializeResult initializeResult = currentInitializationResult(); + return initializeResult != null ? initializeResult.serverInfo() : null; } /** @@ -288,7 +311,8 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - return this.initialized.get(); + Initialization current = this.initializationRef.get(); + return current != null && (current.result.get() != null); } /** @@ -311,7 +335,11 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - this.mcpSession.close(); + Initialization current = this.initializationRef.getAndSet(null); + if (current != null) { + current.close(); + } + this.transport.close(); } /** @@ -319,14 +347,21 @@ public void close() { * @return A Mono that completes when the connection is closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + return Mono.defer(() -> { + Initialization current = this.initializationRef.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose.then(transport.closeGracefully()); + }); } // -------------------------- // Initialization // -------------------------- /** - * The initialization phase MUST be the first interaction between client and server. + * The initialization phase should be the first interaction between client and server. + * The client will ensure it happens in case it has not been explicitly called and in + * case of transport session invalidation. + *

* During this phase, the client and server: *