diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 8658ac785..ac967ff9f 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -2,7 +2,7 @@ name: Bug report about: Create a bug report to help us improve the project title: '' -labels: 'type: bug, status: waiting-for-triage' +labels: status/waiting for triage assignees: '' --- diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index a07b6a840..c903204cd 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,5 @@ blank_issues_enabled: false contact_links: - name: Questions and Community Support - url: https://stackoverflow.com/questions/tagged/spring-ai-mcp - about: Please ask and answer questions on StackOverflow with the spring-ai tag + url: https://stackoverflow.com/questions/tagged/mcp-java-sdk + about: Please ask and answer questions on StackOverflow with the mcp-java-sdk tag diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index aba7d39de..16ba64eef 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -2,7 +2,7 @@ name: Feature request about: Suggest an idea for this project title: '' -labels: 'status: waiting-for-triage, type: feature' +labels: status/waiting for triage assignees: '' --- diff --git a/.github/ISSUE_TEMPLATE/miscellaneous.md b/.github/ISSUE_TEMPLATE/miscellaneous.md index d77c625c3..1db42e3b9 100644 --- a/.github/ISSUE_TEMPLATE/miscellaneous.md +++ b/.github/ISSUE_TEMPLATE/miscellaneous.md @@ -2,7 +2,7 @@ name: Miscellaneous about: Suggest an improvement for this project title: '' -labels: 'status: waiting-for-triage' +labels: status/waiting for triage assignees: '' --- diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..c25de745b --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,22 @@ +version: 2 +updates: + - package-ecosystem: 'github-actions' + directory: '/' + schedule: + interval: monthly + - package-ecosystem: 'maven' + directory: '/' + schedule: + interval: monthly + open-pull-requests-limit: 10 + ignore: + # Freeze production dependencies of mcp-core + - dependency-name: 'org.slf4j:slf4j-api' + - dependency-name: 'com.fasterxml.jackson.core:jackson-annotations' + - dependency-name: 'tools.jackson.core:jackson-databind' + - dependency-name: 'io.projectreactor:reactor-bom' + - dependency-name: 'io.projectreactor:reactor-core' + - dependency-name: 'jakarta.servlet:jakarta.servlet-api' + # mcp-json-jackson2 and mcp-json-jackson3 dependencies + - dependency-name: 'com.fasterxml.jackson.core:jackson-databind' + - dependency-name: 'com.networknt:json-schema-validator' \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c73d9f38..0c79351a6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,7 +5,7 @@ on: jobs: build: - name: Build branch + name: Build and Test runs-on: ubuntu-latest steps: - name: Checkout source code @@ -20,3 +20,20 @@ jobs: - name: Build run: mvn verify + + jackson2-tests: + name: Jackson 2 Integration Tests + runs-on: ubuntu-latest + steps: + - name: Checkout source code + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Jackson 2 Integration Tests + run: mvn -pl mcp-test -am -Pjackson2 test diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml new file mode 100644 index 000000000..2e655d6ce --- /dev/null +++ b/.github/workflows/conformance.yml @@ -0,0 +1,104 @@ +name: Conformance Tests + +on: + pull_request: {} + push: + branches: [main] + workflow_dispatch: + +jobs: + server: + name: Server Conformance + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Build and start server + run: | + mvn clean install -DskipTests + mvn exec:java -pl conformance-tests/server-servlet -Dexec.mainClass="io.modelcontextprotocol.conformance.server.ConformanceServlet" & + timeout 30 bash -c 'until curl -s http://localhost:8080/mcp > /dev/null 2>&1; do sleep 0.5; done' + + - name: Run conformance tests + uses: modelcontextprotocol/conformance@v0.1.11 + with: + mode: server + url: http://localhost:8080/mcp + suite: active + expected-failures: ./conformance-tests/conformance-baseline.yml + + client: + name: Client Conformance + runs-on: ubuntu-latest + strategy: + matrix: + scenario: [initialize, tools_call, elicitation-sep1034-client-defaults, sse-retry] + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Build client + run: mvn clean install -DskipTests + + - name: Run conformance test + uses: modelcontextprotocol/conformance@v0.1.11 + with: + mode: client + command: 'java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-1.0.0-SNAPSHOT.jar' + scenario: ${{ matrix.scenario }} + expected-failures: ./conformance-tests/conformance-baseline.yml + + auth: + name: Auth Conformance + runs-on: ubuntu-latest + strategy: + matrix: + scenario: + - auth/metadata-default + - auth/metadata-var1 + - auth/metadata-var2 + - auth/metadata-var3 + - auth/basic-cimd + - auth/scope-from-www-authenticate + - auth/scope-from-scopes-supported + - auth/scope-omitted-when-undefined + - auth/scope-step-up + - auth/scope-retry-limit + - auth/token-endpoint-auth-basic + - auth/token-endpoint-auth-post + - auth/token-endpoint-auth-none + - auth/pre-registration + steps: + - uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: 'maven' + + - name: Build client + run: mvn clean install -DskipTests + + - name: Run conformance test + uses: modelcontextprotocol/conformance@v0.1.15 + with: + node-version: '22' # see https://github.com/modelcontextprotocol/conformance/pull/162 + mode: client + command: 'java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-1.0.0-SNAPSHOT.jar' + scenario: ${{ matrix.scenario }} + expected-failures: ./conformance-tests/conformance-baseline.yml \ No newline at end of file diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 000000000..56b5a1207 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,54 @@ +name: Deploy Documentation + +on: + push: + branches: + - main + paths: + - 'docs/**' + - 'mkdocs.yml' + release: + types: + - published + workflow_dispatch: + +permissions: + contents: write + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v5 + with: + python-version: 3.x + + - run: pip install mkdocs-material mike + + - name: Configure git user + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Deploy docs (push to main) + if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' + run: | + PROJECT_VERSION=$(mvn help:evaluate -Dexpression=project.version --quiet -DforceStdout) + if [[ "${PROJECT_VERSION}" == *-SNAPSHOT ]]; then + ALIAS="latest-snapshot" + else + ALIAS="latest" + fi + mike deploy --push --update-aliases "${PROJECT_VERSION}" "${ALIAS}" + mike set-default latest --push + + - name: Deploy versioned docs (release) + if: github.event_name == 'release' + run: | + VERSION=${GITHUB_REF_NAME} + mike deploy --push --update-aliases "${VERSION}" latest + mike set-default latest --push diff --git a/.github/workflows/maven-central-release.yml b/.github/workflows/maven-central-release.yml index c6c9d3ab6..8df337ec8 100644 --- a/.github/workflows/maven-central-release.yml +++ b/.github/workflows/maven-central-release.yml @@ -25,7 +25,10 @@ jobs: uses: actions/setup-node@v4 with: node-version: '20' - + + - name: Jackson 2 Integration Tests + run: mvn -pl mcp-test -am -Pjackson2 test + - name: Build and Test run: mvn clean verify diff --git a/.github/workflows/publish-snapshot.yml b/.github/workflows/publish-snapshot.yml index 5d9b4aa39..1a61d336c 100644 --- a/.github/workflows/publish-snapshot.yml +++ b/.github/workflows/publish-snapshot.yml @@ -32,6 +32,9 @@ jobs: - name: Generate Java docs run: mvn -Pjavadoc -B javadoc:aggregate + - name: Jackson 2 Integration Tests + run: mvn -pl mcp-test -am -Pjackson2 test + - name: Build with Maven and deploy to Sonatype snapshot repository env: MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} diff --git a/.gitignore b/.gitignore index b80dac20d..1fc975c0a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ build/ out /.gradletasknamecache **/*.flattened-pom.xml +**/dependency-reduced-pom.xml ### IDE - Eclipse/STS ### .apt_generated @@ -56,6 +57,9 @@ node_modules/ package-lock.json package.json +### MkDocs ### +site/ + ### Other ### .antlr/ .profiler/ diff --git a/DEPENDENCY_POLICY.md b/DEPENDENCY_POLICY.md new file mode 100644 index 000000000..5714a6b57 --- /dev/null +++ b/DEPENDENCY_POLICY.md @@ -0,0 +1,26 @@ +# Dependency Policy + +As a library consumed by downstream projects, the MCP Java SDK takes a conservative approach to dependency updates. Dependencies are kept stable unless there is a specific reason to update, such as a security vulnerability, a bug fix, or a need for new functionality. + +## Update Triggers + +Dependencies are updated when: + +- A **security vulnerability** is disclosed (via GitHub security alerts). +- A bug in a dependency directly affects the SDK. +- A new dependency feature is needed for SDK development. +- A dependency drops support for a Java version the SDK still targets. + +Routine version bumps without a clear motivation are avoided to minimize churn for downstream consumers. + +## What We Don't Do + +The SDK does not run scheduled version bumps for production Maven dependencies. Updating a dependency can force downstream consumers to adopt that update transitively, which can be disruptive for projects with strict dependency policies. + +Dependencies are only updated when there is a concrete reason, not simply because a newer version is available. + +## Automated Tooling + +- **GitHub security updates** are enabled at the repository level and automatically open pull requests for Maven packages with known vulnerabilities. This is a GitHub repo setting, separate from the `dependabot.yml` configuration. +- **GitHub Actions versions** are kept up to date via Dependabot on a monthly schedule (see `.github/dependabot.yml`). +- **Maven dependencies** are monitored via Dependabot on a monthly schedule for non-production updates only (see `.github/dependabot.yml`). diff --git a/README.md b/README.md index 436104c63..34133a796 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,11 @@ # MCP Java SDK +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/license/MIT) [![Build Status](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml/badge.svg)](https://github.com/modelcontextprotocol/java-sdk/actions/workflows/publish-snapshot.yml) +[![Maven Central](https://img.shields.io/maven-central/v/io.modelcontextprotocol.sdk/mcp.svg?label=Maven%20Central)](https://central.sonatype.com/artifact/io.modelcontextprotocol.sdk/mcp) +[![Java Version](https://img.shields.io/badge/Java-17%2B-orange)](https://www.oracle.com/java/technologies/javase/jdk17-archive-downloads.html) -A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.org/docs/concepts/architecture). + +A set of projects that provide Java SDK integration for the [Model Context Protocol](https://modelcontextprotocol.io/docs/concepts/architecture). This SDK enables Java applications to interact with AI models and tools through a standardized interface, supporting both synchronous and asynchronous communication patterns. ## πŸ“š Reference Documentation @@ -9,14 +13,17 @@ This SDK enables Java applications to interact with AI models and tools through #### MCP Java SDK documentation For comprehensive guides and SDK API documentation -- [Features](https://modelcontextprotocol.io/sdk/java/mcp-overview#features) - Overview the features provided by the Java MCP SDK -- [Architecture](https://modelcontextprotocol.io/sdk/java/mcp-overview#architecture) - Java MCP SDK architecture overview. -- [Java Dependencies / BOM](https://modelcontextprotocol.io/sdk/java/mcp-overview#dependencies) - Java dependencies and BOM. -- [Java MCP Client](https://modelcontextprotocol.io/sdk/java/mcp-client) - Learn how to use the MCP client to interact with MCP servers. -- [Java MCP Server](https://modelcontextprotocol.io/sdk/java/mcp-server) - Learn how to implement and configure a MCP servers. +- [Features](https://modelcontextprotocol.github.io/java-sdk/#features) - Overview the features provided by the Java MCP SDK +- [Architecture](https://modelcontextprotocol.github.io/java-sdk/#architecture) - Java MCP SDK architecture overview. +- [Java Dependencies / BOM](https://modelcontextprotocol.github.io/java-sdk/quickstart/#dependencies) - Java dependencies and BOM. +- [Java MCP Client](https://modelcontextprotocol.github.io/java-sdk/client/) - Learn how to use the MCP client to interact with MCP servers. +- [Java MCP Server](https://modelcontextprotocol.github.io/java-sdk/server/) - Learn how to implement and configure a MCP servers. #### Spring AI MCP documentation -[Spring AI MCP](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-overview.html) extends the MCP Java SDK with Spring Boot integration, providing both [client](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-client-boot-starter-docs.html) and [server](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-server-boot-starter-docs.html) starters. Bootstrap your AI applications with MCP support using [Spring Initializer](https://start.spring.io). +[Spring AI MCP](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) extends the MCP Java SDK with Spring Boot integration, providing both [client](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-client-boot-starter-docs.html) and [server](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-server-boot-starter-docs.html) starters. +The [MCP Annotations](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-annotations-overview.html) - provides annotation-based method handling for MCP servers and clients in Java. +The [MCP Security](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-security.html) - provides comprehensive OAuth 2.0 and API key-based security support for Model Context Protocol implementations in Spring AI. +Bootstrap your AI applications with MCP support using [Spring Initializer](https://start.spring.io). ## Development @@ -43,6 +50,7 @@ Please follow the [Contributing Guidelines](CONTRIBUTING.md). - Christian Tzolov - Dariusz JΔ™drzejczyk +- Daniel Garnier-Moiroux ## Links @@ -50,6 +58,145 @@ Please follow the [Contributing Guidelines](CONTRIBUTING.md). - [Issue Tracker](https://github.com/modelcontextprotocol/java-sdk/issues) - [CI/CD](https://github.com/modelcontextprotocol/java-sdk/actions) +## Architecture and Design Decisions + +### Introduction + +Building a general-purpose MCP Java SDK requires making technology decisions in areas where the JDK provides limited or no support. The Java ecosystem is powerful but fragmented: multiple valid approaches exist, each with strong communities. +Our goal is not to prescribe "the one true way," but to provide a reference implementation of the MCP specification that is: + +* **Pragmatic** – makes developers productive quickly +* **Interoperable** – aligns with widely used libraries and practices +* **Pluggable** – allows alternatives where projects prefer different stacks +* **Grounded in team familiarity** – we chose technologies the team can be productive with today, while remaining open to community contributions that broaden the SDK + +### Key Choices and Considerations + +The SDK had to make decisions in the following areas: + +1. **JSON serialization** – mapping between JSON and Java types + +2. **Programming model** – supporting asynchronous processing, cancellation, and streaming while staying simple for blocking use cases + +3. **Observability** – logging and enabling integration with metrics/tracing + +4. **Remote clients and servers** – supporting both consuming MCP servers (client transport) and exposing MCP endpoints (server transport with authorization) + +The following sections explain what we chose, why it made sense, and how the choices align with the SDK's goals. + +### 1. JSON Serialization + +* **SDK Choice**: Jackson for JSON serialization and deserialization, behind an SDK abstraction (package `io.modelcontextprotocol.json` in `mcp-core`) + +* **Why**: Jackson is widely adopted across the Java ecosystem, provides strong performance and a mature annotation model, and is familiar to the SDK team and many potential contributors. + +* **How we expose it**: Public APIs use a bundled abstraction. Jackson is shipped as the default implementation (`mcp-json-jackson3`), but alternatives can be plugged in. + +* **How it fits the SDK**: This offers a pragmatic default while keeping flexibility for projects that prefer different JSON libraries. + +### 2. Programming Model + +* **SDK Choice**: Reactive Streams for public APIs, with Project Reactor as the internal implementation and a synchronous facade for blocking use cases + +* **Why**: MCP builds on JSON-RPC's asynchronous nature and defines a bidirectional protocol on top of it, enabling asynchronous and streaming interactions. MCP explicitly supports: + + * Multiple in-flight requests and responses + * Notifications that do not expect a reply + * STDIO transports for inter-process communication using pipes + * Streaming transports such as Server-Sent Events and Streamable HTTP + + These requirements call for a programming model more powerful than single-result futures like `CompletableFuture`. + + * **Reactive Streams: the Community Standard** + + Reactive Streams is a small Java specification that standardizes asynchronous stream processing with backpressure. It defines four minimal interfaces (Publisher, Subscriber, Subscription, and Processor). These interfaces are widely recognized as the standard contract for async, non-blocking pipelines in Java. + + * **Reactive Streams Implementation** + + The SDK uses Project Reactor as its implementation of the Reactive Streams specification. Reactor is mature, widely adopted, provides rich operators, and integrates well with observability through context propagation. Team familiarity also allowed us to deliver a solid foundation quickly. + We plan to convert the public API to only expose Reactive Streams interfaces. By defining the public API in terms of Reactive Streams interfaces and using Reactor internally, the SDK stays standards-based while benefiting from a practical, production-ready implementation. + + * **Synchronous Facade in the SDK** + + Not all MCP use cases require streaming pipelines. Many scenarios are as simple as "send a request and block until I get the result." + To support this, the SDK provides a synchronous facade layered on top of the reactive core. Developers can stay in a blocking model when it's enough, while still having access to asynchronous streaming when needed. + +* **How it fits the SDK**: This design balances scalability, approachability, and future evolution such as Virtual Threads and Structured Concurrency in upcoming JDKs. + +### 3. Observability + +* **SDK Choice**: SLF4J for logging; Reactor Context for observability propagation + +* **Why**: SLF4J is the de facto logging facade in Java, with broad compatibility. Reactor Context enables propagation of observability data such as correlation IDs and tracing state across async boundaries. This ensures interoperability with modern observability frameworks. + +* **How we expose it**: Public APIs log through SLF4J only, with no backend included. Observability metadata flows through Reactor pipelines. The SDK itself does not ship metrics or tracing implementations. + +* **How it fits the SDK**: This provides reliable logging by default and seamless integration with Micrometer, OpenTelemetry, or similar systems for metrics and tracing. + +### 4. Remote MCP Clients and Servers + +MCP supports both clients (applications consuming MCP servers) and servers (applications exposing MCP endpoints). The SDK provides support for both sides. + +#### Client Transport in the SDK + +* **SDK Choice**: JDK HttpClient (Java 11+) as the default client + +* **Why**: The JDK HttpClient is built-in, portable, and supports streaming responses. This keeps the default lightweight with no extra dependencies. + +* **How we expose it**: MCP Client APIs are transport-agnostic. The core module ships with JDK HttpClient transport. Spring WebClient-based transport is available in [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+. + +* **How it fits the SDK**: This ensures all applications can talk to MCP servers out of the box, while allowing richer integration in Spring and other environments. + +#### Server Transport in the SDK + +* **SDK Choice**: Jakarta Servlet implementation in core + +* **Why**: Servlet is the most widely deployed Java server API, providing broad reach across blocking and non-blocking models without additional dependencies. + +* **How we expose it**: Server APIs are transport-agnostic. Core includes Servlet support. Spring WebFlux and WebMVC server transports are available in [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+. + +* **How it fits the SDK**: This allows developers to expose MCP servers in the most common Java environments today, while enabling other transport implementations such as Netty, Vert.x, or Helidon. + +#### Authorization in the SDK + +* **SDK Choice**: Pluggable authorization hooks for MCP servers; no built-in implementation + +* **Why**: MCP servers must restrict access to authenticated and authorized clients. Authorization needs differ across environments such as Spring Security, MicroProfile JWT, or custom solutions. Providing hooks avoids lock-in and leverages proven libraries. + +* **How we expose it**: Authorization is integrated into the server transport layer. The SDK does not include its own authorization system. + +* **How it fits the SDK**: This keeps server-side security ecosystem-neutral, while ensuring applications can plug in their preferred authorization strategy. + +### Project Structure of the SDK + +The SDK is organized into modules to separate concerns and allow adopters to bring in only what they need: +* `mcp-bom` – Dependency versions +* `mcp-core` – Reference implementation (STDIO, JDK HttpClient, Servlet), JSON binding interface definitions +* `mcp-json-jackson2` – Jackson 2 implementation of JSON binding +* `mcp-json-jackson3` – Jackson 3 implementation of JSON binding +* `mcp` – Convenience bundle (core + Jackson 3) +* `mcp-test` – Shared testing utilities + +Spring integrations (WebClient, WebFlux, WebMVC) are now part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`). + +For example, a minimal adopter may depend only on `mcp` (core + Jackson), while a Spring-based application can use the Spring AI `mcp-spring-webflux` or `mcp-spring-webmvc` artifacts for deeper framework integration. + +Additionally, `mcp-test` contains integration tests for `mcp-core`. +`mcp-core` needs a JSON implementation to run full integration tests. +Implementations such as `mcp-json-jackson3`, depend on `mcp-core`, and therefore cannot be imported in `mcp-core` for tests. +Instead, all integration tests that need a JSON implementation are now in `mcp-test`, and use `jackson3` by default. +A `jackson2` maven profile allows to run integration tests with Jackson 2, like so: + + +```bash +./mvnw -pl mcp-test -am -Pjackson2 test +``` + +### Future Directions + +The SDK is designed to evolve with the Java ecosystem. Areas we are actively watching include: +Concurrency in the JDK – Virtual Threads and Structured Concurrency may simplify the synchronous API story + ## License This project is licensed under the [MIT License](LICENSE). diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 000000000..b5b7dc4d7 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,45 @@ +# Roadmap + +## Spec Implementation Tracking + +The SDK tracks implementation of MCP spec components via GitHub Projects, with a dedicated project board for each spec revision. For example, see the [2025-11-25 spec revision board](https://github.com/orgs/modelcontextprotocol/projects/26/views/1). + +## Current Focus Areas + +### 2025-11-25 Spec Implementation + +The Java SDK is actively implementing the [2025-11-25 MCP specification revision](https://github.com/orgs/modelcontextprotocol/projects/26/views/1). + +Key features in this revision include: + +- **Tasks**: Experimental support for tracking durable requests with polling and deferred result retrieval +- **Tool calling in sampling**: Support for `tools` and `toolChoice` parameters +- **URL mode elicitation**: Client-side URL elicitation requests +- **Icons metadata**: Servers can expose icons for tools, resources, resource templates, and prompts +- **Enhanced schemas**: JSON Schema 2020-12 as default, improved enum support, default values for elicitation +- **Security improvements**: Updated security best practices, enhanced authorization flows, enabling OAuth integrations + +See the full [changelog](https://modelcontextprotocol.io/specification/2025-11-25/changelog) for details. + +### Tier 1 SDK Support + +Once we catch up on the most recent MCP specification revision we aim to fully support all the upcoming specification features on the day of its release. + +### v1.x Development + +The Java SDK is currently in active development as v1.x, following a recent stable 1.0.0 release. The SDK provides: + +- MCP protocol implementation +- Synchronous and asynchronous programming models +- Multiple transport options (STDIO, HTTP/SSE, Servlet) +- Pluggable JSON serialization (Jackson 2 and Jackson 3) + +Development is tracked via [GitHub Issues](https://github.com/modelcontextprotocol/java-sdk/issues) and [GitHub Projects](https://github.com/orgs/modelcontextprotocol/projects). + +### Future Versions + +Major version updates will align with MCP specification changes and breaking API changes as needed. The SDK is designed to evolve with the Java ecosystem, including: + +- Virtual Threads and Structured Concurrency support +- Additional transport implementations +- Performance optimizations diff --git a/SECURITY.md b/SECURITY.md index 74e9880fd..502924200 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,21 +1,21 @@ # Security Policy -Thank you for helping us keep the SDKs and systems they interact with secure. +Thank you for helping keep the Model Context Protocol and its ecosystem secure. ## Reporting Security Issues -This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model -Context Protocol project. +If you discover a security vulnerability in this repository, please report it through +the [GitHub Security Advisory process](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability) +for this repository. -The security of our systems and user data is Anthropic’s top priority. We appreciate the -work of security researchers acting in good faith in identifying and reporting potential -vulnerabilities. +Please **do not** report security vulnerabilities through public GitHub issues, discussions, +or pull requests. -Our security program is managed on HackerOne and we ask that any validated vulnerability -in this functionality be reported through their -[submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). +## What to Include -## Vulnerability Disclosure Program +To help us triage and respond quickly, please include: -Our Vulnerability Program Guidelines are defined on our -[HackerOne program page](https://hackerone.com/anthropic-vdp). \ No newline at end of file +- A description of the vulnerability +- Steps to reproduce the issue +- The potential impact +- Any suggested fixes (optional) diff --git a/VERSIONING.md b/VERSIONING.md new file mode 100644 index 000000000..331c6d05e --- /dev/null +++ b/VERSIONING.md @@ -0,0 +1,46 @@ +# Versioning Policy + +The MCP Java SDK (`io.modelcontextprotocol.sdk`) follows [Semantic Versioning 2.0.0](https://semver.org/). + +## Version Format + +`MAJOR.MINOR.PATCH` + +- **MAJOR**: Incremented for breaking changes (see below). +- **MINOR**: Incremented for new features that are backward-compatible. +- **PATCH**: Incremented for backward-compatible bug fixes. + +## What Constitutes a Breaking Change + +The following changes are considered breaking and require a major version bump: + +- Removing or renaming a public API (class, interface, method, or constant). +- Changing the signature of a public method in a way that breaks existing callers (removing parameters, changing required/optional status, changing types). +- Removing or renaming a public interface method or field. +- Changing the behavior of an existing API in a way that breaks documented contracts. +- Dropping support for a Java LTS version. +- Removing support for a transport type. +- Changes to the MCP protocol version that require client/server code changes. +- Removing a module from the SDK. + +The following are **not** considered breaking: + +- Adding new methods with default implementations to interfaces. +- Adding new public APIs, classes, interfaces, or methods. +- Adding new optional parameters to existing methods (through method overloading). +- Bug fixes that correct behavior to match documented intent. +- Internal refactoring that does not affect the public API. +- Adding support for new MCP spec features. +- Changes to test dependencies or build tooling. +- Adding new modules to the SDK. + +## How Breaking Changes Are Communicated + +1. **Changelog**: All breaking changes are documented in the GitHub release notes with migration instructions. +2. **Deprecation**: When feasible, APIs are deprecated for at least one minor release before removal using `@Deprecated` annotations, which surface warnings through Java tooling and IDEs. +3. **Migration guide**: Major version releases include a migration guide describing what changed and how to update. +4. **PR labels**: Pull requests containing breaking changes are labeled with `breaking change`. + +## Maven Coordinates + +All SDK modules share the same version number and are released together. The BOM (`mcp-bom`) provides dependency management for all SDK modules to ensure version consistency. diff --git a/conformance-tests/VALIDATION_RESULTS.md b/conformance-tests/VALIDATION_RESULTS.md new file mode 100644 index 000000000..19e74330c --- /dev/null +++ b/conformance-tests/VALIDATION_RESULTS.md @@ -0,0 +1,124 @@ +# MCP Java SDK Conformance Test Validation Results + +## Summary + +**Server Tests:** 37/40 passed (92.5%) +**Client Tests:** 3/4 scenarios passed (9/10 checks passed) +**Auth Tests:** 12/14 scenarios fully passing (178 passed, 1 failed, 1 warning, 85.7% scenarios, 98.9% checks) + +## Server Test Results + +### Passing (37/40) + +- **Lifecycle & Utilities (4/4):** initialize, ping, logging-set-level, completion-complete +- **Tools (11/11):** All scenarios including progress notifications ✨ +- **Elicitation (10/10):** SEP-1034 defaults (5 checks), SEP-1330 enums (5 checks) +- **Resources (4/6):** list, read-text, read-binary, templates-read +- **Prompts (4/4):** list, simple, with-args, embedded-resource, with-image +- **SSE Transport (2/2):** Multiple streams +- **Security (2/2):** Localhost validation passes, DNS rebinding protection + +### Failing (3/40) + +1. **resources-subscribe** - Not implemented in SDK +2. **resources-unsubscribe** - Not implemented in SDK + +## Client Test Results + +### Passing (3/4 scenarios, 9/10 checks) + +- **initialize (1/1):** Protocol negotiation, clientInfo, capabilities +- **tools_call (1/1):** Tool discovery and invocation +- **elicitation-sep1034-client-defaults (5/5):** Default values for string, integer, number, enum, boolean + +### Partially Passing (1/4 scenarios, 1/2 checks) + +- **sse-retry (1/2 + 1 warning):** + - βœ… Reconnects after stream closure + - ❌ Does not respect retry timing + - ⚠️ Does not send Last-Event-ID header (SHOULD requirement) + +**Issue:** Client treats `retry:` SSE field as invalid instead of parsing it for reconnection timing. + +## Auth Test Results (Spring HTTP Client) + +**Status: 178 passed, 1 failed, 1 warning across 14 scenarios** + +Uses the `client-spring-http-client` module with Spring Security OAuth2 and the [mcp-client-security](https://github.com/springaicommunity/mcp-client-security) library. + +### Fully Passing (12/14 scenarios) + +- **auth/metadata-default (12/12):** Default metadata discovery +- **auth/metadata-var1 (12/12):** Metadata discovery variant 1 +- **auth/metadata-var2 (12/12):** Metadata discovery variant 2 +- **auth/metadata-var3 (12/12):** Metadata discovery variant 3 +- **auth/scope-from-www-authenticate (13/13):** Scope extraction from WWW-Authenticate header +- **auth/scope-from-scopes-supported (13/13):** Scope extraction from scopes_supported +- **auth/scope-omitted-when-undefined (13/13):** Scope omitted when not defined +- **auth/scope-retry-limit (11/11):** Scope retry limit handling +- **auth/token-endpoint-auth-basic (17/17):** Token endpoint with HTTP Basic auth +- **auth/token-endpoint-auth-post (17/17):** Token endpoint with POST body auth +- **auth/token-endpoint-auth-none (17/17):** Token endpoint with no client auth +- **auth/pre-registration (6/6):** Pre-registered client credentials flow + +### Partially Passing (2/14 scenarios) + +- **auth/basic-cimd (12/12 + 1 warning):** Basic Client-Initiated Metadata Discovery β€” all checks pass, minor warning +- **auth/scope-step-up (11/12):** Scope step-up challenge β€” 1 failure, client does not fully handle scope escalation after initial authorization + +## Known Limitations + +1. **Resource Subscriptions:** SDK doesn't implement `resources/subscribe` and `resources/unsubscribe` handlers +2. **Client SSE Retry:** Client doesn't parse or respect the `retry:` field, reconnects immediately, and doesn't send Last-Event-ID header +3. **Auth Scope Step-Up:** Client does not fully handle scope step-up challenges where the server requests additional scopes after initial authorization +4. **Auth Basic CIMD:** Minor conformance warning in the basic Client-Initiated Metadata Discovery flow + +## Running Tests + +### Server +```bash +# Start server +cd conformance-tests/server-servlet +../../mvnw compile exec:java -Dexec.mainClass="io.modelcontextprotocol.conformance.server.ConformanceServlet" + +# Run tests (in another terminal) +npx @modelcontextprotocol/conformance server --url http://localhost:8080/mcp --suite active +``` + +### Client +```bash +# Build +cd conformance-tests/client-jdk-http-client +../../mvnw clean package -DskipTests + +# Run all scenarios +for scenario in initialize tools_call elicitation-sep1034-client-defaults sse-retry; do + npx @modelcontextprotocol/conformance client \ + --command "java -jar target/client-jdk-http-client-1.0.0-SNAPSHOT.jar" \ + --scenario $scenario +done +``` + +### Auth (Spring HTTP Client) + +Ensure you run with the conformance testing suite `0.1.15` or higher. + +```bash +# Build +cd conformance-tests/client-spring-http-client +../../mvnw clean package -DskipTests + +# Run auth suite +npx @modelcontextprotocol/conformance@0.1.15 client \ + --spec-version 2025-11-25 \ + --command "java -jar target/client-spring-http-client-0.18.0-SNAPSHOT.jar" \ + --suite auth +``` + +## Recommendations + +### High Priority +1. Fix client SSE retry field handling in `HttpClientStreamableHttpTransport` +2. Implement resource subscription handlers in `McpStatelessAsyncServer` +3. Implement CIMD +4. Implement scope step up diff --git a/conformance-tests/client-jdk-http-client/README.md b/conformance-tests/client-jdk-http-client/README.md new file mode 100644 index 000000000..44eccedf0 --- /dev/null +++ b/conformance-tests/client-jdk-http-client/README.md @@ -0,0 +1,135 @@ +# MCP Conformance Tests - JDK HTTP Client + +This module provides a conformance test client implementation for the Java MCP SDK using the JDK HTTP Client with Streamable HTTP transport. + +## Overview + +The conformance test client is designed to work with the [MCP Conformance Test Framework](https://github.com/modelcontextprotocol/conformance). It validates that the Java MCP SDK client properly implements the MCP specification. + +## Architecture + +The client reads test scenarios from environment variables and accepts the server URL as a command-line argument, following the conformance framework's conventions: + +- **MCP_CONFORMANCE_SCENARIO**: Environment variable specifying which test scenario to run +- **Server URL**: Passed as the last command-line argument + +## Supported Scenarios + +Currently implemented scenarios: + +- **initialize**: Tests the MCP client initialization handshake only + - βœ… Validates protocol version negotiation + - βœ… Validates clientInfo (name and version) + - βœ… Validates proper handling of server capabilities + - Does NOT call any tools or perform additional operations + +- **tools_call**: Tests tool discovery and invocation + - βœ… Initializes the client + - βœ… Lists available tools from the server + - βœ… Calls the `add_numbers` tool with test arguments (a=5, b=3) + - βœ… Validates the tool result + +- **elicitation-sep1034-client-defaults**: Tests client applies default values for omitted elicitation fields (SEP-1034) + - βœ… Initializes the client + - βœ… Lists available tools from the server + - βœ… Calls the `test_client_elicitation_defaults` tool + - βœ… Validates that the client properly applies default values from JSON schema to elicitation responses (5/5 checks pass) + +- **sse-retry**: Tests client respects SSE retry field timing and reconnects properly (SEP-1699) + - ⚠️ Initializes the client + - ⚠️ Lists available tools from the server + - ⚠️ Calls the `test_reconnection` tool which triggers SSE stream closure + - βœ… Client reconnects after stream closure (PASSING) + - ❌ Client does not respect retry timing (FAILING) + - ⚠️ Client does not send Last-Event-ID header (WARNING - SHOULD requirement) + +## Building + +Build the executable JAR: + +```bash +cd conformance-tests/client-jdk-http-client +../../mvnw clean package -DskipTests +``` + +This creates an executable JAR at: +``` +target/client-jdk-http-client-1.0.0-SNAPSHOT.jar +``` + +## Running Tests + +### Using the Conformance Framework + +Run a single scenario: + +```bash +npx @modelcontextprotocol/conformance client \ + --command "java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-1.0.0-SNAPSHOT.jar" \ + --scenario initialize + +npx @modelcontextprotocol/conformance client \ + --command "java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-1.0.0-SNAPSHOT.jar" \ + --scenario tools_call + +npx @modelcontextprotocol/conformance client \ + --command "java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-1.0.0-SNAPSHOT.jar" \ + --scenario elicitation-sep1034-client-defaults + +npx @modelcontextprotocol/conformance client \ + --command "java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-1.0.0-SNAPSHOT.jar" \ + --scenario sse-retry +``` + +Run with verbose output: + +```bash +npx @modelcontextprotocol/conformance client \ + --command "java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-1.0.0-SNAPSHOT.jar" \ + --scenario initialize \ + --verbose +``` + +### Manual Testing + +You can also run the client manually if you have a test server: + +```bash +export MCP_CONFORMANCE_SCENARIO=initialize +java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-1.0.0-SNAPSHOT.jar http://localhost:3000/mcp +``` + +## Test Results + +The conformance framework generates test results showing: + +**Current Status (3/4 scenarios passing):** +- βœ… initialize: 1/1 checks passed +- βœ… tools_call: 1/1 checks passed +- βœ… elicitation-sep1034-client-defaults: 5/5 checks passed +- ⚠️ sse-retry: 1/2 checks passed, 1 warning + +Test result files are generated in `results/-/`: +- `checks.json`: Array of conformance check results with pass/fail status +- `stdout.txt`: Client stdout output +- `stderr.txt`: Client stderr output + +### Known Issue: SSE Retry Handling + +The `sse-retry` scenario currently fails because: +1. The client treats the SSE `retry:` field as invalid instead of parsing it +2. The client does not implement retry timing (reconnects immediately) +3. The client does not send the Last-Event-ID header on reconnection + +This is a known limitation in the `HttpClientStreamableHttpTransport` implementation. + +## Next Steps + +Future enhancements: + +- Fix SSE retry field handling (SEP-1699) to properly parse and respect retry timing +- Implement Last-Event-ID header on reconnection for resumability +- Add auth scenarios (currently excluded as per requirements) +- Implement a comprehensive "everything-client" pattern +- Add to CI/CD pipeline +- Create expected-failures baseline for known issues diff --git a/conformance-tests/client-jdk-http-client/pom.xml b/conformance-tests/client-jdk-http-client/pom.xml new file mode 100644 index 000000000..f30361438 --- /dev/null +++ b/conformance-tests/client-jdk-http-client/pom.xml @@ -0,0 +1,82 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + conformance-tests + 1.1.0-SNAPSHOT + + client-jdk-http-client + jar + MCP Conformance Tests - JDK HTTP Client + JDK HTTP Client conformance tests for the Java MCP SDK + https://github.com/modelcontextprotocol/java-sdk + + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + true + + + + + io.modelcontextprotocol.sdk + mcp + 1.1.0-SNAPSHOT + + + + + ch.qos.logback + logback-classic + ${logback.version} + runtime + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.1 + + + package + + shade + + + + + io.modelcontextprotocol.conformance.client.ConformanceJdkClientMcpClient + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + + diff --git a/conformance-tests/client-jdk-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceJdkClientMcpClient.java b/conformance-tests/client-jdk-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceJdkClientMcpClient.java new file mode 100644 index 000000000..570c4614e --- /dev/null +++ b/conformance-tests/client-jdk-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceJdkClientMcpClient.java @@ -0,0 +1,286 @@ +package io.modelcontextprotocol.conformance.client; + +import java.time.Duration; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * MCP Conformance Test Client - JDK HTTP Client Implementation + * + *

+ * This client is designed to work with the MCP conformance test framework. It reads the + * test scenario from the MCP_CONFORMANCE_SCENARIO environment variable and the server URL + * from command-line arguments. + * + *

+ * Usage: ConformanceJdkClientMcpClient <server-url> + * + * @see MCP Conformance + * Test Framework + */ +public class ConformanceJdkClientMcpClient { + + public static void main(String[] args) { + if (args.length == 0) { + System.err.println("Usage: ConformanceJdkClientMcpClient "); + System.err.println("The server URL must be provided as the last command-line argument."); + System.err.println("The MCP_CONFORMANCE_SCENARIO environment variable must be set."); + System.exit(1); + } + + String scenario = System.getenv("MCP_CONFORMANCE_SCENARIO"); + if (scenario == null || scenario.isEmpty()) { + System.err.println("Error: MCP_CONFORMANCE_SCENARIO environment variable is not set"); + System.exit(1); + } + + String serverUrl = args[args.length - 1]; + + try { + switch (scenario) { + case "initialize": + runInitializeScenario(serverUrl); + break; + case "tools_call": + runToolsCallScenario(serverUrl); + break; + case "elicitation-sep1034-client-defaults": + runElicitationDefaultsScenario(serverUrl); + break; + case "sse-retry": + runSSERetryScenario(serverUrl); + break; + default: + System.err.println("Unknown scenario: " + scenario); + System.err.println("Available scenarios:"); + System.err.println(" - initialize"); + System.err.println(" - tools_call"); + System.err.println(" - elicitation-sep1034-client-defaults"); + System.err.println(" - sse-retry"); + System.exit(1); + } + System.exit(0); + } + catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + e.printStackTrace(); + System.exit(1); + } + } + + /** + * Helper method to create and configure an MCP client with transport. + * @param serverUrl the URL of the MCP server + * @return configured McpSyncClient instance + */ + private static McpSyncClient createClient(String serverUrl) { + HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(serverUrl).build(); + + return McpClient.sync(transport) + .clientInfo(new McpSchema.Implementation("test-client", "1.0.0")) + .requestTimeout(Duration.ofSeconds(30)) + .build(); + } + + /** + * Helper method to create and configure an MCP client with elicitation support. + * @param serverUrl the URL of the MCP server + * @return configured McpSyncClient instance with elicitation handler + */ + private static McpSyncClient createClientWithElicitation(String serverUrl) { + HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(serverUrl).build(); + + // Build client capabilities with elicitation support + var capabilities = McpSchema.ClientCapabilities.builder().elicitation().build(); + + return McpClient.sync(transport) + .clientInfo(new McpSchema.Implementation("test-client", "1.0.0")) + .requestTimeout(Duration.ofSeconds(30)) + .capabilities(capabilities) + .elicitation(request -> { + // Apply default values from the schema to create the content + var content = new java.util.HashMap(); + var schema = request.requestedSchema(); + + if (schema != null && schema.containsKey("properties")) { + @SuppressWarnings("unchecked") + var properties = (java.util.Map) schema.get("properties"); + + // Apply defaults for each property + properties.forEach((key, propDef) -> { + @SuppressWarnings("unchecked") + var propMap = (java.util.Map) propDef; + if (propMap.containsKey("default")) { + content.put(key, propMap.get("default")); + } + }); + } + + // Return accept action with the defaults applied + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, content, null); + }) + .build(); + } + + /** + * Initialize scenario: Tests MCP client initialization handshake. + * @param serverUrl the URL of the MCP server + * @throws Exception if any error occurs during execution + */ + private static void runInitializeScenario(String serverUrl) throws Exception { + McpSyncClient client = createClient(serverUrl); + + try { + // Initialize client + client.initialize(); + + System.out.println("Successfully connected to MCP server"); + } + finally { + // Close the client (which will close the transport) + client.close(); + System.out.println("Connection closed successfully"); + } + } + + /** + * Tools call scenario: Tests tool listing and invocation functionality. + * @param serverUrl the URL of the MCP server + * @throws Exception if any error occurs during execution + */ + private static void runToolsCallScenario(String serverUrl) throws Exception { + McpSyncClient client = createClient(serverUrl); + + try { + // Initialize client + client.initialize(); + + System.out.println("Successfully connected to MCP server"); + + // List available tools + McpSchema.ListToolsResult toolsResult = client.listTools(); + System.out.println("Successfully listed tools"); + + // Call the add_numbers tool if it exists + if (toolsResult != null && toolsResult.tools() != null) { + for (McpSchema.Tool tool : toolsResult.tools()) { + if ("add_numbers".equals(tool.name())) { + // Call the add_numbers tool with test arguments + var arguments = new java.util.HashMap(); + arguments.put("a", 5); + arguments.put("b", 3); + + McpSchema.CallToolResult result = client + .callTool(new McpSchema.CallToolRequest("add_numbers", arguments)); + + System.out.println("Successfully called add_numbers tool"); + if (result != null && result.content() != null) { + System.out.println("Tool result: " + result.content()); + } + break; + } + } + } + } + finally { + // Close the client (which will close the transport) + client.close(); + System.out.println("Connection closed successfully"); + } + } + + /** + * Elicitation defaults scenario: Tests client applies default values for omitted + * elicitation fields (SEP-1034). + * @param serverUrl the URL of the MCP server + * @throws Exception if any error occurs during execution + */ + private static void runElicitationDefaultsScenario(String serverUrl) throws Exception { + McpSyncClient client = createClientWithElicitation(serverUrl); + + try { + // Initialize client + client.initialize(); + + System.out.println("Successfully connected to MCP server"); + + // List available tools + McpSchema.ListToolsResult toolsResult = client.listTools(); + System.out.println("Successfully listed tools"); + + // Call the test_client_elicitation_defaults tool if it exists + if (toolsResult != null && toolsResult.tools() != null) { + for (McpSchema.Tool tool : toolsResult.tools()) { + if ("test_client_elicitation_defaults".equals(tool.name())) { + // Call the tool which will trigger an elicitation request + var arguments = new java.util.HashMap(); + + McpSchema.CallToolResult result = client + .callTool(new McpSchema.CallToolRequest("test_client_elicitation_defaults", arguments)); + + System.out.println("Successfully called test_client_elicitation_defaults tool"); + if (result != null && result.content() != null) { + System.out.println("Tool result: " + result.content()); + } + break; + } + } + } + } + finally { + // Close the client (which will close the transport) + client.close(); + System.out.println("Connection closed successfully"); + } + } + + /** + * SSE retry scenario: Tests client respects SSE retry field timing and reconnects + * properly (SEP-1699). + * @param serverUrl the URL of the MCP server + * @throws Exception if any error occurs during execution + */ + private static void runSSERetryScenario(String serverUrl) throws Exception { + McpSyncClient client = createClient(serverUrl); + + try { + // Initialize client + client.initialize(); + + System.out.println("Successfully connected to MCP server"); + + // List available tools + McpSchema.ListToolsResult toolsResult = client.listTools(); + System.out.println("Successfully listed tools"); + + // Call the test_reconnection tool if it exists + if (toolsResult != null && toolsResult.tools() != null) { + for (McpSchema.Tool tool : toolsResult.tools()) { + if ("test_reconnection".equals(tool.name())) { + // Call the tool which will trigger SSE stream closure and + // reconnection + var arguments = new java.util.HashMap(); + + McpSchema.CallToolResult result = client + .callTool(new McpSchema.CallToolRequest("test_reconnection", arguments)); + + System.out.println("Successfully called test_reconnection tool"); + if (result != null && result.content() != null) { + System.out.println("Tool result: " + result.content()); + } + break; + } + } + } + } + finally { + // Close the client (which will close the transport) + client.close(); + System.out.println("Connection closed successfully"); + } + } + +} diff --git a/conformance-tests/client-jdk-http-client/src/main/resources/logback.xml b/conformance-tests/client-jdk-http-client/src/main/resources/logback.xml new file mode 100644 index 000000000..bb8e3795d --- /dev/null +++ b/conformance-tests/client-jdk-http-client/src/main/resources/logback.xml @@ -0,0 +1,16 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + diff --git a/conformance-tests/client-spring-http-client/README.md b/conformance-tests/client-spring-http-client/README.md new file mode 100644 index 000000000..876a86e1d --- /dev/null +++ b/conformance-tests/client-spring-http-client/README.md @@ -0,0 +1,124 @@ +# MCP Conformance Tests - Spring HTTP Client (Auth Suite) + +This module provides a conformance test client implementation for the Java MCP SDK's **auth** suite. + +OAuth2 support is not implemented in the SDK itself, but we provide hooks to implement the Authorization section of the specification. One such implementation is done in Spring, with Sprign AI and the [mcp-client-security](https://github.com/springaicommunity/mcp-client-security) library. + +This is a Spring web application, we interact with it through a normal HTTP-client that follows redirects and performs OAuth2 authorization flows. + +## Overview + +The conformance test client is designed to work with the [MCP Conformance Test Framework](https://github.com/modelcontextprotocol/conformance). It validates that the Java MCP SDK client, combined with Spring Security's OAuth2 support, properly implements the MCP authorization specification. + +Test with @modelcontextprotocol/conformance@0.1.15. + +## Conformance Test Results + +**Status: 178 passed, 1 failed, 1 warning across 14 scenarios** + +| Scenario | Result | Details | +|---|---|---| +| auth/metadata-default | βœ… Pass | 12/12 | +| auth/metadata-var1 | βœ… Pass | 12/12 | +| auth/metadata-var2 | βœ… Pass | 12/12 | +| auth/metadata-var3 | βœ… Pass | 12/12 | +| auth/basic-cimd | ⚠️ Warning | 12/12 passed, 1 warning | +| auth/scope-from-www-authenticate | βœ… Pass | 13/13 | +| auth/scope-from-scopes-supported | βœ… Pass | 13/13 | +| auth/scope-omitted-when-undefined | βœ… Pass | 13/13 | +| auth/scope-step-up | ❌ Fail | 11/12 (1 failed) | +| auth/scope-retry-limit | βœ… Pass | 11/11 | +| auth/token-endpoint-auth-basic | βœ… Pass | 17/17 | +| auth/token-endpoint-auth-post | βœ… Pass | 17/17 | +| auth/token-endpoint-auth-none | βœ… Pass | 17/17 | +| auth/pre-registration | βœ… Pass | 6/6 | + +See [VALIDATION_RESULTS.md](../VALIDATION_RESULTS.md) for the full project validation results. + +## Architecture + +The client is a Spring Boot application that reads test scenarios from environment variables and accepts the server URL as a command-line argument, following the conformance framework's conventions: + +- **MCP_CONFORMANCE_SCENARIO**: Environment variable specifying which test scenario to run +- **MCP_CONFORMANCE_CONTEXT**: Environment variable with JSON context (used by `auth/pre-registration`) +- **Server URL**: Passed as the last command-line argument + +### Scenario Routing + +The application uses Spring's conditional configuration to select the appropriate scenario at startup: + +- **`DefaultConfiguration`** β€” Activated for all scenarios except `auth/pre-registration`. Uses the OAuth2 Authorization Code flow with dynamic client registration via `McpClientOAuth2Configurer`. +- **`PreRegistrationConfiguration`** β€” Activated only for `auth/pre-registration`. Uses the Client Credentials flow with pre-registered client credentials read from `MCP_CONFORMANCE_CONTEXT`. + +### Key Dependencies + +- **Spring Boot 4.0** with Spring Security OAuth2 Client +- **Spring AI MCP Client** (`spring-ai-starter-mcp-client`) +- **mcp-client-security** β€” Community library providing MCP-specific OAuth2 integration (metadata discovery, dynamic client registration, transport context) + +## Building + +Build the executable JAR: + +```bash +cd conformance-tests/client-spring-http-client +../../mvnw clean package -DskipTests +``` + +This creates an executable JAR at: +``` +target/client-spring-http-client-0.18.0-SNAPSHOT.jar +``` + +## Running Tests + +### Using the Conformance Framework + +Run the full auth suite: + +```bash +npx @modelcontextprotocol/conformance@0.1.15 client \ + --spec-version 2025-11-25 \ + --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-0.18.0-SNAPSHOT.jar" \ + --suite auth +``` + +Run a single scenario: + +```bash +npx @modelcontextprotocol/conformance@0.1.15 client \ + --spec-version 2025-11-25 \ + --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-0.18.0-SNAPSHOT.jar" \ + --scenario auth/metadata-default +``` + +Run with verbose output: + +```bash +npx @modelcontextprotocol/conformance@0.1.15 client \ + --spec-version 2025-11-25 \ + --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-0.18.0-SNAPSHOT.jar" \ + --scenario auth/metadata-default \ + --verbose +``` + +### Manual Testing + +You can also run the client manually if you have a test server: + +```bash +export MCP_CONFORMANCE_SCENARIO=auth/metadata-default +java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-0.18.0-SNAPSHOT.jar http://localhost:3000/mcp +``` + +## Known Issues + +1. **auth/scope-step-up** (1 failure) β€” The client does not fully handle scope step-up challenges where the server requests additional scopes after initial authorization. +2. **auth/basic-cimd** (1 warning) β€” Minor conformance warning in the basic Client-Initiated Metadata Discovery flow. + +## References + +- [MCP Specification β€” Authorization](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization) +- [MCP Conformance Tests](https://github.com/modelcontextprotocol/conformance) +- [mcp-client-security Library](https://github.com/springaicommunity/mcp-client-security) +- [SDK Integration Guide](https://github.com/modelcontextprotocol/conformance/blob/main/SDK_INTEGRATION.md) diff --git a/conformance-tests/client-spring-http-client/pom.xml b/conformance-tests/client-spring-http-client/pom.xml new file mode 100644 index 000000000..94923fb5c --- /dev/null +++ b/conformance-tests/client-spring-http-client/pom.xml @@ -0,0 +1,91 @@ + + + 4.0.0 + + org.springframework.boot + spring-boot-starter-parent + 4.0.2 + + + io.modelcontextprotocol.sdk + client-spring-http-client + 1.0.0-SNAPSHOT + jar + MCP Conformance Tests - Spring HTTP Client + Spring HTTP Client conformance tests for the Java MCP SDK + https://github.com/modelcontextprotocol/java-sdk + + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + 17 + 2.0.0-M2 + true + + + + + org.springframework.boot + spring-boot-starter-webmvc + + + + org.springframework.boot + spring-boot-starter-restclient + + + + org.springframework.ai + spring-ai-starter-mcp-client + ${spring-ai.version} + + + + org.springframework.boot + spring-boot-starter-oauth2-client + + + + org.springaicommunity + mcp-client-security + 0.1.2 + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + + + + maven-central + https://repo.maven.apache.org/maven2/ + + false + + + true + + + + spring-milestones + Spring Milestones + https://repo.spring.io/milestone + + false + + + + + diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceSpringClientApplication.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceSpringClientApplication.java new file mode 100644 index 000000000..00582c9f2 --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceSpringClientApplication.java @@ -0,0 +1,99 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.conformance.client; + +import java.util.Optional; + +import io.modelcontextprotocol.conformance.client.scenario.Scenario; +import org.springaicommunity.mcp.security.client.sync.oauth2.metadata.McpMetadataDiscoveryService; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.DynamicClientRegistrationService; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.InMemoryMcpClientRegistrationRepository; + +import org.springframework.boot.ApplicationArguments; +import org.springframework.boot.ApplicationRunner; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; + +/** + * MCP Conformance Test Client - Spring HTTP Client Implementation. + * + *

+ * This client is designed to work with the MCP conformance test framework. It reads the + * test scenario from the MCP_CONFORMANCE_SCENARIO environment variable and the server URL + * from command-line arguments. + * + *

+ * It specifically tests the {@code auth} conformance suite. It requires Spring to work. + * + *

+ * Usage: java -jar client-spring-http-client.jar <server-url> + * + * @see MCP Conformance + * Test Framework + */ +@SpringBootApplication +public class ConformanceSpringClientApplication { + + public static final String REGISTRATION_ID = "default_registration"; + + public static void main(String[] args) { + SpringApplication.run(ConformanceSpringClientApplication.class, args); + } + + @Bean + McpMetadataDiscoveryService discovery() { + return new McpMetadataDiscoveryService(); + } + + @Bean + InMemoryMcpClientRegistrationRepository clientRegistrationRepository(McpMetadataDiscoveryService discovery) { + return new InMemoryMcpClientRegistrationRepository(new DynamicClientRegistrationService(), discovery); + } + + @Bean + ApplicationRunner conformanceRunner(Optional scenario, ServerUrl serverUrl) { + return args -> { + String scenarioName = System.getenv("MCP_CONFORMANCE_SCENARIO"); + if (scenarioName == null || scenarioName.isEmpty()) { + System.err.println("Error: MCP_CONFORMANCE_SCENARIO environment variable is not set"); + System.exit(1); + } + + if (scenario.isEmpty()) { + System.err.println("Unsupported scenario type"); + System.exit(1); + } + + try { + System.out.println("Executing " + scenarioName); + scenario.get().execute(serverUrl.value()); + System.exit(0); + } + catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + e.printStackTrace(); + System.exit(1); + } + }; + } + + public record ServerUrl(String value) { + } + + @Bean + ServerUrl serverUrl(ApplicationArguments args) { + var nonOptionArgs = args.getNonOptionArgs(); + if (nonOptionArgs.isEmpty()) { + System.err.println("Usage: ConformanceSpringClientApplication "); + System.err.println("The server URL must be provided as a command-line argument."); + System.err.println("The MCP_CONFORMANCE_SCENARIO environment variable must be set."); + System.exit(1); + } + + return new ServerUrl(nonOptionArgs.get(nonOptionArgs.size() - 1)); + } + +} diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/McpClientController.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/McpClientController.java new file mode 100644 index 000000000..e02cfd416 --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/McpClientController.java @@ -0,0 +1,30 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.conformance.client; + +import io.modelcontextprotocol.conformance.client.scenario.Scenario; + +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; + +/** + * Expose MCP client in a web environment. + */ +@RestController +class McpClientController { + + private final Scenario scenario; + + McpClientController(Scenario scenario) { + this.scenario = scenario; + } + + @GetMapping("/initialize-mcp-client") + public String execute() { + this.scenario.getMcpClient().initialize(); + return "OK"; + } + +} diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/DefaultConfiguration.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/DefaultConfiguration.java new file mode 100644 index 000000000..acf26d94e --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/DefaultConfiguration.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.conformance.client.configuration; + +import io.modelcontextprotocol.conformance.client.ConformanceSpringClientApplication; +import io.modelcontextprotocol.conformance.client.scenario.DefaultScenario; +import org.springaicommunity.mcp.security.client.sync.config.McpClientOAuth2Configurer; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpClientRegistrationRepository; + +import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression; +import org.springframework.boot.web.server.servlet.context.ServletWebServerApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.web.SecurityFilterChain; +import static io.modelcontextprotocol.conformance.client.ConformanceSpringClientApplication.REGISTRATION_ID; + +@Configuration +@ConditionalOnExpression("#{environment['MCP_CONFORMANCE_SCENARIO'] != 'auth/pre-registration'}") +public class DefaultConfiguration { + + @Bean + DefaultScenario defaultScenario(McpClientRegistrationRepository clientRegistrationRepository, + ServletWebServerApplicationContext serverCtx, + OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository) { + return new DefaultScenario(clientRegistrationRepository, serverCtx, oAuth2AuthorizedClientRepository); + } + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http, ConformanceSpringClientApplication.ServerUrl serverUrl) { + return http.authorizeHttpRequests(authz -> authz.anyRequest().permitAll()) + .with(new McpClientOAuth2Configurer(), + mcp -> mcp.registerMcpOAuth2Client(REGISTRATION_ID, serverUrl.value())) + .build(); + } + +} diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/PreRegistrationConfiguration.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/PreRegistrationConfiguration.java new file mode 100644 index 000000000..afe03f85a --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/PreRegistrationConfiguration.java @@ -0,0 +1,39 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.conformance.client.configuration; + +import io.modelcontextprotocol.conformance.client.scenario.PreRegistrationScenario; +import org.springaicommunity.mcp.security.client.sync.config.McpClientOAuth2Configurer; +import org.springaicommunity.mcp.security.client.sync.oauth2.metadata.McpMetadataDiscoveryService; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpClientRegistrationRepository; + +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.web.SecurityFilterChain; + +@Configuration +@ConditionalOnProperty(name = "mcp.conformance.scenario", havingValue = "auth/pre-registration") +public class PreRegistrationConfiguration { + + @Bean + PreRegistrationScenario defaultScenario(McpClientRegistrationRepository clientRegistrationRepository, + McpMetadataDiscoveryService mcpMetadataDiscovery, + OAuth2AuthorizedClientService oAuth2AuthorizedClientService) { + return new PreRegistrationScenario(clientRegistrationRepository, mcpMetadataDiscovery, + oAuth2AuthorizedClientService); + } + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) { + return http.authorizeHttpRequests(authz -> authz.anyRequest().permitAll()) + .with(new McpClientOAuth2Configurer(), Customizer.withDefaults()) + .build(); + } + +} diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/DefaultScenario.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/DefaultScenario.java new file mode 100644 index 000000000..d82637de9 --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/DefaultScenario.java @@ -0,0 +1,100 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.conformance.client.scenario; + +import java.net.CookieManager; +import java.net.CookiePolicy; +import java.net.http.HttpClient; +import java.time.Duration; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpSchema; +import org.jspecify.annotations.NonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.security.client.sync.AuthenticationMcpTransportContextProvider; +import org.springaicommunity.mcp.security.client.sync.oauth2.http.client.OAuth2AuthorizationCodeSyncHttpRequestCustomizer; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpClientRegistrationRepository; + +import org.springframework.boot.web.server.servlet.context.ServletWebServerApplicationContext; +import org.springframework.http.client.JdkClientHttpRequestFactory; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.web.client.RestClient; +import static io.modelcontextprotocol.conformance.client.ConformanceSpringClientApplication.REGISTRATION_ID; + +public class DefaultScenario implements Scenario { + + private static final Logger log = LoggerFactory + .getLogger(DefaultScenario.class); + + private final ServletWebServerApplicationContext serverCtx; + + private final DefaultOAuth2AuthorizedClientManager authorizedClientManager; + + private McpSyncClient client; + + public DefaultScenario(McpClientRegistrationRepository clientRegistrationRepository, + ServletWebServerApplicationContext serverCtx, + OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository) { + this.serverCtx = serverCtx; + this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(clientRegistrationRepository, + oAuth2AuthorizedClientRepository); + } + + @Override + public void execute(String serverUrl) { + log.info("Executing DefaultScenario"); + var testServerUrl = "http://localhost:" + serverCtx.getWebServer().getPort(); + var testClient = buildTestClient(testServerUrl); + + var customizer = new OAuth2AuthorizationCodeSyncHttpRequestCustomizer(authorizedClientManager, REGISTRATION_ID); + HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(serverUrl) + .httpRequestCustomizer(customizer) + .build(); + + this.client = McpClient.sync(transport) + .transportContextProvider(new AuthenticationMcpTransportContextProvider()) + .clientInfo(new McpSchema.Implementation("test-client", "1.0.0")) + .requestTimeout(Duration.ofSeconds(30)) + .build(); + + try { + testClient.get().uri("/initialize-mcp-client").retrieve().toBodilessEntity(); + } + finally { + // Close the client (which will close the transport) + this.client.close(); + + System.out.println("Connection closed successfully"); + } + } + + private static @NonNull RestClient buildTestClient(String testServerUrl) { + var cookieManager = new CookieManager(); + cookieManager.setCookiePolicy(CookiePolicy.ACCEPT_ALL); + var httpClient = HttpClient.newBuilder() + .cookieHandler(cookieManager) + .followRedirects(HttpClient.Redirect.ALWAYS) + .build(); + var testClient = RestClient.builder() + .baseUrl(testServerUrl) + .requestFactory(new JdkClientHttpRequestFactory(httpClient)) + .build(); + return testClient; + } + + @Override + public McpSyncClient getMcpClient() { + if (this.client == null) { + return Scenario.super.getMcpClient(); + } + + return this.client; + } + +} diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/PreRegistrationScenario.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/PreRegistrationScenario.java new file mode 100644 index 000000000..8e6bbe228 --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/PreRegistrationScenario.java @@ -0,0 +1,110 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.conformance.client.scenario; + +import java.time.Duration; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.security.client.sync.AuthenticationMcpTransportContextProvider; +import org.springaicommunity.mcp.security.client.sync.oauth2.http.client.OAuth2ClientCredentialsSyncHttpRequestCustomizer; +import org.springaicommunity.mcp.security.client.sync.oauth2.metadata.McpMetadataDiscoveryService; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpClientRegistrationRepository; +import tools.jackson.databind.PropertyNamingStrategies; +import tools.jackson.databind.annotation.JsonNaming; +import tools.jackson.databind.json.JsonMapper; + +import org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.registration.ClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import static io.modelcontextprotocol.conformance.client.ConformanceSpringClientApplication.REGISTRATION_ID; + +public class PreRegistrationScenario implements Scenario { + + private static final Logger log = LoggerFactory.getLogger(PreRegistrationScenario.class); + + private final JsonMapper mapper; + + private final McpClientRegistrationRepository clientRegistrationRepository; + + private final AuthorizedClientServiceOAuth2AuthorizedClientManager authorizedClientManager; + + private final McpMetadataDiscoveryService mcpMetadataDiscovery; + + public PreRegistrationScenario(McpClientRegistrationRepository clientRegistrationRepository, + McpMetadataDiscoveryService mcpMetadataDiscovery, OAuth2AuthorizedClientService authorizedClientService) { + this.mapper = JsonMapper.shared(); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientManager = new AuthorizedClientServiceOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientService); + this.mcpMetadataDiscovery = mcpMetadataDiscovery; + } + + @Override + public void execute(String serverUrl) { + log.info("Executing PreRegistrationScenario"); + + var oauthCredentials = extractCredentialsFromContext(); + setClientRegistration(serverUrl, oauthCredentials); + + var customizer = new OAuth2ClientCredentialsSyncHttpRequestCustomizer(authorizedClientManager, REGISTRATION_ID); + HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(serverUrl) + .httpRequestCustomizer(customizer) + .build(); + + var client = McpClient.sync(transport) + .transportContextProvider(new AuthenticationMcpTransportContextProvider()) + .clientInfo(new McpSchema.Implementation("test-client", "1.0.0")) + .requestTimeout(Duration.ofSeconds(30)) + .build(); + + try { + // Initialize client + client.initialize(); + + System.out.println("Successfully connected to MCP server"); + } + finally { + // Close the client (which will close the transport) + client.close(); + + System.out.println("Connection closed successfully"); + } + } + + private void setClientRegistration(String mcpServerUrl, PreRegistrationContext oauthCredentials) { + var metadata = this.mcpMetadataDiscovery.getMcpMetadata(mcpServerUrl); + var registration = ClientRegistrations + .fromIssuerLocation(metadata.protectedResourceMetadata().authorizationServers().get(0)) + .registrationId(REGISTRATION_ID) + .clientId(oauthCredentials.clientId()) + .clientSecret(oauthCredentials.clientSecret()) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .build(); + clientRegistrationRepository.addPreRegisteredClient(registration, + metadata.protectedResourceMetadata().resource()); + } + + private PreRegistrationContext extractCredentialsFromContext() { + String contextEnv = System.getenv("MCP_CONFORMANCE_CONTEXT"); + if (contextEnv == null || contextEnv.isEmpty()) { + var errorMessage = "Error: MCP_CONFORMANCE_CONTEXT environment variable is not set"; + System.err.println(errorMessage); + throw new RuntimeException(errorMessage); + } + + return mapper.readValue(contextEnv, PreRegistrationContext.class); + } + + @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) + private record PreRegistrationContext(String clientId, String clientSecret) { + + } + +} diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/Scenario.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/Scenario.java new file mode 100644 index 000000000..9054db83b --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/Scenario.java @@ -0,0 +1,17 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.conformance.client.scenario; + +import io.modelcontextprotocol.client.McpSyncClient; + +public interface Scenario { + + default McpSyncClient getMcpClient() { + throw new IllegalStateException("Client not set"); + } + + void execute(String serverUrl); + +} diff --git a/conformance-tests/client-spring-http-client/src/main/resources/application.properties b/conformance-tests/client-spring-http-client/src/main/resources/application.properties new file mode 100644 index 000000000..0c4a77438 --- /dev/null +++ b/conformance-tests/client-spring-http-client/src/main/resources/application.properties @@ -0,0 +1,4 @@ +# Server runs on random port +server.port=0 +# Disable Spring AI MCP client auto-configuration (we configure the client manually) +spring.ai.mcp.client.enabled=false diff --git a/conformance-tests/conformance-baseline.yml b/conformance-tests/conformance-baseline.yml new file mode 100644 index 000000000..4ab144063 --- /dev/null +++ b/conformance-tests/conformance-baseline.yml @@ -0,0 +1,18 @@ +# MCP Java SDK Conformance Test Baseline +# This file lists known failing scenarios that are expected to fail until fixed. +# See: https://github.com/modelcontextprotocol/conformance/blob/main/SDK_INTEGRATION.md + +server: + # Resource subscription not implemented in SDK + - resources-subscribe + - resources-unsubscribe + +client: + # SSE retry field handling not implemented + # - Client does not parse or respect retry: field timing + # - Client does not send Last-Event-ID header + - sse-retry + # CIMD not implemented yet + - auth/basic-cimd + # Scope step up beyond initial authorization request not implemented + - auth/scope-step-up diff --git a/conformance-tests/pom.xml b/conformance-tests/pom.xml new file mode 100644 index 000000000..d1bef2a24 --- /dev/null +++ b/conformance-tests/pom.xml @@ -0,0 +1,33 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 1.1.0-SNAPSHOT + + conformance-tests + pom + MCP Conformance Tests + Conformance tests for the Java MCP SDK + https://github.com/modelcontextprotocol/java-sdk + + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + true + + + + client-jdk-http-client + client-spring-http-client + server-servlet + + + diff --git a/conformance-tests/server-servlet/README.md b/conformance-tests/server-servlet/README.md new file mode 100644 index 000000000..bd86636b6 --- /dev/null +++ b/conformance-tests/server-servlet/README.md @@ -0,0 +1,205 @@ +# MCP Conformance Tests - Servlet Server + +This module contains a comprehensive MCP (Model Context Protocol) server implementation for conformance testing using the servlet stack with an embedded Tomcat server and streamable HTTP transport. + +## Conformance Test Results + +**Status: 37 out of 40 tests passing (92.5%)** + +The server has been validated against the official [MCP conformance test suite](https://github.com/modelcontextprotocol/conformance). See [VALIDATION_RESULTS.md](../VALIDATION_RESULTS.md) for detailed results. + +### What's Implemented + +βœ… **Lifecycle & Utilities** (4/4) +- Server initialization, ping, logging, completion + +βœ… **Tools** (11/11) +- Text, image, audio, embedded resources, mixed content +- Logging, error handling, sampling, elicitation +- Progress notifications + +βœ… **Elicitation** (10/10) +- SEP-1034: Default values for all primitive types +- SEP-1330: All enum schema variants + +βœ… **Resources** (4/6) +- List, read text/binary, templates +- ⚠️ Subscribe/unsubscribe (SDK limitation) + +βœ… **Prompts** (4/4) +- Simple, parameterized, embedded resources, images + +βœ… **SSE Transport** (2/2) +- Multiple streams support + +βœ… **Security** (2/2) +- βœ… DNS rebinding protection + +## Features + +- Embedded Tomcat servlet container +- MCP server using HttpServletStreamableServerTransportProvider +- Comprehensive test coverage with 15+ tools +- Streamable HTTP transport with SSE on `/mcp` endpoint +- Support for all MCP content types (text, image, audio, resources) +- Advanced features: sampling, elicitation, progress (partial), completion + +## Running the Server + +To run the conformance server: + +```bash +cd conformance-tests/server-servlet +../../mvnw compile exec:java -Dexec.mainClass="io.modelcontextprotocol.conformance.server.ConformanceServlet" +``` + +Or from the root directory: + +```bash +./mvnw compile exec:java -pl conformance-tests/server-servlet -Dexec.mainClass="io.modelcontextprotocol.conformance.server.ConformanceServlet" +``` + +The server will start on port 8080 with the MCP endpoint at `/mcp`. + +## Running Conformance Tests + +Once the server is running, you can validate it against the official MCP conformance test suite using `npx`: + +### Run Full Active Test Suite + +```bash +npx @modelcontextprotocol/conformance server --url http://localhost:8080/mcp --suite active +``` + +### Run Specific Scenarios + +```bash +# Test tools +npx @modelcontextprotocol/conformance server --url http://localhost:8080/mcp --scenario tools-list --verbose + +# Test prompts +npx @modelcontextprotocol/conformance server --url http://localhost:8080/mcp --scenario prompts-list --verbose + +# Test resources +npx @modelcontextprotocol/conformance server --url http://localhost:8080/mcp --scenario resources-read-text --verbose + +# Test elicitation with defaults +npx @modelcontextprotocol/conformance server --url http://localhost:8080/mcp --scenario elicitation-sep1034-defaults --verbose +``` + +### Available Test Suites + +- `active` (default) - All active/stable tests (30 scenarios) +- `all` - All tests including pending/experimental +- `pending` - Only pending/experimental tests + +### Common Scenarios + +**Lifecycle & Utilities:** +- `server-initialize` - Server initialization +- `ping` - Ping utility +- `logging-set-level` - Logging configuration +- `completion-complete` - Argument completion + +**Tools:** +- `tools-list` - List available tools +- `tools-call-simple-text` - Simple text response +- `tools-call-image` - Image content +- `tools-call-audio` - Audio content +- `tools-call-with-logging` - Logging during execution +- `tools-call-with-progress` - Progress notifications +- `tools-call-sampling` - LLM sampling +- `tools-call-elicitation` - User input requests + +**Resources:** +- `resources-list` - List resources +- `resources-read-text` - Read text resource +- `resources-read-binary` - Read binary resource +- `resources-templates-read` - Resource templates +- `resources-subscribe` - Subscribe to resource updates +- `resources-unsubscribe` - Unsubscribe from updates + +**Prompts:** +- `prompts-list` - List prompts +- `prompts-get-simple` - Simple prompt +- `prompts-get-with-args` - Parameterized prompt +- `prompts-get-embedded-resource` - Prompt with resource +- `prompts-get-with-image` - Prompt with image + +**Elicitation:** +- `elicitation-sep1034-defaults` - Default values (SEP-1034) +- `elicitation-sep1330-enums` - Enum schemas (SEP-1330) + +## Testing with curl + +You can also test the endpoint manually: + +```bash +# Check endpoint (will show SSE requirement) +curl -X GET http://localhost:8080/mcp + +# Initialize session with proper headers +curl -X POST http://localhost:8080/mcp \ + -H "Content-Type: application/json" \ + -H "Accept: text/event-stream" \ + -H "mcp-session-id: test-session-123" \ + -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}' +``` + +## Architecture + +- **Transport**: HttpServletStreamableServerTransportProvider (streamable HTTP with SSE) +- **Container**: Embedded Apache Tomcat +- **Protocol**: Streamable HTTP with Server-Sent Events +- **Port**: 8080 (default) +- **Endpoint**: `/mcp` +- **Request Timeout**: 30 seconds + +## Implemented Tools + +### Content Type Tools +- `test_simple_text` - Returns simple text content +- `test_image_content` - Returns a minimal PNG image (1x1 red pixel) +- `test_audio_content` - Returns a minimal WAV audio file +- `test_embedded_resource` - Returns embedded resource content +- `test_multiple_content_types` - Returns mixed text, image, and resource content + +### Behavior Tools +- `test_tool_with_logging` - Sends log notifications during execution +- `test_error_handling` - Intentionally returns an error for testing +- `test_tool_with_progress` - Reports progress notifications (⚠️ SDK issue) + +### Interactive Tools +- `test_sampling` - Requests LLM sampling from client +- `test_elicitation` - Requests user input from client +- `test_elicitation_sep1034_defaults` - Elicitation with default values (SEP-1034) +- `test_elicitation_sep1330_enums` - Elicitation with enum schemas (SEP-1330) + +## Implemented Prompts + +- `test_simple_prompt` - Simple prompt without arguments +- `test_prompt_with_arguments` - Prompt with required arguments (arg1, arg2) +- `test_prompt_with_embedded_resource` - Prompt with embedded resource content +- `test_prompt_with_image` - Prompt with image content + +## Implemented Resources + +- `test://static-text` - Static text resource +- `test://static-binary` - Static binary resource (PNG image) +- `test://watched-resource` - Resource that can be subscribed to +- `test://template/{id}/data` - Resource template with parameter substitution + +## Known Limitations + +See [VALIDATION_RESULTS.md](../VALIDATION_RESULTS.md) for details on: + +1. **Resource Subscriptions** - Not implemented in Java SDK +2. **DNS Rebinding Protection** - Missing Host/Origin validation + +These are SDK-level limitations that require fixes in the core framework. + +## References + +- [MCP Specification](https://modelcontextprotocol.io/specification/) +- [MCP Conformance Tests](https://github.com/modelcontextprotocol/conformance) +- [SDK Integration Guide](https://github.com/modelcontextprotocol/conformance/blob/main/SDK_INTEGRATION.md) diff --git a/conformance-tests/server-servlet/pom.xml b/conformance-tests/server-servlet/pom.xml new file mode 100644 index 000000000..68da42158 --- /dev/null +++ b/conformance-tests/server-servlet/pom.xml @@ -0,0 +1,73 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + conformance-tests + 1.1.0-SNAPSHOT + + server-servlet + jar + MCP Conformance Tests - Servlet Server + Servlet Server conformance tests for the Java MCP SDK + https://github.com/modelcontextprotocol/java-sdk + + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + true + + + + + io.modelcontextprotocol.sdk + mcp + 1.1.0-SNAPSHOT + + + + org.slf4j + slf4j-api + ${slf4j-api.version} + + + + ch.qos.logback + logback-classic + ${logback.version} + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + provided + + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.1.0 + + io.modelcontextprotocol.conformance.server.ConformanceServlet + + + + + + \ No newline at end of file diff --git a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java new file mode 100644 index 000000000..3d162a5de --- /dev/null +++ b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java @@ -0,0 +1,596 @@ +package io.modelcontextprotocol.conformance.server; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.DefaultServerTransportSecurityValidator; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.AudioContent; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ImageContent; +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ConformanceServlet { + + private static final Logger logger = LoggerFactory.getLogger(ConformanceServlet.class); + + private static final int PORT = 8080; + + private static final String MCP_ENDPOINT = "/mcp"; + + private static final JsonSchema EMPTY_JSON_SCHEMA = new JsonSchema("object", Collections.emptyMap(), null, null, + null, null); + + // Minimal 1x1 red pixel PNG (base64 encoded) + private static final String RED_PIXEL_PNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="; + + // Minimal WAV file (base64 encoded) - 1 sample at 8kHz + private static final String MINIMAL_WAV = "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAAB9AAACABAAZGF0YQAAAAA="; + + public static void main(String[] args) throws Exception { + logger.info("Starting MCP Conformance Tests - Servlet Server"); + + HttpServletStreamableServerTransportProvider transportProvider = HttpServletStreamableServerTransportProvider + .builder() + .mcpEndpoint(MCP_ENDPOINT) + .keepAliveInterval(Duration.ofSeconds(30)) + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) + .build(); + + // Build server with all conformance test features + var mcpServer = McpServer.sync(transportProvider) + .serverInfo("mcp-conformance-server", "1.0.0") + .capabilities(ServerCapabilities.builder() + .completions() + .resources(true, false) + .tools(false) + .prompts(false) + .build()) + .tools(createToolSpecs()) + .prompts(createPromptSpecs()) + .resources(createResourceSpecs()) + .resourceTemplates(createResourceTemplateSpecs()) + .completions(createCompletionSpecs()) + .requestTimeout(Duration.ofSeconds(30)) + .build(); + + // Set up embedded Tomcat + Tomcat tomcat = createEmbeddedTomcat(transportProvider); + + try { + tomcat.start(); + logger.info("Conformance MCP Servlet Server started on port {} with endpoint {}", PORT, MCP_ENDPOINT); + logger.info("Server URL: http://localhost:{}{}", PORT, MCP_ENDPOINT); + + // Keep the server running + tomcat.getServer().await(); + } + catch (LifecycleException e) { + logger.error("Failed to start Tomcat server", e); + throw e; + } + finally { + logger.info("Shutting down MCP server..."); + mcpServer.closeGracefully(); + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + logger.error("Error during Tomcat shutdown", e); + } + } + } + + private static Tomcat createEmbeddedTomcat(HttpServletStreamableServerTransportProvider transportProvider) { + Tomcat tomcat = new Tomcat(); + tomcat.setPort(PORT); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext("", baseDir); + + // Add the MCP servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(transportProvider); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(30000); + return tomcat; + } + + private static List createToolSpecs() { + return List.of( + // test_simple_text - Returns simple text content + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_simple_text") + .description("Returns simple text content for testing") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_simple_text' called"); + return CallToolResult.builder() + .content(List.of(new TextContent("This is a simple text response for testing."))) + .isError(false) + .build(); + }) + .build(), + + // test_image_content - Returns image content + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_image_content") + .description("Returns image content for testing") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_image_content' called"); + return CallToolResult.builder() + .content(List.of(new ImageContent(null, RED_PIXEL_PNG, "image/png"))) + .isError(false) + .build(); + }) + .build(), + + // test_audio_content - Returns audio content + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_audio_content") + .description("Returns audio content for testing") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_audio_content' called"); + return CallToolResult.builder() + .content(List.of(new AudioContent(null, MINIMAL_WAV, "audio/wav"))) + .isError(false) + .build(); + }) + .build(), + + // test_embedded_resource - Returns embedded resource content + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_embedded_resource") + .description("Returns embedded resource content for testing") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_embedded_resource' called"); + TextResourceContents resourceContents = new TextResourceContents("test://embedded-resource", + "text/plain", "This is an embedded resource content."); + EmbeddedResource embeddedResource = new EmbeddedResource(null, resourceContents); + return CallToolResult.builder().content(List.of(embeddedResource)).isError(false).build(); + }) + .build(), + + // test_multiple_content_types - Returns multiple content types + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_multiple_content_types") + .description("Returns multiple content types for testing") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_multiple_content_types' called"); + TextResourceContents resourceContents = new TextResourceContents( + "test://mixed-content-resource", "application/json", + "{\"test\":\"data\",\"value\":123}"); + EmbeddedResource embeddedResource = new EmbeddedResource(null, resourceContents); + return CallToolResult.builder() + .content(List.of(new TextContent("Multiple content types test:"), + new ImageContent(null, RED_PIXEL_PNG, "image/png"), embeddedResource)) + .isError(false) + .build(); + }) + .build(), + + // test_tool_with_logging - Tool that sends log messages during execution + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_tool_with_logging") + .description("Tool that sends log messages during execution") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_tool_with_logging' called"); + // Send log notifications + exchange.loggingNotification(LoggingMessageNotification.builder() + .level(LoggingLevel.INFO) + .data("Tool execution started") + .build()); + exchange.loggingNotification(LoggingMessageNotification.builder() + .level(LoggingLevel.INFO) + .data("Tool processing data") + .build()); + exchange.loggingNotification(LoggingMessageNotification.builder() + .level(LoggingLevel.INFO) + .data("Tool execution completed") + .build()); + return CallToolResult.builder() + .content(List.of(new TextContent("Tool execution completed with logging"))) + .isError(false) + .build(); + }) + .build(), + + // test_error_handling - Tool that always returns an error + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_error_handling") + .description("Tool that returns an error for testing error handling") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_error_handling' called"); + return CallToolResult.builder() + .content(List.of(new TextContent("This tool intentionally returns an error for testing"))) + .isError(true) + .build(); + }) + .build(), + + // test_tool_with_progress - Tool that reports progress + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_tool_with_progress") + .description("Tool that reports progress notifications") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_tool_with_progress' called"); + Object progressToken = request.meta().get("progressToken"); + if (progressToken != null) { + // Send progress notifications sequentially + exchange.progressNotification(new ProgressNotification(progressToken, 0.0, 100.0, null)); + // try { + // Thread.sleep(50); + // } + // catch (InterruptedException e) { + // Thread.currentThread().interrupt(); + // } + exchange.progressNotification(new ProgressNotification(progressToken, 50.0, 100.0, null)); + // try { + // Thread.sleep(50); + // } + // catch (InterruptedException e) { + // Thread.currentThread().interrupt(); + // } + exchange.progressNotification(new ProgressNotification(progressToken, 100.0, 100.0, null)); + return CallToolResult.builder() + .content(List.of(new TextContent("Tool execution completed with progress"))) + .isError(false) + .build(); + } + else { + // No progress token, just execute with delays + // try { + // Thread.sleep(100); + // } + // catch (InterruptedException e) { + // Thread.currentThread().interrupt(); + // } + return CallToolResult.builder() + .content(List.of(new TextContent("Tool execution completed without progress"))) + .isError(false) + .build(); + } + }) + .build(), + + // test_sampling - Tool that requests LLM sampling from client + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_sampling") + .description("Tool that requests LLM sampling from client") + .inputSchema(new JsonSchema("object", + Map.of("prompt", + Map.of("type", "string", "description", "The prompt to send to the LLM")), + List.of("prompt"), null, null, null)) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_sampling' called"); + String prompt = (String) request.arguments().get("prompt"); + + // Request sampling from client + CreateMessageRequest samplingRequest = CreateMessageRequest.builder() + .messages(List.of(new SamplingMessage(Role.USER, new TextContent(prompt)))) + .maxTokens(100) + .build(); + + CreateMessageResult response = exchange.createMessage(samplingRequest); + String responseText = "LLM response: " + ((TextContent) response.content()).text(); + return CallToolResult.builder() + .content(List.of(new TextContent(responseText))) + .isError(false) + .build(); + }) + .build(), + + // test_elicitation - Tool that requests user input from client + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_elicitation") + .description("Tool that requests user input from client") + .inputSchema(new JsonSchema("object", + Map.of("message", + Map.of("type", "string", "description", "The message to show the user")), + List.of("message"), null, null, null)) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_elicitation' called"); + String message = (String) request.arguments().get("message"); + + // Request elicitation from client + Map requestedSchema = Map.of("type", "object", "properties", + Map.of("username", Map.of("type", "string", "description", "User's response"), "email", + Map.of("type", "string", "description", "User's email address")), + "required", List.of("username", "email")); + + ElicitRequest elicitRequest = new ElicitRequest(message, requestedSchema); + + ElicitResult response = exchange.createElicitation(elicitRequest); + String responseText = "User response: action=" + response.action() + ", content=" + + response.content(); + return CallToolResult.builder() + .content(List.of(new TextContent(responseText))) + .isError(false) + .build(); + }) + .build(), + + // test_elicitation_sep1034_defaults - Tool with default values for all + // primitive types + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_elicitation_sep1034_defaults") + .description("Tool that requests elicitation with default values for all primitive types") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_elicitation_sep1034_defaults' called"); + + // Create schema with default values for all primitive types + Map requestedSchema = Map.of("type", "object", "properties", + Map.of("name", Map.of("type", "string", "default", "John Doe"), "age", + Map.of("type", "integer", "default", 30), "score", + Map.of("type", "number", "default", 95.5), "status", + Map.of("type", "string", "enum", List.of("active", "inactive", "pending"), + "default", "active"), + "verified", Map.of("type", "boolean", "default", true)), + "required", List.of("name", "age", "score", "status", "verified")); + + ElicitRequest elicitRequest = new ElicitRequest("Please provide your information with defaults", + requestedSchema); + + ElicitResult response = exchange.createElicitation(elicitRequest); + String responseText = "Elicitation completed: action=" + response.action() + ", content=" + + response.content(); + return CallToolResult.builder() + .content(List.of(new TextContent(responseText))) + .isError(false) + .build(); + }) + .build(), + + // test_elicitation_sep1330_enums - Tool with enum schema improvements + McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("test_elicitation_sep1330_enums") + .description("Tool that requests elicitation with enum schema improvements") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + logger.info("Tool 'test_elicitation_sep1330_enums' called"); + + // Create schema with all 5 enum variants + Map requestedSchema = Map.of("type", "object", "properties", Map.of( + // 1. Untitled single-select + "untitledSingle", + Map.of("type", "string", "enum", List.of("option1", "option2", "option3")), + // 2. Titled single-select using oneOf with const/title + "titledSingle", + Map.of("type", "string", "oneOf", + List.of(Map.of("const", "value1", "title", "First Option"), + Map.of("const", "value2", "title", "Second Option"), + Map.of("const", "value3", "title", "Third Option"))), + // 3. Legacy titled using enumNames (deprecated) + "legacyEnum", + Map.of("type", "string", "enum", List.of("opt1", "opt2", "opt3"), "enumNames", + List.of("Option One", "Option Two", "Option Three")), + // 4. Untitled multi-select + "untitledMulti", + Map.of("type", "array", "items", + Map.of("type", "string", "enum", List.of("option1", "option2", "option3"))), + // 5. Titled multi-select using items.anyOf with + // const/title + "titledMulti", + Map.of("type", "array", "items", + Map.of("anyOf", + List.of(Map.of("const", "value1", "title", "First Choice"), + Map.of("const", "value2", "title", "Second Choice"), + Map.of("const", "value3", "title", "Third Choice"))))), + "required", List.of("untitledSingle", "titledSingle", "legacyEnum", "untitledMulti", + "titledMulti")); + + ElicitRequest elicitRequest = new ElicitRequest("Select your preferences", requestedSchema); + + ElicitResult response = exchange.createElicitation(elicitRequest); + String responseText = "Elicitation completed: action=" + response.action() + ", content=" + + response.content(); + return CallToolResult.builder() + .content(List.of(new TextContent(responseText))) + .isError(false) + .build(); + }) + .build()); + } + + private static List createPromptSpecs() { + return List.of( + // test_simple_prompt - Simple prompt without arguments + new McpServerFeatures.SyncPromptSpecification( + new Prompt("test_simple_prompt", null, "A simple prompt for testing", List.of()), + (exchange, request) -> { + logger.info("Prompt 'test_simple_prompt' requested"); + return new GetPromptResult(null, List.of(new PromptMessage(Role.USER, + new TextContent("This is a simple prompt for testing.")))); + }), + + // test_prompt_with_arguments - Prompt with arguments + new McpServerFeatures.SyncPromptSpecification( + new Prompt("test_prompt_with_arguments", null, "A prompt with arguments for testing", + List.of(new PromptArgument("arg1", "First test argument", true), + new PromptArgument("arg2", "Second test argument", true))), + (exchange, request) -> { + logger.info("Prompt 'test_prompt_with_arguments' requested"); + String arg1 = (String) request.arguments().get("arg1"); + String arg2 = (String) request.arguments().get("arg2"); + String text = String.format("Prompt with arguments: arg1='%s', arg2='%s'", arg1, arg2); + return new GetPromptResult(null, + List.of(new PromptMessage(Role.USER, new TextContent(text)))); + }), + + // test_prompt_with_embedded_resource - Prompt with embedded resource + new McpServerFeatures.SyncPromptSpecification( + new Prompt("test_prompt_with_embedded_resource", null, + "A prompt with embedded resource for testing", + List.of(new PromptArgument("resourceUri", "URI of the resource to embed", true))), + (exchange, request) -> { + logger.info("Prompt 'test_prompt_with_embedded_resource' requested"); + String resourceUri = (String) request.arguments().get("resourceUri"); + TextResourceContents resourceContents = new TextResourceContents(resourceUri, "text/plain", + "Embedded resource content for testing."); + EmbeddedResource embeddedResource = new EmbeddedResource(null, resourceContents); + return new GetPromptResult(null, + List.of(new PromptMessage(Role.USER, embeddedResource), new PromptMessage(Role.USER, + new TextContent("Please process the embedded resource above.")))); + }), + + // test_prompt_with_image - Prompt with image content + new McpServerFeatures.SyncPromptSpecification(new Prompt("test_prompt_with_image", null, + "A prompt with image content for testing", List.of()), (exchange, request) -> { + logger.info("Prompt 'test_prompt_with_image' requested"); + return new GetPromptResult(null, List.of( + new PromptMessage(Role.USER, new ImageContent(null, RED_PIXEL_PNG, "image/png")), + new PromptMessage(Role.USER, new TextContent("Please analyze the image above.")))); + })); + } + + private static List createResourceSpecs() { + return List.of( + // test://static-text - Static text resource + new McpServerFeatures.SyncResourceSpecification(Resource.builder() + .uri("test://static-text") + .name("Static Text Resource") + .description("A static text resource for testing") + .mimeType("text/plain") + .build(), (exchange, request) -> { + logger.info("Resource 'test://static-text' requested"); + return new ReadResourceResult(List.of(new TextResourceContents("test://static-text", + "text/plain", "This is the content of the static text resource."))); + }), + + // test://static-binary - Static binary resource (image) + new McpServerFeatures.SyncResourceSpecification(Resource.builder() + .uri("test://static-binary") + .name("Static Binary Resource") + .description("A static binary resource for testing") + .mimeType("image/png") + .build(), (exchange, request) -> { + logger.info("Resource 'test://static-binary' requested"); + return new ReadResourceResult( + List.of(new BlobResourceContents("test://static-binary", "image/png", RED_PIXEL_PNG))); + }), + + // test://watched-resource - Resource that can be subscribed to + new McpServerFeatures.SyncResourceSpecification(Resource.builder() + .uri("test://watched-resource") + .name("Watched Resource") + .description("A resource that can be subscribed to for updates") + .mimeType("text/plain") + .build(), (exchange, request) -> { + logger.info("Resource 'test://watched-resource' requested"); + return new ReadResourceResult(List.of(new TextResourceContents("test://watched-resource", + "text/plain", "This is a watched resource content."))); + })); + } + + private static List createResourceTemplateSpecs() { + return List.of( + // test://template/{id}/data - Resource template with parameter + // substitution + new McpServerFeatures.SyncResourceTemplateSpecification(ResourceTemplate.builder() + .uriTemplate("test://template/{id}/data") + .name("Template Resource") + .description("A resource template for testing parameter substitution") + .mimeType("application/json") + .build(), (exchange, request) -> { + logger.info("Resource template 'test://template/{{id}}/data' requested for URI: {}", + request.uri()); + // Extract id from URI + String uri = request.uri(); + String id = uri.replaceAll("test://template/(.+)/data", "$1"); + String jsonContent = String + .format("{\"id\":\"%s\",\"templateTest\":true,\"data\":\"Data for ID: %s\"}", id, id); + return new ReadResourceResult( + List.of(new TextResourceContents(uri, "application/json", jsonContent))); + })); + } + + private static List createCompletionSpecs() { + return List.of( + // Completion for test_prompt_with_arguments + new McpServerFeatures.SyncCompletionSpecification(new PromptReference("test_prompt_with_arguments"), + (exchange, request) -> { + logger.info("Completion requested for prompt 'test_prompt_with_arguments', argument: {}", + request.argument().name()); + // Return minimal completion with required fields + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); + })); + } + +} diff --git a/conformance-tests/server-servlet/src/main/resources/logback.xml b/conformance-tests/server-servlet/src/main/resources/logback.xml new file mode 100644 index 000000000..af69ac902 --- /dev/null +++ b/conformance-tests/server-servlet/src/main/resources/logback.xml @@ -0,0 +1,14 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + diff --git a/docs/blog/.authors.yml b/docs/blog/.authors.yml new file mode 100644 index 000000000..7b255c403 --- /dev/null +++ b/docs/blog/.authors.yml @@ -0,0 +1,5 @@ +authors: + mcp-team: + name: MCP Java SDK Team + description: Maintainers of the MCP Java SDK + avatar: https://github.com/modelcontextprotocol.png diff --git a/docs/blog/index.md b/docs/blog/index.md new file mode 100644 index 000000000..e61459078 --- /dev/null +++ b/docs/blog/index.md @@ -0,0 +1 @@ +# News diff --git a/docs/blog/posts/mcp-server-performance-benchmark.md b/docs/blog/posts/mcp-server-performance-benchmark.md new file mode 100644 index 000000000..a08b807b6 --- /dev/null +++ b/docs/blog/posts/mcp-server-performance-benchmark.md @@ -0,0 +1,72 @@ +--- +date: 2026-02-15 +authors: + - mcp-team +categories: + - Performance + - Benchmarks +--- + +# Java Leads MCP Server Performance Benchmarks with Sub-Millisecond Latency + +A comprehensive independent benchmark of MCP server implementations across four major languages puts Java at the top of the performance charts β€” delivering sub-millisecond latency, the highest throughput, and the best CPU efficiency of all tested platforms. + + + +## The Benchmark + +[TM Dev Lab](https://www.tmdevlab.com/mcp-server-performance-benchmark.html) published a rigorous performance comparison of MCP server implementations spanning **3.9 million total requests** across three independent test rounds. The benchmark evaluated four implementations under identical conditions: + +- **Java** β€” Spring Boot 4.0.0 + Spring AI 2.0.0-M2 on Java 21 +- **Go** β€” Official MCP SDK v1.2.0 +- **Node.js** β€” @modelcontextprotocol/sdk v1.26.0 +- **Python** β€” FastMCP 2.12.0+ with FastAPI 0.109.0+ + +Each server was tested with 50 concurrent virtual users over 5-minute sustained runs in Docker containers (1-core CPU, 1GB memory) on Ubuntu 24.04.3 LTS. Four standardized benchmark tools measured CPU-intensive, I/O-intensive, data transformation, and latency-handling scenarios β€” all with a **0% error rate** across every implementation. + +## Java's Performance Highlights + +The results speak for themselves: + +| Server | Avg Latency | Throughput (RPS) | CPU Efficiency (RPS/CPU%) | +|------------|-------------|------------------|---------------------------| +| **Java** | **0.835 ms** | **1,624** | **57.2** | +| Go | 0.855 ms | 1,624 | 50.4 | +| Node.js | 10.66 ms | 559 | 5.7 | +| Python | 26.45 ms | 292 | 3.2 | + +```mermaid +--- +config: + xyChart: + width: 700 + height: 400 + themeVariables: + xyChart: + backgroundColor: transparent +--- +xychart-beta + title "Average Latency Comparison (milliseconds)" + x-axis [Java, Go, "Node.js", Python] + y-axis "Latency (ms)" 0 --> 30 + bar [0.84, 0.86, 10.66, 26.45] +``` + +Java achieved the **lowest average latency** at 0.835 ms β€” edging out Go's 0.855 ms β€” while matching its throughput at 1,624 requests per second. Where Java truly stands out is **CPU efficiency**: at 57.2 RPS per CPU%, it extracts more performance per compute cycle than any other implementation, including Go (50.4). + +In CPU-bound workloads like Fibonacci calculation, Java excelled with a **0.369 ms** response time, showcasing the JVM's highly optimized just-in-time compilation. + +## A Clear Performance Tier + +The benchmark reveals two distinct performance tiers: + +- **High-performance tier**: Java and Go deliver sub-millisecond latencies and 1,600+ RPS +- **Standard tier**: Node.js (12x slower) and Python (31x slower) trail significantly + +Java's throughput is **2.9x higher than Node.js** and **5.6x higher than Python**. For latency-sensitive MCP deployments, the difference is even more pronounced β€” Java responds **12.8x faster than Node.js** and **31.7x faster than Python**. + +## What This Means for MCP Developers + +For teams building production MCP servers that need to handle high concurrency and low-latency tool interactions, Java with Spring Boot and Spring AI provides a battle-tested, high-performance foundation. The JVM's mature ecosystem, strong typing, and proven scalability make it an excellent choice for enterprise MCP deployments where performance and reliability are paramount. + +The full benchmark details, methodology, and raw data are available at [TM Dev Lab](https://www.tmdevlab.com/mcp-server-performance-benchmark.html). diff --git a/docs/client.md b/docs/client.md new file mode 100644 index 000000000..6a99928c5 --- /dev/null +++ b/docs/client.md @@ -0,0 +1,439 @@ +--- +title: MCP Client +description: Learn how to use the Model Context Protocol (MCP) client to interact with MCP servers +--- + +# MCP Client + +The MCP Client is a key component in the Model Context Protocol (MCP) architecture, responsible for establishing and managing connections with MCP servers. It implements the client-side of the protocol, handling: + +- Protocol version negotiation to ensure compatibility with servers +- Capability negotiation to determine available features +- Message transport and JSON-RPC communication +- Tool discovery and execution with optional schema validation +- Resource access and management +- Prompt system interactions +- Optional features like roots management, sampling, and elicitation support +- Progress tracking for long-running operations + +!!! tip + The core `io.modelcontextprotocol.sdk:mcp` module provides STDIO, SSE, and Streamable HTTP client transport implementations without requiring external web frameworks. + + The Spring-specific WebFlux transport (`mcp-spring-webflux`) is now part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`) and is no longer shipped by this SDK. + See the [MCP Client Boot Starter](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-client-boot-starter-docs.html) documentation for Spring-based client setup. + +The client provides both synchronous and asynchronous APIs for flexibility in different application contexts. + +=== "Sync API" + + ```java + // Create a sync client with custom configuration + McpSyncClient client = McpClient.sync(transport) + .requestTimeout(Duration.ofSeconds(10)) + .capabilities(ClientCapabilities.builder() + .roots(true) // Enable roots capability + .sampling() // Enable sampling capability + .elicitation() // Enable elicitation capability + .build()) + .sampling(request -> new CreateMessageResult(response)) + .elicitation(request -> new ElicitResult(ElicitResult.Action.ACCEPT, content)) + .build(); + + // Initialize connection + client.initialize(); + + // List available tools + ListToolsResult tools = client.listTools(); + + // Call a tool + CallToolResult result = client.callTool( + new CallToolRequest("calculator", + Map.of("operation", "add", "a", 2, "b", 3)) + ); + + // List and read resources + ListResourcesResult resources = client.listResources(); + ReadResourceResult resource = client.readResource( + new ReadResourceRequest("resource://uri") + ); + + // List and use prompts + ListPromptsResult prompts = client.listPrompts(); + GetPromptResult prompt = client.getPrompt( + new GetPromptRequest("greeting", Map.of("name", "Spring")) + ); + + // Add/remove roots + client.addRoot(new Root("file:///path", "description")); + client.removeRoot("file:///path"); + + // Close client + client.closeGracefully(); + ``` + +=== "Async API" + + ```java + // Create an async client with custom configuration + McpAsyncClient client = McpClient.async(transport) + .requestTimeout(Duration.ofSeconds(10)) + .capabilities(ClientCapabilities.builder() + .roots(true) // Enable roots capability + .sampling() // Enable sampling capability + .elicitation() // Enable elicitation capability + .build()) + .sampling(request -> Mono.just(new CreateMessageResult(response))) + .elicitation(request -> Mono.just(new ElicitResult(ElicitResult.Action.ACCEPT, content))) + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> { + logger.info("Tools updated: {}", tools); + })) + .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> { + logger.info("Resources updated: {}", resources); + })) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> { + logger.info("Prompts updated: {}", prompts); + })) + .progressConsumer(progress -> Mono.fromRunnable(() -> { + logger.info("Progress: {}", progress); + })) + .build(); + + // Initialize connection and use features + client.initialize() + .flatMap(initResult -> client.listTools()) + .flatMap(tools -> { + return client.callTool(new CallToolRequest( + "calculator", + Map.of("operation", "add", "a", 2, "b", 3) + )); + }) + .flatMap(result -> { + return client.listResources() + .flatMap(resources -> + client.readResource(new ReadResourceRequest("resource://uri")) + ); + }) + .flatMap(resource -> { + return client.listPrompts() + .flatMap(prompts -> + client.getPrompt(new GetPromptRequest( + "greeting", + Map.of("name", "Spring") + )) + ); + }) + .flatMap(prompt -> { + return client.addRoot(new Root("file:///path", "description")) + .then(client.removeRoot("file:///path")); + }) + .doFinally(signalType -> { + client.closeGracefully().subscribe(); + }) + .subscribe(); + ``` + +## Client Transport + +The transport layer handles the communication between MCP clients and servers, providing different implementations for various use cases. The client transport manages message serialization, connection establishment, and protocol-specific communication patterns. + +### STDIO + +Creates transport for process-based communication using stdin/stdout: + +```java +ServerParameters params = ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything", "dir") + .build(); +McpTransport transport = new StdioClientTransport(params); +``` + +### Streamable HTTP + +=== "Streamable HttpClient" + + Creates a Streamable HTTP client transport for efficient bidirectional communication. Included in the core `mcp` module: + + ```java + McpTransport transport = HttpClientStreamableHttpTransport + .builder("http://your-mcp-server") + .endpoint("/mcp") + .build(); + ``` + + The Streamable HTTP transport supports: + + - Resumable streams for connection recovery + - Configurable connect timeout + - Custom HTTP request customization + - Multiple protocol version negotiation + +=== "Streamable WebClient (external)" + + Creates Streamable HTTP WebClient-based client transport. Requires the `mcp-spring-webflux` dependency from [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```java + McpTransport transport = WebFluxSseClientTransport + .builder(WebClient.builder().baseUrl("http://your-mcp-server")) + .build(); + ``` + +### SSE HTTP (Legacy) + +=== "SSE HttpClient" + + Creates a framework-agnostic (pure Java API) SSE client transport. Included in the core `mcp` module: + + ```java + McpTransport transport = new HttpClientSseClientTransport("http://your-mcp-server"); + ``` +=== "SSE WebClient (external)" + + Creates WebFlux-based SSE client transport. Requires the `mcp-spring-webflux` dependency from [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```java + WebClient.Builder webClientBuilder = WebClient.builder() + .baseUrl("http://your-mcp-server"); + McpTransport transport = new WebFluxSseClientTransport(webClientBuilder); + ``` + + +## Client Capabilities + +The client can be configured with various capabilities: + +```java +var capabilities = ClientCapabilities.builder() + .roots(true) // Enable filesystem roots support with list changes notifications + .sampling() // Enable LLM sampling support + .elicitation() // Enable elicitation support (form and URL modes) + .build(); +``` + +You can also configure elicitation with specific mode support: + +```java +var capabilities = ClientCapabilities.builder() + .elicitation(true, false) // Enable form-based elicitation, disable URL-based + .build(); +``` + +### Roots Support + +Roots define the boundaries of where servers can operate within the filesystem: + +```java +// Add a root dynamically +client.addRoot(new Root("file:///path", "description")); + +// Remove a root +client.removeRoot("file:///path"); + +// Notify server of roots changes +client.rootsListChangedNotification(); +``` + +The roots capability allows servers to: + +- Request the list of accessible filesystem roots +- Receive notifications when the roots list changes +- Understand which directories and files they have access to + +### Sampling Support + +Sampling enables servers to request LLM interactions ("completions" or "generations") through the client: + +```java +// Configure sampling handler +Function samplingHandler = request -> { + // Sampling implementation that interfaces with LLM + return new CreateMessageResult(response); +}; + +// Create client with sampling support +var client = McpClient.sync(transport) + .capabilities(ClientCapabilities.builder() + .sampling() + .build()) + .sampling(samplingHandler) + .build(); +``` + +This capability allows: + +- Servers to leverage AI capabilities without requiring API keys +- Clients to maintain control over model access and permissions +- Support for both text and image-based interactions +- Optional inclusion of MCP server context in prompts + +### Elicitation Support + +Elicitation enables servers to request additional information or user input through the client. This is useful when a server needs clarification or confirmation during an operation: + +```java +// Configure elicitation handler +Function elicitationHandler = request -> { + // Present the request to the user and collect their response + // The request contains a message and a schema describing the expected input + Map userResponse = collectUserInput(request.message(), request.requestedSchema()); + return new ElicitResult(ElicitResult.Action.ACCEPT, userResponse); +}; + +// Create client with elicitation support +var client = McpClient.sync(transport) + .capabilities(ClientCapabilities.builder() + .elicitation() + .build()) + .elicitation(elicitationHandler) + .build(); +``` + +The `ElicitResult` supports three actions: + +- `ACCEPT` - The user accepted and provided the requested information +- `DECLINE` - The user declined to provide the information +- `CANCEL` - The operation was cancelled + +### Logging Support + +The client can register a logging consumer to receive log messages from the server and set the minimum logging level to filter messages: + +```java +var mcpClient = McpClient.sync(transport) + .loggingConsumer(notification -> { + System.out.println("Received log message: " + notification.data()); + }) + .build(); + +mcpClient.initialize(); + +mcpClient.setLoggingLevel(McpSchema.LoggingLevel.INFO); + +// Call the tool that sends logging notifications +CallToolResult result = mcpClient.callTool(new CallToolRequest("logging-test", Map.of())); +``` + +Clients can control the minimum logging level they receive through the `mcpClient.setLoggingLevel(level)` request. Messages below the set level will be filtered out. +Supported logging levels (in order of increasing severity): DEBUG (0), INFO (1), NOTICE (2), WARNING (3), ERROR (4), CRITICAL (5), ALERT (6), EMERGENCY (7) + +### Progress Notifications + +The client can register a progress consumer to track the progress of long-running operations: + +```java +var mcpClient = McpClient.sync(transport) + .progressConsumer(progress -> { + System.out.println("Progress: " + progress.progress() + "/" + progress.total()); + }) + .build(); +``` + +## Using MCP Clients + +### Tool Execution + +Tools are server-side functions that clients can discover and execute. The MCP client provides methods to list available tools and execute them with specific parameters. Each tool has a unique name and accepts a map of parameters. + +=== "Sync API" + + ```java + // List available tools + ListToolsResult tools = client.listTools(); + + // Call a tool with a CallToolRequest + CallToolResult result = client.callTool( + new CallToolRequest("calculator", Map.of( + "operation", "add", + "a", 1, + "b", 2 + )) + ); + ``` + +=== "Async API" + + ```java + // List available tools asynchronously + client.listTools() + .doOnNext(tools -> tools.tools().forEach(tool -> + System.out.println(tool.name()))) + .subscribe(); + + // Call a tool asynchronously + client.callTool(new CallToolRequest("calculator", Map.of( + "operation", "add", + "a", 1, + "b", 2 + ))) + .subscribe(); + ``` + +### Tool Schema Validation and Caching + +The client supports optional JSON schema validation for tool call results and automatic schema caching: + +```java +var client = McpClient.sync(transport) + .jsonSchemaValidator(myValidator) // Enable schema validation + .enableCallToolSchemaCaching(true) // Cache tool schemas + .build(); +``` + +### Resource Access + +Resources represent server-side data sources that clients can access using URI templates. The MCP client provides methods to discover available resources and retrieve their contents through a standardized interface. + +=== "Sync API" + + ```java + // List available resources + ListResourcesResult resources = client.listResources(); + + // Read a resource + ReadResourceResult resource = client.readResource( + new ReadResourceRequest("resource://uri") + ); + ``` + +=== "Async API" + + ```java + // List available resources asynchronously + client.listResources() + .doOnNext(resources -> resources.resources().forEach(resource -> + System.out.println(resource.name()))) + .subscribe(); + + // Read a resource asynchronously + client.readResource(new ReadResourceRequest("resource://uri")) + .subscribe(); + ``` + +### Prompt System + +The prompt system enables interaction with server-side prompt templates. These templates can be discovered and executed with custom parameters, allowing for dynamic text generation based on predefined patterns. + +=== "Sync API" + + ```java + // List available prompt templates + ListPromptsResult prompts = client.listPrompts(); + + // Get a prompt with parameters + GetPromptResult prompt = client.getPrompt( + new GetPromptRequest("greeting", Map.of("name", "World")) + ); + ``` + +=== "Async API" + + ```java + // List available prompt templates asynchronously + client.listPrompts() + .doOnNext(prompts -> prompts.prompts().forEach(prompt -> + System.out.println(prompt.name()))) + .subscribe(); + + // Get a prompt asynchronously + client.getPrompt(new GetPromptRequest("greeting", Map.of("name", "World"))) + .subscribe(); + ``` diff --git a/docs/contribute.md b/docs/contribute.md new file mode 100644 index 000000000..3199dd51f --- /dev/null +++ b/docs/contribute.md @@ -0,0 +1,106 @@ +--- +title: Contributing +description: How to contribute to the MCP Java SDK +--- + +# Contributing + +Thank you for your interest in contributing to the Model Context Protocol Java SDK! +This guide outlines how to contribute to this project. + +## Prerequisites + +!!! info "Required Software" + - **Java 17** or above + - **Docker** + - **npx** + +## Getting Started + +1. Fork the repository +2. Clone your fork: + + ```bash + git clone https://github.com/YOUR-USERNAME/java-sdk.git + cd java-sdk + ``` + +3. Build from source: + + ```bash + ./mvnw clean install -DskipTests # skip the tests + ./mvnw test # run tests + ``` + +## Reporting Issues + +Please create an issue in the repository if you discover a bug or would like to +propose an enhancement. Bug reports should have a reproducer in the form of a code +sample or a repository attached that the maintainers or contributors can work with to +address the problem. + +## Making Changes + +1. Create a new branch: + + ```bash + git checkout -b feature/your-feature-name + ``` + +2. Make your changes. + +3. Validate your changes: + + ```bash + ./mvnw clean test + ``` + +### Change Proposal Guidelines + +#### Principles of MCP + +1. **Simple + Minimal**: It is much easier to add things to the codebase than it is to + remove them. To maintain simplicity, we keep a high bar for adding new concepts and + primitives as each addition requires maintenance and compatibility consideration. +2. **Concrete**: Code changes need to be based on specific usage and implementation + challenges and not on speculative ideas. Most importantly, the SDK is meant to + implement the MCP specification. + +## Submitting Changes + +1. For non-trivial changes, please clarify with the maintainers in an issue whether + you can contribute the change and the desired scope of the change. +2. For trivial changes (for example a couple of lines or documentation changes) there + is no need to open an issue first. +3. Push your changes to your fork. +4. Submit a pull request to the main repository. +5. Follow the pull request template. +6. Wait for review. +7. For any follow-up work, please add new commits instead of force-pushing. This will + allow the reviewer to focus on incremental changes instead of having to restart the + review process. + +## Code of Conduct + +This project follows a Code of Conduct. Please review it in +[CODE_OF_CONDUCT.md](https://github.com/modelcontextprotocol/java-sdk/blob/main/CODE_OF_CONDUCT.md). + +## Questions + +If you have questions, please create a discussion in the repository. + +## License + +By contributing, you agree that your contributions will be licensed under the MIT +License. + +## Security + +This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model Context Protocol project. + +The security of our systems and user data is Anthropic's top priority. We appreciate the work of security researchers acting in good faith in identifying and reporting potential vulnerabilities. + +!!! warning "Reporting Security Vulnerabilities" + Do **not** report security vulnerabilities through public GitHub issues. Instead, report them through our HackerOne [submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). + +Our Vulnerability Disclosure Program guidelines are defined on our [HackerOne program page](https://hackerone.com/anthropic-vdp). diff --git a/docs/development.md b/docs/development.md new file mode 100644 index 000000000..e00c7268b --- /dev/null +++ b/docs/development.md @@ -0,0 +1,75 @@ +--- +title: Documentation +description: How to contribute to the MCP Java SDK documentation +--- + +# Documentation Development + +This guide covers how to set up and preview the MCP Java SDK documentation locally. + +!!! info "Prerequisites" + - Python 3.x + - pip (Python package manager) + +## Setup + +Install mkdocs-material: + +```bash +pip install mkdocs-material +``` + +## Preview Locally + +From the project root directory, run: + +```bash +mkdocs serve +``` + +A local preview of the documentation will be available at `http://localhost:8000`. + +### Custom Ports + +By default, mkdocs uses port 8000. You can customize the port with the `-a` flag: + +```bash +mkdocs serve -a localhost:3333 +``` + +## Building + +To build the static site for deployment: + +```bash +mkdocs build +``` + +The built site will be output to the `site/` directory. + +## Project Structure + +``` +docs/ +β”œβ”€β”€ index.md # Overview page +β”œβ”€β”€ quickstart.md # Quickstart guide +β”œβ”€β”€ client.md # MCP Client documentation +β”œβ”€β”€ server.md # MCP Server documentation +β”œβ”€β”€ contributing.md # Contributing guide +β”œβ”€β”€ development.md # This page +β”œβ”€β”€ images/ # Images and diagrams +└── stylesheets/ # Custom CSS +mkdocs.yml # MkDocs configuration +``` + +## Writing Guidelines + +- Documentation pages use standard Markdown with [mkdocs-material extensions](https://squidfunk.github.io/mkdocs-material/reference/) +- Use content tabs (`=== "Tab Label"`) for Maven/Gradle or Sync/Async code examples +- Use admonitions (`!!! tip`, `!!! info`, `!!! warning`) for callouts +- All code blocks should specify a language for syntax highlighting +- Images go in the `docs/images/` directory + +## IDE Support + +We suggest using extensions on your IDE to recognize and format Markdown. If you're a VSCode user, consider the [Markdown All in One](https://marketplace.visualstudio.com/items?itemName=yzhang.markdown-all-in-one) extension for enhanced Markdown support, and [Prettier](https://marketplace.visualstudio.com/items?itemName=esbenp.prettier-vscode) for code formatting. diff --git a/docs/images/favicon.svg b/docs/images/favicon.svg new file mode 100644 index 000000000..fe5edb725 --- /dev/null +++ b/docs/images/favicon.svg @@ -0,0 +1,69 @@ + + + + + + + + + + + + + + diff --git a/docs/images/java-mcp-client-architecture.jpg b/docs/images/java-mcp-client-architecture.jpg new file mode 100644 index 000000000..688a2b4ad Binary files /dev/null and b/docs/images/java-mcp-client-architecture.jpg differ diff --git a/docs/images/java-mcp-server-architecture.jpg b/docs/images/java-mcp-server-architecture.jpg new file mode 100644 index 000000000..4b05ca139 Binary files /dev/null and b/docs/images/java-mcp-server-architecture.jpg differ diff --git a/docs/images/java-mcp-uml-classdiagram.svg b/docs/images/java-mcp-uml-classdiagram.svg new file mode 100644 index 000000000..f83a586e7 --- /dev/null +++ b/docs/images/java-mcp-uml-classdiagram.svg @@ -0,0 +1 @@ +McpTransportMono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler)Mono<Void> sendMessage(JSONRPCMessage message)void close()Mono<Void> closeGracefully()<T> T unmarshalFrom(Object data, TypeReference<T> typeRef)McpSession<T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef)Mono<Void> sendNotification(String method, Map<String, Object> params)Mono<Void> closeGracefully()void close()DefaultMcpSessioninterface RequestHandlerinterface NotificationHandlerMcpClientBuilder using(ClientMcpTransport transport)McpAsyncClientMono<InitializeResult> initialize()ServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()Mono<Void> closeGracefully()Mono<Object> ping()Mono<Void> addRoot(Root root)Mono<Void> removeRoot(String rootUri)Mono<Void> rootsListChangedNotification()Mono<CallToolResult> callTool(CallToolRequest request)Mono<ListToolsResult> listTools()Mono<ListResourcesResult> listResources()Mono<ReadResourceResult> readResource(ReadResourceRequest request)Mono<ListResourceTemplatesResult> listResourceTemplates()Mono<Void> subscribeResource(SubscribeRequest request)Mono<Void> unsubscribeResource(UnsubscribeRequest request)Mono<ListPromptsResult> listPrompts()Mono<GetPromptResult> getPrompt(GetPromptRequest request)Mono<Void> setLoggingLevel(LoggingLevel level)McpSyncClientInitializeResult initialize()ServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()boolean closeGracefully()Object ping()void addRoot(Root root)void removeRoot(String rootUri)void rootsListChangedNotification()CallToolResult callTool(CallToolRequest request)ListToolsResult listTools()ListResourcesResult listResources()ReadResourceResult readResource(ReadResourceRequest request)ListResourceTemplatesResult listResourceTemplates()void subscribeResource(SubscribeRequest request)void unsubscribeResource(UnsubscribeRequest request)ListPromptsResult listPrompts()GetPromptResult getPrompt(GetPromptRequest request)void setLoggingLevel(LoggingLevel level)McpServerBuilder using(ServerMcpTransport transport)McpAsyncServerServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()Mono<Void> closeGracefully() Mono<Void> addTool(ToolRegistration toolRegistration)Mono<Void> removeTool(String toolName)Mono<Void> notifyToolsListChanged() Mono<Void> addResource(ResourceRegistration resourceHandler)Mono<Void> removeResource(String resourceUri)Mono<Void> notifyResourcesListChanged() Mono<Void> addPrompt(PromptRegistration promptRegistration)Mono<Void> removePrompt(String promptName)Mono<Void> notifyPromptsListChanged() Mono<Void> loggingNotification(LoggingMessageNotification notification) Mono<CreateMessageResult> createMessage(CreateMessageRequest request)McpSyncServerMcpAsyncServer getAsyncServer() ServerCapabilities getServerCapabilities()Implementation getServerInfo()ClientCapabilities getClientCapabilities()Implementation getClientInfo()void close()void closeGracefully() void addTool(ToolRegistration toolHandler)void removeTool(String toolName)void notifyToolsListChanged() void addResource(ResourceRegistration resourceHandler)void removeResource(String resourceUri)void notifyResourcesListChanged() void addPrompt(PromptRegistration promptRegistration)void removePrompt(String promptName)void notifyPromptsListChanged() void loggingNotification(LoggingMessageNotification notification) CreateMessageResult createMessage(CreateMessageRequest request)StdioClientTransportvoid setErrorHandler(Consumer<String> errorHandler)Sinks.Many<String> getErrorSink()ClientMcpTransportStdioServerTransportServerMcpTransportHttpServletSseServerTransportHttpClientSseClientTransportWebFluxSseClientTransportWebFluxSseServerTransportRouterFunction<?> getRouterFunction()WebMvcSseServerTransportRouterFunction<?> getRouterFunction()McpSchemaclass ErrorCodesinterface Requestinterface JSONRPCMessageinterface ResourceContentsinterface Contentinterface ServerCapabilitiesJSONRPCMessage deserializeJsonRpcMessage()McpErrorcreatescreatesdelegates tocreatescreatesusesthrows \ No newline at end of file diff --git a/docs/images/logo-dark.svg b/docs/images/logo-dark.svg new file mode 100644 index 000000000..03d9f85d3 --- /dev/null +++ b/docs/images/logo-dark.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/docs/images/logo-light.svg b/docs/images/logo-light.svg new file mode 100644 index 000000000..fe5edb725 --- /dev/null +++ b/docs/images/logo-light.svg @@ -0,0 +1,69 @@ + + + + + + + + + + + + + + diff --git a/docs/images/mcp-stack.svg b/docs/images/mcp-stack.svg new file mode 100644 index 000000000..3847eaa8d --- /dev/null +++ b/docs/images/mcp-stack.svg @@ -0,0 +1,197 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000..e6062b5ff --- /dev/null +++ b/docs/index.md @@ -0,0 +1,84 @@ +--- +title: Index +description: Introduction to the Model Context Protocol (MCP) Java SDK +--- + +# MCP Java SDK + +Java SDK for the [Model Context Protocol](https://modelcontextprotocol.io/docs/concepts/architecture) +enables standardized integration between AI models and tools. + +## Features + +- MCP Client and MCP Server implementations supporting: + - Protocol [version compatibility negotiation](https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle#initialization) with multiple protocol versions + - [Tools](https://modelcontextprotocol.io/specification/2025-11-25/server/tools) discovery, execution, list change notifications, and structured output with schema validation + - [Resources](https://modelcontextprotocol.io/specification/2025-11-25/server/resources) management with URI templates + - [Roots](https://modelcontextprotocol.io/specification/2025-11-25/client/roots) list management and notifications + - [Prompts](https://modelcontextprotocol.io/specification/2025-11-25/server/prompts) handling and management + - [Sampling](https://modelcontextprotocol.io/specification/2025-11-25/client/sampling) support for AI model interactions + - [Elicitation](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation) support for requesting user input from servers + - [Completions](https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/completion) for argument autocompletion suggestions + - [Progress](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/progress) - progress notifications for tracking long-running operations + - [Logging](https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/logging) - structured logging with configurable severity levels +- Multiple transport implementations: + - Default transports (included in core `mcp` module, no external web frameworks required): + - [STDIO](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#stdio)-based transport for process-based communication + - Java HttpClient-based SSE client transport for HTTP SSE Client-side streaming + - Servlet-based SSE server transport for HTTP SSE Server streaming + - [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http) transport for efficient bidirectional communication (client and server) + - Optional Spring-based transports (available in [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+, no longer part of this SDK): + - WebFlux SSE client and server transports for reactive HTTP streaming + - WebFlux Streamable HTTP server transport + - WebMVC SSE server transport for servlet-based HTTP streaming + - WebMVC Streamable HTTP server transport + - WebMVC Stateless server transport +- Supports Synchronous and Asynchronous programming paradigms +- Pluggable JSON serialization (Jackson 2.x and Jackson 3.x) +- Pluggable authorization hooks for server security +- DNS rebinding protection with Host/Origin header validation + +!!! tip + The core `io.modelcontextprotocol.sdk:mcp` module provides default STDIO, SSE, and Streamable HTTP client and server transport implementations without requiring external web frameworks. + + Spring-specific transports (WebFlux, WebMVC) are now part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ and are no longer shipped by this SDK. + Use the [MCP Client Boot Starter](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-client-boot-starter-docs.html) and [MCP Server Boot Starter](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-server-boot-starter-docs.html) from Spring AI. + Also consider the [MCP Annotations](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-annotations-overview.html) and [MCP Security](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-security.html). + +## Next Steps + +

+ +- :rocket:{ .lg .middle } **Quickstart** + + --- + + Get started with dependencies and BOM configuration. + + [:octicons-arrow-right-24: Quickstart](quickstart.md) + +- :material-monitor:{ .lg .middle } **MCP Client** + + --- + + Learn how to create and configure MCP clients. + + [:octicons-arrow-right-24: Client](client.md) + +- :material-server:{ .lg .middle } **MCP Server** + + --- + + Learn how to implement and configure MCP servers. + + [:octicons-arrow-right-24: Server](server.md) + +- :fontawesome-brands-github:{ .lg .middle } **GitHub** + + --- + + View the source code and contribute. + + [:octicons-arrow-right-24: Repository](https://github.com/modelcontextprotocol/java-sdk) + +
diff --git a/docs/overview.md b/docs/overview.md new file mode 100644 index 000000000..9084b6a6a --- /dev/null +++ b/docs/overview.md @@ -0,0 +1,93 @@ +--- +title: Overview +description: Introduction to the Model Context Protocol (MCP) Java SDK +--- + +# Overview + +## Architecture + +The SDK follows a layered architecture with clear separation of concerns: + +![MCP Stack Architecture](images/mcp-stack.svg) + +- **Client/Server Layer (McpClient/McpServer)**: Both use McpSession for sync/async operations, + with McpClient handling client-side protocol operations and McpServer managing server-side protocol operations. +- **Session Layer (McpSession)**: Manages communication patterns and state. +- **Transport Layer (McpTransport)**: Handles JSON-RPC message serialization/deserialization via: + - StdioTransport (stdin/stdout) in the core module + - HTTP SSE transports in dedicated transport modules (Java HttpClient, Servlet) + - Streamable HTTP transports for efficient bidirectional communication + - Spring WebFlux and Spring WebMVC transports (available in [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+) + +The MCP Client is a key component in the Model Context Protocol (MCP) architecture, responsible for establishing and managing connections with MCP servers. +It implements the client-side of the protocol. + +![Java MCP Client Architecture](images/java-mcp-client-architecture.jpg) + +The MCP Server is a foundational component in the Model Context Protocol (MCP) architecture that provides tools, resources, and capabilities to clients. +It implements the server-side of the protocol. + +![Java MCP Server Architecture](images/java-mcp-server-architecture.jpg) + +Key Interactions: + +- **Client/Server Initialization**: Transport setup, protocol compatibility check, capability negotiation, and implementation details exchange. +- **Message Flow**: JSON-RPC message handling with validation, type-safe response processing, and error handling. +- **Resource Management**: Resource discovery, URI template-based access, subscription system, and content retrieval. + +## Module Structure + +The SDK is organized into modules to separate concerns and allow adopters to bring in only what they need: + +| Module | Artifact ID | Group | Purpose | +|--------|------------|-------|---------| +| `mcp-bom` | `mcp-bom` | `io.modelcontextprotocol.sdk` | Bill of Materials for dependency management | +| `mcp-core` | `mcp-core` | `io.modelcontextprotocol.sdk` | Core reference implementation (STDIO, JDK HttpClient, Servlet, Streamable HTTP) | +| `mcp-json-jackson2` | `mcp-json-jackson2` | `io.modelcontextprotocol.sdk` | Jackson 2.x JSON serialization implementation | +| `mcp-json-jackson3` | `mcp-json-jackson3` | `io.modelcontextprotocol.sdk` | Jackson 3.x JSON serialization implementation | +| `mcp` | `mcp` | `io.modelcontextprotocol.sdk` | Convenience bundle (`mcp-core` + `mcp-json-jackson3`) | +| `mcp-test` | `mcp-test` | `io.modelcontextprotocol.sdk` | Shared testing utilities and integration tests | +| `mcp-spring-webflux` _(external)_ | `mcp-spring-webflux` | `org.springframework.ai` | Spring WebFlux integration β€” part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ | +| `mcp-spring-webmvc` _(external)_ | `mcp-spring-webmvc` | `org.springframework.ai` | Spring WebMVC integration β€” part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ | + +!!! tip + A minimal adopter may depend only on `mcp` (core + Jackson 3). Spring-based applications should use the `mcp-spring-webflux` or `mcp-spring-webmvc` artifacts from [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`), no longer part of this SDK. + +## Next Steps + +
+ +- :rocket:{ .lg .middle } **Quickstart** + + --- + + Get started with dependencies and BOM configuration. + + [:octicons-arrow-right-24: Quickstart](quickstart.md) + +- :material-monitor:{ .lg .middle } **MCP Client** + + --- + + Learn how to create and configure MCP clients. + + [:octicons-arrow-right-24: Client](client.md) + +- :material-server:{ .lg .middle } **MCP Server** + + --- + + Learn how to implement and configure MCP servers. + + [:octicons-arrow-right-24: Server](server.md) + +- :fontawesome-brands-github:{ .lg .middle } **GitHub** + + --- + + View the source code and contribute. + + [:octicons-arrow-right-24: Repository](https://github.com/modelcontextprotocol/java-sdk) + +
diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 000000000..e7e76bc88 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,163 @@ +--- +title: Quickstart +description: Get started with the MCP Java SDK dependencies and configuration +--- + +# Quickstart + +## Dependencies + +Add the following dependency to your project: + +=== "Maven" + + The convenience `mcp` module bundles `mcp-core` with Jackson 3.x JSON serialization: + + ```xml + + io.modelcontextprotocol.sdk + mcp + + ``` + + This includes default STDIO, SSE, and Streamable HTTP transport implementations without requiring external web frameworks. + + If you need only the core module without a JSON implementation (e.g., to bring your own): + + ```xml + + io.modelcontextprotocol.sdk + mcp-core + + ``` + + For Jackson 2.x instead of Jackson 3.x: + + ```xml + + io.modelcontextprotocol.sdk + mcp-core + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + + ``` + + If you're using Spring Framework, the Spring-specific transport implementations are now part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```xml + + + org.springframework.ai + mcp-spring-webflux + + + + + org.springframework.ai + mcp-spring-webmvc + + ``` + + !!! note + When using the `spring-ai-bom` or Spring AI starter dependencies (`spring-ai-starter-mcp-server-webflux`, `spring-ai-starter-mcp-server-webmvc`, `spring-ai-starter-mcp-client-webflux`) no explicit version is needed β€” the BOM manages it automatically. + +=== "Gradle" + + The convenience `mcp` module bundles `mcp-core` with Jackson 3.x JSON serialization: + + ```groovy + dependencies { + implementation "io.modelcontextprotocol.sdk:mcp" + } + ``` + + This includes default STDIO, SSE, and Streamable HTTP transport implementations without requiring external web frameworks. + + If you need only the core module without a JSON implementation (e.g., to bring your own): + + ```groovy + dependencies { + implementation "io.modelcontextprotocol.sdk:mcp-core" + } + ``` + + For Jackson 2.x instead of Jackson 3.x: + + ```groovy + dependencies { + implementation "io.modelcontextprotocol.sdk:mcp-core" + implementation "io.modelcontextprotocol.sdk:mcp-json-jackson2" + } + ``` + + If you're using Spring Framework, the Spring-specific transport implementations are now part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```groovy + // Optional: Spring WebFlux-based SSE and Streamable HTTP client and server transport (Spring AI 2.0+) + dependencies { + implementation "org.springframework.ai:mcp-spring-webflux" + } + + // Optional: Spring WebMVC-based SSE and Streamable HTTP server transport (Spring AI 2.0+) + dependencies { + implementation "org.springframework.ai:mcp-spring-webmvc" + } + ``` + +## Bill of Materials (BOM) + +The Bill of Materials (BOM) declares the recommended versions of all the dependencies used by a given release. +Using the BOM from your application's build script avoids the need for you to specify and maintain the dependency versions yourself. +Instead, the version of the BOM you're using determines the utilized dependency versions. +It also ensures that you're using supported and tested versions of the dependencies by default, unless you choose to override them. + +Add the BOM to your project: + +=== "Maven" + + ```xml + + + + io.modelcontextprotocol.sdk + mcp-bom + 1.0.0 + pom + import + + + + ``` + +=== "Gradle" + + ```groovy + dependencies { + implementation platform("io.modelcontextprotocol.sdk:mcp-bom:1.0.0") + //... + } + ``` + + Gradle users can also leverage Gradle (5.0+) native support for declaring dependency constraints using a Maven BOM. + This is implemented by adding a 'platform' dependency handler method to the dependencies section of your Gradle build script. + As shown in the snippet above this can then be followed by version-less declarations of the dependencies. + +Replace the version number with the latest version from [Maven Central](https://central.sonatype.com/artifact/io.modelcontextprotocol.sdk/mcp). + +## Available Dependencies + +The following dependencies are available and managed by the BOM: + +- **Core Dependencies** + - `io.modelcontextprotocol.sdk:mcp-core` - Core MCP library providing the base functionality, APIs, and default transport implementations (STDIO, SSE, Streamable HTTP). JSON binding is abstracted for pluggability. + - `io.modelcontextprotocol.sdk:mcp` - Convenience bundle that combines `mcp-core` with `mcp-json-jackson3` for out-of-the-box usage. +- **JSON Serialization** + - `io.modelcontextprotocol.sdk:mcp-json-jackson3` - Jackson 3.x JSON serialization implementation (included in `mcp` bundle). + - `io.modelcontextprotocol.sdk:mcp-json-jackson2` - Jackson 2.x JSON serialization implementation for projects that require Jackson 2.x compatibility. +- **Optional Spring Transport Dependencies** (part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+, group `org.springframework.ai`) + - `org.springframework.ai:mcp-spring-webflux` - WebFlux-based SSE and Streamable HTTP transport implementation for reactive applications. + - `org.springframework.ai:mcp-spring-webmvc` - WebMVC-based SSE and Streamable HTTP transport implementation for servlet-based applications. +- **Testing Dependencies** + - `io.modelcontextprotocol.sdk:mcp-test` - Testing utilities and support for MCP-based applications. diff --git a/docs/server.md b/docs/server.md new file mode 100644 index 000000000..0753726e2 --- /dev/null +++ b/docs/server.md @@ -0,0 +1,761 @@ +--- +title: MCP Server +description: Learn how to implement and configure a Model Context Protocol (MCP) server +--- + +# MCP Server + +## Overview + +The MCP Server is a foundational component in the Model Context Protocol (MCP) architecture that provides tools, resources, and capabilities to clients. It implements the server-side of the protocol, responsible for: + +- Exposing tools that clients can discover and execute +- Managing resources with URI-based access patterns and resource templates +- Providing prompt templates and handling prompt requests +- Supporting capability negotiation with clients +- Providing argument autocompletion suggestions (completions) +- Implementing server-side protocol operations +- Managing concurrent client connections +- Providing structured logging and notifications + +!!! tip + The core `io.modelcontextprotocol.sdk:mcp` module provides STDIO, SSE, and Streamable HTTP server transport implementations without requiring external web frameworks. + + Spring-specific transport implementations (`mcp-spring-webflux`, `mcp-spring-webmvc`) are now part of [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`) and are no longer shipped by this SDK. + See the [MCP Server Boot Starter](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-server-boot-starter-docs.html) documentation for Spring-based server setup. + +The server supports both synchronous and asynchronous APIs, allowing for flexible integration in different application contexts. + +=== "Sync API" + + ```java + // Create a server with custom configuration + McpSyncServer syncServer = McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .capabilities(ServerCapabilities.builder() + .resources(false, true) // Enable resource support with list changes + .tools(true) // Enable tool support with list changes + .prompts(true) // Enable prompt support with list changes + .completions() // Enable completions support + .logging() // Enable logging support + .build()) + .build(); + + // Register tools, resources, and prompts + syncServer.addTool(syncToolSpecification); + syncServer.addResource(syncResourceSpecification); + syncServer.addPrompt(syncPromptSpecification); + + // Close the server when done + syncServer.close(); + ``` + +=== "Async API" + + ```java + // Create an async server with custom configuration + McpAsyncServer asyncServer = McpServer.async(transportProvider) + .serverInfo("my-server", "1.0.0") + .capabilities(ServerCapabilities.builder() + .resources(false, true) // Enable resource support with list changes + .tools(true) // Enable tool support with list changes + .prompts(true) // Enable prompt support with list changes + .completions() // Enable completions support + .logging() // Enable logging support + .build()) + .build(); + + // Register tools, resources, and prompts + asyncServer.addTool(asyncToolSpecification) + .doOnSuccess(v -> logger.info("Tool registered")) + .subscribe(); + + asyncServer.addResource(asyncResourceSpecification) + .doOnSuccess(v -> logger.info("Resource registered")) + .subscribe(); + + asyncServer.addPrompt(asyncPromptSpecification) + .doOnSuccess(v -> logger.info("Prompt registered")) + .subscribe(); + + // Close the server when done + asyncServer.close() + .doOnSuccess(v -> logger.info("Server closed")) + .subscribe(); + ``` + +### Server Types + +The SDK supports multiple server creation patterns depending on your transport requirements: + +```java +// Single-session server with SSE transport provider +McpSyncServer server = McpServer.sync(sseTransportProvider).build(); + +// Streamable HTTP server +McpSyncServer server = McpServer.sync(streamableTransportProvider).build(); + +// Stateless server (no session management) +McpSyncServer server = McpServer.sync(statelessTransport).build(); +``` + +## Server Transport Providers + +The transport layer in the MCP SDK is responsible for handling the communication between clients and servers. +It provides different implementations to support various communication protocols and patterns. +The SDK includes several built-in transport provider implementations: + +### STDIO + +Create process-based transport using stdin/stdout: + +```java +StdioServerTransportProvider transportProvider = + new StdioServerTransportProvider(new ObjectMapper()); +``` + +Provides bidirectional JSON-RPC message handling over standard input/output streams with non-blocking message processing, serialization/deserialization, and graceful shutdown support. + +Key features: + +- Bidirectional communication through stdin/stdout +- Process-based integration support +- Simple setup and configuration +- Lightweight implementation + +### Streamable HTTP + +=== "Streamable HTTP Servlet" + + Creates a Servlet-based Streamable HTTP server transport. Included in the core `mcp` module: + + ```java + HttpServletStreamableServerTransportProvider transportProvider = + HttpServletStreamableServerTransportProvider.builder() + .jsonMapper(jsonMapper) + .mcpEndpoint("/mcp") + .build(); + ``` + + To use with a Spring Web application, register it as a Servlet bean: + + ```java + @Configuration + @EnableWebMvc + public class McpServerConfig implements WebMvcConfigurer { + + @Bean + public HttpServletStreamableServerTransportProvider transportProvider(McpJsonMapper jsonMapper) { + return HttpServletStreamableServerTransportProvider.builder() + .jsonMapper(jsonMapper) + .mcpEndpoint("/mcp") + .build(); + } + + @Bean + public ServletRegistrationBean mcpServlet( + HttpServletStreamableServerTransportProvider transportProvider) { + return new ServletRegistrationBean<>(transportProvider); + } + } + ``` + + Key features: + + - Efficient bidirectional HTTP communication + - Session management for multiple client connections + - Configurable keep-alive intervals + - Security validation support + - Graceful shutdown support + +=== "Streamable HTTP WebFlux (external)" + + Creates WebFlux-based Streamable HTTP server transport. Requires the `mcp-spring-webflux` dependency from [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```java + @Configuration + class McpConfig { + @Bean + WebFluxStreamableServerTransportProvider transportProvider(McpJsonMapper jsonMapper) { + return WebFluxStreamableServerTransportProvider.builder() + .jsonMapper(jsonMapper) + .messageEndpoint("/mcp") + .build(); + } + + @Bean + RouterFunction mcpRouterFunction( + WebFluxStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + } + ``` + + Key features: + + - Reactive HTTP streaming with WebFlux + - Concurrent client connections + - Configurable keep-alive intervals + - Security validation support + +=== "Streamable HTTP WebMvc (external)" + + Creates WebMvc-based Streamable HTTP server transport. Requires the `mcp-spring-webmvc` dependency from [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```java + @Configuration + @EnableWebMvc + class McpConfig { + @Bean + WebMvcStreamableServerTransportProvider transportProvider(McpJsonMapper jsonMapper) { + return WebMvcStreamableServerTransportProvider.builder() + .jsonMapper(jsonMapper) + .mcpEndpoint("/mcp") + .build(); + } + + @Bean + RouterFunction mcpRouterFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + } + ``` + +### SSE HTTP (Legacy) + +=== "SSE Servlet" + + Creates a Servlet-based SSE server transport. Included in the core `mcp` module. + The `HttpServletSseServerTransportProvider` can be used with any Servlet container. + To use it with a Spring Web application, you can register it as a Servlet bean: + + ```java + @Configuration + @EnableWebMvc + public class McpServerConfig implements WebMvcConfigurer { + + @Bean + public HttpServletSseServerTransportProvider servletSseServerTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); + } + + @Bean + public ServletRegistrationBean customServletBean( + HttpServletSseServerTransportProvider transportProvider) { + return new ServletRegistrationBean<>(transportProvider); + } + } + ``` + + Implements the MCP HTTP with SSE transport specification using the traditional Servlet API, providing: + + - Asynchronous message handling using Servlet 6.0 async support + - Session management for multiple client connections + - Two types of endpoints: + - SSE endpoint (`/sse`) for server-to-client events + - Message endpoint (configurable) for client-to-server requests + - Error handling and response formatting + - Graceful shutdown support + +=== "SSE WebFlux (external)" + + Creates WebFlux-based SSE server transport. Requires the `mcp-spring-webflux` dependency from [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```java + @Configuration + class McpConfig { + @Bean + WebFluxSseServerTransportProvider webFluxSseServerTransportProvider(ObjectMapper mapper) { + return new WebFluxSseServerTransportProvider(mapper, "/mcp/message"); + } + + @Bean + RouterFunction mcpRouterFunction(WebFluxSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + } + ``` + + Implements the MCP HTTP with SSE transport specification, providing: + + - Reactive HTTP streaming with WebFlux + - Concurrent client connections through SSE endpoints + - Message routing and session management + - Graceful shutdown capabilities + +=== "SSE WebMvc (external)" + + Creates WebMvc-based SSE server transport. Requires the `mcp-spring-webmvc` dependency from [Spring AI](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) 2.0+ (group `org.springframework.ai`): + + ```java + @Configuration + @EnableWebMvc + class McpConfig { + @Bean + WebMvcSseServerTransportProvider webMvcSseServerTransportProvider(ObjectMapper mapper) { + return new WebMvcSseServerTransportProvider(mapper, "/mcp/message"); + } + + @Bean + RouterFunction mcpRouterFunction( + WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + } + ``` + + Implements the MCP HTTP with SSE transport specification, providing: + + - Server-side event streaming + - Integration with Spring WebMVC + - Support for traditional web applications + - Synchronous operation handling + + +## Server Capabilities + +The server can be configured with various capabilities: + +```java +var capabilities = ServerCapabilities.builder() + .resources(false, true) // Resource support (subscribe, listChanged) + .tools(true) // Tool support with list changes notifications + .prompts(true) // Prompt support with list changes notifications + .completions() // Enable completions support + .logging() // Enable logging support + .build(); +``` + +### Tool Specification + +The Model Context Protocol allows servers to [expose tools](https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/) that can be invoked by language models. +The Java SDK allows implementing Tool Specifications with their handler functions. +Tools enable AI models to perform calculations, access external APIs, query databases, and manipulate files. + +The recommended approach is to use the builder pattern and `CallToolRequest` as the handler parameter: + +=== "Sync" + + ```java + // Sync tool specification using builder + var syncToolSpecification = SyncToolSpecification.builder() + .tool(Tool.builder() + .name("calculator") + .description("Basic calculator") + .inputSchema(schema) + .build()) + .callHandler((exchange, request) -> { + // Access arguments via request.arguments() + String operation = (String) request.arguments().get("operation"); + int a = (int) request.arguments().get("a"); + int b = (int) request.arguments().get("b"); + // Tool implementation + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Result: " + result))) + .build(); + }) + .build(); + ``` + +=== "Async" + + ```java + // Async tool specification using builder + var asyncToolSpecification = AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("calculator") + .description("Basic calculator") + .inputSchema(schema) + .build()) + .callHandler((exchange, request) -> { + // Access arguments via request.arguments() + String operation = (String) request.arguments().get("operation"); + int a = (int) request.arguments().get("a"); + int b = (int) request.arguments().get("b"); + // Tool implementation + return Mono.just(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Result: " + result))) + .build()); + }) + .build(); + ``` + +The Tool specification includes a Tool definition with `name`, `description`, and `inputSchema` followed by a call handler that implements the tool's logic. +The handler receives `McpSyncServerExchange`/`McpAsyncServerExchange` for client interaction and a `CallToolRequest` containing the tool arguments. + +You can also register tools directly on the server builder using the `toolCall` convenience method: + +```java +var server = McpServer.sync(transportProvider) + .toolCall( + Tool.builder().name("echo").description("Echoes input").inputSchema(schema).build(), + (exchange, request) -> CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(request.arguments().get("text").toString()))) + .build() + ) + .build(); +``` + +### Resource Specification + +Specification of a resource with its handler function. +Resources provide context to AI models by exposing data such as: File contents, Database records, API responses, System information, Application state. + +=== "Sync" + + ```java + // Sync resource specification + var syncResourceSpecification = new McpServerFeatures.SyncResourceSpecification( + Resource.builder() + .uri("custom://resource") + .name("name") + .description("description") + .mimeType("text/plain") + .build(), + (exchange, request) -> { + // Resource read implementation + return new ReadResourceResult(contents); + } + ); + ``` + +=== "Async" + + ```java + // Async resource specification + var asyncResourceSpecification = new McpServerFeatures.AsyncResourceSpecification( + Resource.builder() + .uri("custom://resource") + .name("name") + .description("description") + .mimeType("text/plain") + .build(), + (exchange, request) -> { + // Resource read implementation + return Mono.just(new ReadResourceResult(contents)); + } + ); + ``` + +### Resource Template Specification + +Resource templates allow servers to expose parameterized resources using URI templates: + +```java +// Resource template specification +var resourceTemplateSpec = new McpServerFeatures.SyncResourceTemplateSpecification( + ResourceTemplate.builder() + .uriTemplate("file://{path}") + .name("File Resource") + .description("Access files by path") + .mimeType("application/octet-stream") + .build(), + (exchange, request) -> { + // Read the file at the requested URI + return new ReadResourceResult(contents); + } +); +``` + +### Prompt Specification + +As part of the [Prompting capabilities](https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/), MCP provides a standardized way for servers to expose prompt templates to clients. +The Prompt Specification is a structured template for AI model interactions that enables consistent message formatting, parameter substitution, context injection, response formatting, and instruction templating. + +=== "Sync" + + ```java + // Sync prompt specification + var syncPromptSpecification = new McpServerFeatures.SyncPromptSpecification( + new Prompt("greeting", "description", List.of( + new PromptArgument("name", "description", true) + )), + (exchange, request) -> { + // Prompt implementation + return new GetPromptResult(description, messages); + } + ); + ``` + +=== "Async" + + ```java + // Async prompt specification + var asyncPromptSpecification = new McpServerFeatures.AsyncPromptSpecification( + new Prompt("greeting", "description", List.of( + new PromptArgument("name", "description", true) + )), + (exchange, request) -> { + // Prompt implementation + return Mono.just(new GetPromptResult(description, messages)); + } + ); + ``` + +The prompt definition includes name (identifier for the prompt), description (purpose of the prompt), and list of arguments (parameters for templating). +The handler function processes requests and returns formatted templates. +The first argument is `McpSyncServerExchange`/`McpAsyncServerExchange` for client interaction, and the second argument is a `GetPromptRequest` instance. + +### Completion Specification + +Completions allow servers to provide argument autocompletion suggestions for prompts and resources: + +=== "Sync" + + ```java + // Sync completion specification + var syncCompletionSpec = new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference("greeting"), // Reference to a prompt + (exchange, request) -> { + String argName = request.argument().name(); + String partial = request.argument().value(); + // Return matching suggestions + List suggestions = findMatches(partial); + return new McpSchema.CompleteResult( + new McpSchema.CompleteResult.CompleteCompletion(suggestions, suggestions.size(), false) + ); + } + ); + ``` + +=== "Async" + + ```java + // Async completion specification + var asyncCompletionSpec = new McpServerFeatures.AsyncCompletionSpecification( + new McpSchema.PromptReference("greeting"), + (exchange, request) -> { + String argName = request.argument().name(); + String partial = request.argument().value(); + List suggestions = findMatches(partial); + return Mono.just(new McpSchema.CompleteResult( + new McpSchema.CompleteResult.CompleteCompletion(suggestions, suggestions.size(), false) + )); + } + ); + ``` + +Completions can be registered for both `PromptReference` and `ResourceReference` types. + +### Using Sampling from a Server + +To use [Sampling capabilities](https://spec.modelcontextprotocol.io/specification/2024-11-05/client/sampling/), connect to a client that supports sampling. +No special server configuration is needed, but verify client sampling support before making requests. +Learn about [client sampling support](client.md#sampling-support). + +Once connected to a compatible client, the server can request language model generations: + +=== "Sync API" + + ```java + // Create a server + McpSyncServer server = McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .build(); + + // Define a tool that uses sampling + var calculatorTool = SyncToolSpecification.builder() + .tool(Tool.builder() + .name("ai-calculator") + .description("Performs calculations using AI") + .inputSchema(schema) + .build()) + .callHandler((exchange, request) -> { + // Check if client supports sampling + if (exchange.getClientCapabilities().sampling() == null) { + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Client does not support AI capabilities"))) + .build(); + } + + // Create a sampling request + CreateMessageRequest samplingRequest = CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Calculate: " + request.arguments().get("expression"))))) + .modelPreferences(McpSchema.ModelPreferences.builder() + .hints(List.of( + McpSchema.ModelHint.of("claude-3-sonnet"), + McpSchema.ModelHint.of("claude") + )) + .intelligencePriority(0.8) + .speedPriority(0.5) + .build()) + .systemPrompt("You are a helpful calculator assistant. Provide only the numerical answer.") + .maxTokens(100) + .build(); + + // Request sampling from the client + CreateMessageResult result = exchange.createMessage(samplingRequest); + + // Process the result + String answer = ((McpSchema.TextContent) result.content()).text(); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(answer))) + .build(); + }) + .build(); + + // Add the tool to the server + server.addTool(calculatorTool); + ``` + +=== "Async API" + + ```java + // Create a server + McpAsyncServer server = McpServer.async(transportProvider) + .serverInfo("my-server", "1.0.0") + .build(); + + // Define a tool that uses sampling + var calculatorTool = AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ai-calculator") + .description("Performs calculations using AI") + .inputSchema(schema) + .build()) + .callHandler((exchange, request) -> { + // Check if client supports sampling + if (exchange.getClientCapabilities().sampling() == null) { + return Mono.just(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Client does not support AI capabilities"))) + .build()); + } + + // Create a sampling request + CreateMessageRequest samplingRequest = CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Calculate: " + request.arguments().get("expression"))))) + .modelPreferences(McpSchema.ModelPreferences.builder() + .hints(List.of( + McpSchema.ModelHint.of("claude-3-sonnet"), + McpSchema.ModelHint.of("claude") + )) + .intelligencePriority(0.8) + .speedPriority(0.5) + .build()) + .systemPrompt("You are a helpful calculator assistant. Provide only the numerical answer.") + .maxTokens(100) + .build(); + + // Request sampling from the client + return exchange.createMessage(samplingRequest) + .map(result -> { + String answer = ((McpSchema.TextContent) result.content()).text(); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(answer))) + .build(); + }); + }) + .build(); + + // Add the tool to the server + server.addTool(calculatorTool) + .subscribe(); + ``` + +The `CreateMessageRequest` object allows you to specify: `Content` - the input text or image for the model, +`Model Preferences` - hints and priorities for model selection, `System Prompt` - instructions for the model's behavior and +`Max Tokens` - maximum length of the generated response. + +### Using Elicitation from a Server + +Servers can request user input from connected clients that support elicitation: + +```java +var tool = SyncToolSpecification.builder() + .tool(Tool.builder() + .name("confirm-action") + .description("Confirms an action with the user") + .inputSchema(schema) + .build()) + .callHandler((exchange, request) -> { + // Check if client supports elicitation + if (exchange.getClientCapabilities().elicitation() == null) { + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Client does not support elicitation"))) + .build(); + } + + // Request user confirmation + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Do you want to proceed with this action?") + .requestedSchema(Map.of( + "type", "object", + "properties", Map.of("confirmed", Map.of("type", "boolean")) + )) + .build(); + + ElicitResult result = exchange.elicit(elicitRequest); + + if (result.action() == ElicitResult.Action.ACCEPT) { + // User accepted + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Action confirmed"))) + .build(); + } else { + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Action declined"))) + .build(); + } + }) + .build(); +``` + +### Logging Support + +The server provides structured logging capabilities that allow sending log messages to clients with different severity levels. +Log notifications can only be sent from within an existing client session, such as tools, resources, and prompts calls. + +The server can send log messages using the `McpAsyncServerExchange`/`McpSyncServerExchange` object in the tool/resource/prompt handler function: + +```java +var tool = new McpServerFeatures.AsyncToolSpecification( + Tool.builder().name("logging-test").description("Test logging notifications").inputSchema(emptyJsonSchema).build(), + null, + (exchange, request) -> { + + exchange.loggingNotification( // Use the exchange to send log messages + McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .block(); + + return Mono.just(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .build()); + }); + +var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities( + ServerCapabilities.builder() + .logging() // Enable logging support + .tools(true) + .build()) + .tools(tool) + .build(); +``` + +On the client side, you can register a logging consumer to receive log messages from the server: + +```java +var mcpClient = McpClient.sync(transport) + .loggingConsumer(notification -> { + System.out.println("Received log message: " + notification.data()); + }) + .build(); + +mcpClient.initialize(); +mcpClient.setLoggingLevel(McpSchema.LoggingLevel.INFO); +``` + +Clients can control the minimum logging level they receive through the `mcpClient.setLoggingLevel(level)` request. Messages below the set level will be filtered out. +Supported logging levels (in order of increasing severity): DEBUG (0), INFO (1), NOTICE (2), WARNING (3), ERROR (4), CRITICAL (5), ALERT (6), EMERGENCY (7) + +## Error Handling + +The SDK provides comprehensive error handling through the McpError class, covering protocol compatibility, transport communication, JSON-RPC messaging, tool execution, resource management, prompt handling, timeouts, and connection issues. This unified error handling approach ensures consistent and reliable error management across both synchronous and asynchronous operations. diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 83d8bc510..fb6f3a32a 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 1.1.0-SNAPSHOT mcp-bom @@ -29,28 +29,28 @@ io.modelcontextprotocol.sdk - mcp + mcp-core ${project.version} - + io.modelcontextprotocol.sdk - mcp-test + mcp ${project.version} - + io.modelcontextprotocol.sdk - mcp-spring-webflux + mcp-json-jackson2 ${project.version} - + io.modelcontextprotocol.sdk - mcp-spring-webmvc + mcp-test ${project.version} diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-core/pom.xml similarity index 50% rename from mcp-spring/mcp-spring-webflux/pom.xml rename to mcp-core/pom.xml index 300d518e7..4de0fba2b 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-core/pom.xml @@ -6,13 +6,12 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT - ../../pom.xml + 1.1.0-SNAPSHOT - mcp-spring-webflux + mcp-core jar - WebFlux transports - WebFlux implementation for the SSE and Streamable Http Client and Server transports + Java MCP SDK Core + Core classes of the Java SDK implementation of the Model Context Protocol, enabling seamless integration with language models and AI tools https://github.com/modelcontextprotocol/java-sdk @@ -21,47 +20,76 @@ git@github.com/modelcontextprotocol/java-sdk.git + + + + biz.aQute.bnd + bnd-maven-plugin + ${bnd-maven-plugin.version} + + + bnd-process + + bnd-process + + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + - - io.modelcontextprotocol.sdk - mcp - 0.12.0-SNAPSHOT - - io.modelcontextprotocol.sdk - mcp-test - 0.12.0-SNAPSHOT - test + org.slf4j + slf4j-api + ${slf4j-api.version} - org.springframework - spring-webflux - ${springframework.version} + com.fasterxml.jackson.core + jackson-annotations + ${jackson-annotations.version} - io.projectreactor.netty - reactor-netty-http - test + io.projectreactor + reactor-core - - - org.springframework - spring-context - ${springframework.version} - test - + - org.springframework - spring-test - ${springframework.version} - test + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + provided @@ -76,12 +104,20 @@ ${junit.version} test + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + org.mockito mockito-core ${mockito.version} test + + net.bytebuddy byte-buddy @@ -99,12 +135,6 @@ ${testcontainers.version} test - - org.testcontainers - toxiproxy - ${toxiproxy.version} - test - org.awaitility @@ -121,19 +151,26 @@ - org.junit.jupiter - junit-jupiter-params - ${junit-jupiter.version} + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} test - net.javacrumbs.json-unit - json-unit-assertj - ${json-unit-assertj.version} + org.testcontainers + toxiproxy + ${toxiproxy.version} test + + + com.google.code.gson + gson + 2.10.1 + test + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java similarity index 89% rename from mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java index 2cc1c5dba..07d86f40e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -11,14 +11,13 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.util.context.ContextView; @@ -99,21 +98,30 @@ class LifecycleInitializer { */ private final Duration initializationTimeout; + /** + * Post-initialization hook to perform additional operations after every successful + * initialization. + */ + private final Function> postInitializationHook; + public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, List protocolVersions, Duration initializationTimeout, - Function sessionSupplier) { + Function sessionSupplier, + Function> postInitializationHook) { Assert.notNull(sessionSupplier, "Session supplier must not be null"); Assert.notNull(clientCapabilities, "Client capabilities must not be null"); Assert.notNull(clientInfo, "Client info must not be null"); Assert.notEmpty(protocolVersions, "Protocol versions must not be empty"); Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + Assert.notNull(postInitializationHook, "Post-initialization hook must not be null"); this.sessionSupplier = sessionSupplier; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions)); this.initializationTimeout = initializationTimeout; + this.postInitializationHook = postInitializationHook; } /** @@ -148,10 +156,6 @@ interface Initialization { } - /** - * Default implementation of the {@link Initialization} interface that manages the MCP - * client initialization process. - */ private static class DefaultInitialization implements Initialization { /** @@ -199,29 +203,20 @@ private void setMcpClientSession(McpClientSession mcpClientSession) { this.mcpClientSession.set(mcpClientSession); } - /** - * Returns a Mono that completes when the MCP client initialization is complete. - * This allows subscribers to wait for the initialization to finish before - * proceeding with further operations. - * @return A Mono that emits the result of the MCP initialization process - */ private Mono await() { return this.initSink.asMono(); } - /** - * Completes the initialization process with the given result. It caches the - * result and emits it to all subscribers waiting for the initialization to - * complete. - * @param initializeResult The result of the MCP initialization process - */ private void complete(McpSchema.InitializeResult initializeResult) { - // first ensure the result is cached - this.result.set(initializeResult); // inform all the subscribers waiting for the initialization this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); } + private void cacheResult(McpSchema.InitializeResult initializeResult) { + // first ensure the result is cached + this.result.set(initializeResult); + } + private void error(Throwable t) { this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); } @@ -263,7 +258,7 @@ public void handleException(Throwable t) { } // Providing an empty operation since we are only interested in triggering // the implicit initialization step. - withIntitialization("re-initializing", result -> Mono.empty()).subscribe(); + this.withInitialization("re-initializing", result -> Mono.empty()).subscribe(); } } @@ -275,7 +270,7 @@ public void handleException(Throwable t) { * @param operation The operation to execute when the client is initialized * @return A Mono that completes with the result of the operation */ - public Mono withIntitialization(String actionName, Function> operation) { + public Mono withInitialization(String actionName, Function> operation) { return Mono.deferContextual(ctx -> { DefaultInitialization newInit = new DefaultInitialization(); DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit); @@ -283,19 +278,24 @@ public Mono withIntitialization(String actionName, Function initializationJob = needsToInitialize ? doInitialize(newInit, ctx) - : previous.await(); + Mono initializationJob = needsToInitialize + ? this.doInitialize(newInit, this.postInitializationHook, ctx) : previous.await(); return initializationJob.map(initializeResult -> this.initializationRef.get()) .timeout(this.initializationTimeout) .onErrorResume(ex -> { + this.initializationRef.compareAndSet(newInit, null); return Mono.error(new RuntimeException("Client failed to initialize " + actionName, ex)); }) - .flatMap(operation); + .flatMap(res -> operation.apply(res) + .contextWrite(c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + res.initializeResult().protocolVersion()))); }); } - private Mono doInitialize(DefaultInitialization initialization, ContextView ctx) { + private Mono doInitialize(DefaultInitialization initialization, + Function> postInitOperation, ContextView ctx) { + initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); McpClientSession mcpClientSession = initialization.mcpSession(); @@ -321,7 +321,12 @@ private Mono doInitialize(DefaultInitialization init } return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + .contextWrite( + c -> c.put(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, initializeResult.protocolVersion())) .thenReturn(initializeResult); + }).flatMap(initializeResult -> { + initialization.cacheResult(initializeResult); + return postInitOperation.apply(initialization).thenReturn(initializeResult); }).doOnNext(initialization::complete).onErrorResume(ex -> { initialization.error(ex); return Mono.error(ex); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index eb6d42f68..93fcc332a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -15,14 +15,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; - +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; @@ -31,15 +30,16 @@ import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.util.ToolNameValidator; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; -import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -76,6 +76,7 @@ * @author Dariusz JΔ™drzejczyk * @author Christian Tzolov * @author Jihoon Kim + * @author Anurag Pant * @see McpClient * @see McpSchema * @see McpClientSession @@ -85,27 +86,29 @@ public class McpAsyncClient { private static final Logger logger = LoggerFactory.getLogger(McpAsyncClient.class); - private static final TypeReference VOID_TYPE_REFERENCE = new TypeReference<>() { + private static final TypeRef VOID_TYPE_REFERENCE = new TypeRef<>() { }; - public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference PAGINATED_REQUEST_TYPE_REF = new TypeReference<>() { + public static final TypeRef PAGINATED_REQUEST_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference INITIALIZE_RESULT_TYPE_REF = new TypeReference<>() { + public static final TypeRef INITIALIZE_RESULT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeReference<>() { + public static final TypeRef CREATE_MESSAGE_REQUEST_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { + public static final TypeRef LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference PROGRESS_NOTIFICATION_TYPE_REF = new TypeReference<>() { + public static final TypeRef PROGRESS_NOTIFICATION_TYPE_REF = new TypeRef<>() { }; + public static final String NEGOTIATED_PROTOCOL_VERSION = "io.modelcontextprotocol.client.negotiated-protocol-version"; + /** * Client capabilities. */ @@ -153,16 +156,33 @@ public class McpAsyncClient { */ private final LifecycleInitializer initializer; + /** + * JSON schema validator to use for validating tool responses against output schemas. + */ + private final JsonSchemaValidator jsonSchemaValidator; + + /** + * Cached tool output schemas. + */ + private final ConcurrentHashMap> toolsOutputSchemaCache; + + /** + * Whether to enable automatic schema caching during callTool operations. + */ + private final boolean enableCallToolSchemaCaching; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. * @param transport the transport to use. * @param requestTimeout the session request-response timeout. * @param initializationTimeout the max timeout to await for the client-server - * @param features the MCP Client supported features. + * @param jsonSchemaValidator the JSON schema validator to use for validating tool + * @param features the MCP Client supported features. responses against output + * schemas. */ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - McpClientFeatures.Async features) { + JsonSchemaValidator jsonSchemaValidator, McpClientFeatures.Async features) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); @@ -172,6 +192,9 @@ public class McpAsyncClient { this.clientCapabilities = features.clientCapabilities(); this.transport = transport; this.roots = new ConcurrentHashMap<>(features.roots()); + this.jsonSchemaValidator = jsonSchemaValidator; + this.toolsOutputSchemaCache = new ConcurrentHashMap<>(); + this.enableCallToolSchemaCaching = features.enableCallToolSchemaCaching(); // Request Handlers Map> requestHandlers = new HashMap<>(); @@ -274,9 +297,30 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, asyncProgressNotificationHandler(progressConsumersFinal)); + Function> postInitializationHook = init -> { + + if (init.initializeResult().capabilities().tools() == null || !enableCallToolSchemaCaching) { + return Mono.empty(); + } + + return this.listToolsInternal(init, McpSchema.FIRST_PAGE).doOnNext(listToolsResult -> { + listToolsResult.tools() + .forEach(tool -> logger.debug("Tool {} schema: {}", tool.name(), tool.outputSchema())); + if (enableCallToolSchemaCaching && listToolsResult.tools() != null) { + // Cache tools output schema + listToolsResult.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }).then(); + }; + this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(), initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, - notificationHandlers, con -> con.contextWrite(ctx))); + notificationHandlers, con -> con.contextWrite(ctx)), + postInitializationHook); + this.transport.setExceptionHandler(this.initializer::handleException); } @@ -361,6 +405,7 @@ public Mono closeGracefully() { // -------------------------- // Initialization // -------------------------- + /** * The initialization phase should be the first interaction between client and server. * The client will ensure it happens in case it has not been explicitly called and in @@ -388,7 +433,7 @@ public Mono closeGracefully() { *

*/ public Mono initialize() { - return this.initializer.withIntitialization("by explicit API call", init -> Mono.just(init.initializeResult())); + return this.initializer.withInitialization("by explicit API call", init -> Mono.just(init.initializeResult())); } // -------------------------- @@ -400,13 +445,14 @@ public Mono initialize() { * @return A Mono that completes with the server's ping response */ public Mono ping() { - return this.initializer.withIntitialization("pinging the server", + return this.initializer.withInitialization("pinging the server", init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } // -------------------------- // Roots // -------------------------- + /** * Adds a new root to the client's root list. * @param root The root to add. @@ -481,7 +527,7 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.initializer.withIntitialization("sending roots list changed notification", + return this.initializer.withInitialization("sending roots list changed notification", init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } @@ -512,7 +558,7 @@ private RequestHandler samplingCreateMessageHandler() { // -------------------------- private RequestHandler elicitationCreateHandler() { return params -> { - ElicitRequest request = transport.unmarshalFrom(params, new TypeReference<>() { + ElicitRequest request = transport.unmarshalFrom(params, new TypeRef<>() { }); return this.elicitationHandler.apply(request); @@ -522,10 +568,10 @@ private RequestHandler elicitationCreateHandler() { // -------------------------- // Tools // -------------------------- - private static final TypeReference CALL_TOOL_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef CALL_TOOL_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_TOOLS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_TOOLS_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -540,27 +586,57 @@ private RequestHandler elicitationCreateHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.initializer.withIntitialization("calling tools", init -> { + return this.initializer.withInitialization("calling tool", init -> { if (init.initializeResult().capabilities().tools() == null) { return Mono.error(new IllegalStateException("Server does not provide tools capability")); } + return init.mcpSession() - .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF) + .flatMap(result -> Mono.just(validateToolResult(callToolRequest.name(), result))); }); } + private McpSchema.CallToolResult validateToolResult(String toolName, McpSchema.CallToolResult result) { + + if (!this.enableCallToolSchemaCaching || result == null || result.isError() == Boolean.TRUE) { + // if tool schema caching is disabled or tool call resulted in an error - skip + // validation and return the result as it is + return result; + } + + Map optOutputSchema = this.toolsOutputSchemaCache.get(toolName); + + if (optOutputSchema == null) { + logger.warn( + "Calling a tool with no outputSchema is not expected to return result with structured content, but got: {}", + result.structuredContent()); + return result; + } + + // Validate the tool output against the cached output schema + var validation = this.jsonSchemaValidator.validate(optOutputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + throw new IllegalArgumentException("Tool call result validation failed: " + validation.errorMessage()); + } + + return result; + } + /** * Retrieves the list of all tools provided by the server. * @return A Mono that emits the list of all tools result */ public Mono listTools() { - return this.listTools(McpSchema.FIRST_PAGE) - .expand(result -> (result.nextCursor() != null) ? this.listTools(result.nextCursor()) : Mono.empty()) - .reduce(new McpSchema.ListToolsResult(new ArrayList<>(), null), (allToolsResult, result) -> { - allToolsResult.tools().addAll(result.tools()); - return allToolsResult; - }) - .map(result -> new McpSchema.ListToolsResult(Collections.unmodifiableList(result.tools()), null)); + return this.listTools(McpSchema.FIRST_PAGE).expand(result -> { + String next = result.nextCursor(); + return (next != null && !next.isEmpty()) ? this.listTools(next) : Mono.empty(); + }).reduce(new McpSchema.ListToolsResult(new ArrayList<>(), null), (allToolsResult, result) -> { + allToolsResult.tools().addAll(result.tools()); + return allToolsResult; + }).map(result -> new McpSchema.ListToolsResult(Collections.unmodifiableList(result.tools()), null)); } /** @@ -569,14 +645,30 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.initializer.withIntitialization("listing tools", init -> { - if (init.initializeResult().capabilities().tools() == null) { - return Mono.error(new IllegalStateException("Server does not provide tools capability")); - } - return init.mcpSession() - .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_TOOLS_RESULT_TYPE_REF); - }); + return this.initializer.withInitialization("listing tools", init -> this.listToolsInternal(init, cursor)); + } + + private Mono listToolsInternal(Initialization init, String cursor) { + + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); + } + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TOOLS_RESULT_TYPE_REF) + .doOnNext(result -> { + // Validate tool names (warn only) + if (result.tools() != null) { + result.tools().forEach(tool -> ToolNameValidator.validate(tool.name(), false)); + } + if (this.enableCallToolSchemaCaching && result.tools() != null) { + // Cache tools output schema + result.tools() + .stream() + .filter(tool -> tool.outputSchema() != null) + .forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(), tool.outputSchema())); + } + }); } private NotificationHandler asyncToolsChangeNotificationHandler( @@ -596,13 +688,13 @@ private NotificationHandler asyncToolsChangeNotificationHandler( // Resources // -------------------------- - private static final TypeReference LIST_RESOURCES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCES_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference READ_RESOURCE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef READ_RESOURCE_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -633,7 +725,7 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.initializer.withIntitialization("listing resources", init -> { + return this.initializer.withInitialization("listing resources", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } @@ -665,7 +757,7 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.initializer.withIntitialization("reading resources", init -> { + return this.initializer.withInitialization("reading resources", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } @@ -703,7 +795,7 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.initializer.withIntitialization("listing resource templates", init -> { + return this.initializer.withInitialization("listing resource templates", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } @@ -723,7 +815,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.initializer.withIntitialization("subscribing to resources", init -> init.mcpSession() + return this.initializer.withInitialization("subscribing to resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -737,7 +829,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.initializer.withIntitialization("unsubscribing from resources", init -> init.mcpSession() + return this.initializer.withInitialization("unsubscribing from resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -756,7 +848,7 @@ private NotificationHandler asyncResourcesUpdatedNotificationHandler( List, Mono>> resourcesUpdateConsumers) { return params -> { McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification = transport.unmarshalFrom(params, - new TypeReference<>() { + new TypeRef<>() { }); return readResource(new McpSchema.ReadResourceRequest(resourcesUpdatedNotification.uri())) @@ -773,10 +865,10 @@ private NotificationHandler asyncResourcesUpdatedNotificationHandler( // -------------------------- // Prompts // -------------------------- - private static final TypeReference LIST_PROMPTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_PROMPTS_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference GET_PROMPT_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef GET_PROMPT_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -803,7 +895,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.initializer.withIntitialization("listing prompts", init -> init.mcpSession() + return this.initializer.withInitialization("listing prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -817,7 +909,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.initializer.withIntitialization("getting prompts", init -> init.mcpSession() + return this.initializer.withInitialization("getting prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -835,14 +927,6 @@ private NotificationHandler asyncPromptsChangeNotificationHandler( // -------------------------- // Logging // -------------------------- - /** - * Create a notification handler for logging notifications from the server. This - * handler automatically distributes logging messages to all registered consumers. - * @param loggingConsumers List of consumers that will be notified when a logging - * message is received. Each consumer receives the logging message notification. - * @return A NotificationHandler that processes log notifications by distributing the - * message to all registered consumers - */ private NotificationHandler asyncLoggingNotificationHandler( List>> loggingConsumers) { @@ -868,7 +952,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { return Mono.error(new IllegalArgumentException("Logging level must not be null")); } - return this.initializer.withIntitialization("setting logging level", init -> { + return this.initializer.withInitialization("setting logging level", init -> { if (init.initializeResult().capabilities().logging() == null) { return Mono.error(new IllegalStateException("Server's Logging capabilities are not enabled!")); } @@ -877,15 +961,6 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { }); } - /** - * Create a notification handler for progress notifications from the server. This - * handler automatically distributes progress notifications to all registered - * consumers. - * @param progressConsumers List of consumers that will be notified when a progress - * message is received. Each consumer receives the progress notification. - * @return A NotificationHandler that processes progress notifications by distributing - * the message to all registered consumers - */ private NotificationHandler asyncProgressNotificationHandler( List>> progressConsumers) { @@ -911,7 +986,7 @@ void setProtocolVersions(List protocolVersions) { // -------------------------- // Completions // -------------------------- - private static final TypeReference COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeRef<>() { }; /** @@ -925,7 +1000,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.initializer.withIntitialization("complete completions", init -> init.mcpSession() + return this.initializer.withInitialization("complete completions", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java similarity index 85% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java index c8af28ac1..12f34e60a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -4,17 +4,11 @@ package io.modelcontextprotocol.client; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; @@ -22,9 +16,19 @@ import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.Implementation; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + /** * Factory class for creating Model Context Protocol (MCP) clients. MCP is a protocol that * enables AI models to interact with external tools and resources through a standardized @@ -72,6 +76,7 @@ * .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> System.out.println("Resources updated: " + resources))) * .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> System.out.println("Prompts updated: " + prompts))) * .loggingConsumer(message -> Mono.fromRunnable(() -> System.out.println("Log message: " + message))) + * .resourcesUpdateConsumer(resourceContents -> Mono.fromRunnable(() -> System.out.println("Resources contents updated: " + resourceContents))) * .build(); * } * @@ -97,6 +102,7 @@ * * @author Christian Tzolov * @author Dariusz JΔ™drzejczyk + * @author Anurag Pant * @see McpAsyncClient * @see McpSyncClient * @see McpTransport @@ -163,7 +169,7 @@ class SyncSpec { private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Client", "0.15.0"); private final Map roots = new HashMap<>(); @@ -183,6 +189,12 @@ class SyncSpec { private Function elicitationHandler; + private Supplier contextProvider = () -> McpTransportContext.EMPTY; + + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -336,6 +348,22 @@ public SyncSpec resourcesChangeConsumer(Consumer> resou return this; } + /** + * Adds a consumer to be notified when a specific resource is updated. This allows + * the client to react to changes in individual resources, such as updates to + * their content or metadata. + * @param resourcesUpdateConsumer A consumer function that processes the updated + * resource and returns a Mono indicating the completion of the processing. Must + * not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException If the resourcesUpdateConsumer is null. + */ + public SyncSpec resourcesUpdateConsumer(Consumer> resourcesUpdateConsumer) { + Assert.notNull(resourcesUpdateConsumer, "Resources update consumer must not be null"); + this.resourcesUpdateConsumers.add(resourcesUpdateConsumer); + return this; + } + /** * Adds a consumer to be notified when the available prompts change. This allows * the client to react to changes in the server's prompt templates, such as new @@ -409,6 +437,48 @@ public SyncSpec progressConsumers(List> return this; } + /** + * Add a provider of {@link McpTransportContext}, providing a context before + * calling any client operation. This allows to extract thread-locals and hand + * them over to the underlying transport. + *

+ * There is no direct equivalent in {@link AsyncSpec}. To achieve the same result, + * append {@code contextWrite(McpTransportContext.KEY, context)} to any + * {@link McpAsyncClient} call. + * @param contextProvider A supplier to create a context + * @return This builder for method chaining + */ + public SyncSpec transportContextProvider(Supplier contextProvider) { + this.contextProvider = contextProvider; + return this; + } + + /** + * Add a {@link JsonSchemaValidator} to validate the JSON structure of the + * structured output. + * @param jsonSchemaValidator A validator to validate the JSON structure of the + * structured output. Must not be null. + * @return This builder for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public SyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. @@ -418,12 +488,13 @@ public McpSyncClient build() { McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler, - this.elicitationHandler); + this.elicitationHandler, this.enableCallToolSchemaCaching); McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient( - new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); + return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, + jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(), + asyncFeatures), this.contextProvider); } } @@ -454,7 +525,7 @@ class AsyncSpec { private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); + private Implementation clientInfo = new Implementation("Java SDK MCP Client", "0.15.0"); private final Map roots = new HashMap<>(); @@ -474,6 +545,10 @@ class AsyncSpec { private Function> elicitationHandler; + private JsonSchemaValidator jsonSchemaValidator; + + private boolean enableCallToolSchemaCaching = false; // Default to false + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -720,17 +795,45 @@ public AsyncSpec progressConsumers( return this; } + /** + * Sets the JSON schema validator to use for validating tool responses against + * output schemas. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public AsyncSpec jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enables automatic schema caching during callTool operations. When a tool's + * output schema is not found in the cache, callTool will automatically fetch and + * cache all tool schemas via listTools. + * @param enableCallToolSchemaCaching true to enable, false to disable + * @return This builder instance for method chaining + */ + public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) { + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + return this; + } + /** * Create an instance of {@link McpAsyncClient} with the provided configurations * or sensible defaults. * @return a new instance of {@link McpAsyncClient}. */ public McpAsyncClient build() { + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : McpJsonDefaults.getSchemaValidator(); return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, + jsonSchemaValidator, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, - this.samplingHandler, this.elicitationHandler)); + this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java similarity index 94% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 3b6550765..127d53337 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -62,6 +62,7 @@ class McpClientFeatures { * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List, Mono>> toolsChangeConsumers, @@ -71,7 +72,8 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> loggingConsumers, List>> progressConsumers, Function> samplingHandler, - Function> elicitationHandler) { + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -84,6 +86,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, @@ -94,7 +97,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> loggingConsumers, List>> progressConsumers, Function> samplingHandler, - Function> elicitationHandler) { + Function> elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -113,6 +117,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; } /** @@ -129,7 +134,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c Function> elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler); + elicitationHandler, false); } /** @@ -187,7 +192,8 @@ public static Async fromSync(Sync syncSpec) { return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, - loggingConsumers, progressConsumers, samplingHandler, elicitationHandler); + loggingConsumers, progressConsumers, samplingHandler, elicitationHandler, + syncSpec.enableCallToolSchemaCaching); } } @@ -205,6 +211,7 @@ public static Async fromSync(Sync syncSpec) { * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -214,7 +221,8 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List> loggingConsumers, List> progressConsumers, Function samplingHandler, - Function elicitationHandler) { + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { /** * Create an instance and validate the arguments. @@ -229,6 +237,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, Map roots, List>> toolsChangeConsumers, @@ -238,7 +247,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List> loggingConsumers, List> progressConsumers, Function samplingHandler, - Function elicitationHandler) { + Function elicitationHandler, + boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -257,6 +267,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; } /** @@ -272,7 +283,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl Function elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler); + elicitationHandler, false); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java similarity index 82% rename from mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 33784adcd..7fdaa8941 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -5,16 +5,19 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; /** * A synchronous client implementation for the Model Context Protocol (MCP) that wraps an @@ -63,14 +66,20 @@ public class McpSyncClient implements AutoCloseable { private final McpAsyncClient delegate; + private final Supplier contextProvider; + /** * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. + * @param contextProvider the supplier of context before calling any non-blocking + * operation on underlying delegate */ - McpSyncClient(McpAsyncClient delegate) { + McpSyncClient(McpAsyncClient delegate, Supplier contextProvider) { Assert.notNull(delegate, "The delegate can not be null"); + Assert.notNull(contextProvider, "The contextProvider can not be null"); this.delegate = delegate; + this.contextProvider = contextProvider; } /** @@ -177,14 +186,14 @@ public boolean closeGracefully() { public McpSchema.InitializeResult initialize() { // TODO: block takes no argument here as we assume the async client is // configured with a requestTimeout at all times - return this.delegate.initialize().block(); + return withProvidedContext(this.delegate.initialize()).block(); } /** * Send a roots/list_changed notification. */ public void rootsListChangedNotification() { - this.delegate.rootsListChangedNotification().block(); + withProvidedContext(this.delegate.rootsListChangedNotification()).block(); } /** @@ -206,7 +215,7 @@ public void removeRoot(String rootUri) { * @return */ public Object ping() { - return this.delegate.ping().block(); + return withProvidedContext(this.delegate.ping()).block(); } // -------------------------- @@ -224,7 +233,8 @@ public Object ping() { * Boolean indicating if the execution failed (true) or succeeded (false/absent) */ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) { - return this.delegate.callTool(callToolRequest).block(); + return withProvidedContext(this.delegate.callTool(callToolRequest)).block(); + } /** @@ -234,7 +244,7 @@ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolReque * pagination if more tools are available */ public McpSchema.ListToolsResult listTools() { - return this.delegate.listTools().block(); + return withProvidedContext(this.delegate.listTools()).block(); } /** @@ -245,7 +255,8 @@ public McpSchema.ListToolsResult listTools() { * pagination if more tools are available */ public McpSchema.ListToolsResult listTools(String cursor) { - return this.delegate.listTools(cursor).block(); + return withProvidedContext(this.delegate.listTools(cursor)).block(); + } // -------------------------- @@ -257,7 +268,8 @@ public McpSchema.ListToolsResult listTools(String cursor) { * @return The list of all resources result */ public McpSchema.ListResourcesResult listResources() { - return this.delegate.listResources().block(); + return withProvidedContext(this.delegate.listResources()).block(); + } /** @@ -266,7 +278,8 @@ public McpSchema.ListResourcesResult listResources() { * @return The list of resources result */ public McpSchema.ListResourcesResult listResources(String cursor) { - return this.delegate.listResources(cursor).block(); + return withProvidedContext(this.delegate.listResources(cursor)).block(); + } /** @@ -275,7 +288,8 @@ public McpSchema.ListResourcesResult listResources(String cursor) { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { - return this.delegate.readResource(resource).block(); + return withProvidedContext(this.delegate.readResource(resource)).block(); + } /** @@ -284,7 +298,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.Resource resource) { * @return the resource content. */ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.delegate.readResource(readResourceRequest).block(); + return withProvidedContext(this.delegate.readResource(readResourceRequest)).block(); + } /** @@ -292,7 +307,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest r * @return The list of all resource templates result. */ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { - return this.delegate.listResourceTemplates().block(); + return withProvidedContext(this.delegate.listResourceTemplates()).block(); + } /** @@ -304,7 +320,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { * @return The list of resource templates result. */ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor) { - return this.delegate.listResourceTemplates(cursor).block(); + return withProvidedContext(this.delegate.listResourceTemplates(cursor)).block(); + } /** @@ -317,7 +334,8 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor * subscribe to. */ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - this.delegate.subscribeResource(subscribeRequest).block(); + withProvidedContext(this.delegate.subscribeResource(subscribeRequest)).block(); + } /** @@ -326,7 +344,8 @@ public void subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { * to unsubscribe from. */ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - this.delegate.unsubscribeResource(unsubscribeRequest).block(); + withProvidedContext(this.delegate.unsubscribeResource(unsubscribeRequest)).block(); + } // -------------------------- @@ -338,7 +357,7 @@ public void unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) * @return The list of all prompts result. */ public ListPromptsResult listPrompts() { - return this.delegate.listPrompts().block(); + return withProvidedContext(this.delegate.listPrompts()).block(); } /** @@ -347,11 +366,12 @@ public ListPromptsResult listPrompts() { * @return The list of prompts result. */ public ListPromptsResult listPrompts(String cursor) { - return this.delegate.listPrompts(cursor).block(); + return withProvidedContext(this.delegate.listPrompts(cursor)).block(); + } public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { - return this.delegate.getPrompt(getPromptRequest).block(); + return withProvidedContext(this.delegate.getPrompt(getPromptRequest)).block(); } /** @@ -359,7 +379,8 @@ public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { * @param loggingLevel the min logging level */ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { - this.delegate.setLoggingLevel(loggingLevel).block(); + withProvidedContext(this.delegate.setLoggingLevel(loggingLevel)).block(); + } /** @@ -369,7 +390,18 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) { * @return the completion result containing suggested values. */ public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.delegate.completeCompletion(completeRequest).block(); + return withProvidedContext(this.delegate.completeCompletion(completeRequest)).block(); + + } + + /** + * For a given action, on assembly, capture the "context" via the + * {@link #contextProvider} and store it in the Reactor context. + * @param action the action to perform + * @return the result of the action + */ + private Mono withProvidedContext(Mono action) { + return action.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, this.contextProvider.get())); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java similarity index 70% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index a92f26c4e..be4e4cf97 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -18,16 +18,19 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.spec.McpTransportException; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import reactor.core.Disposable; @@ -94,8 +97,8 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** HTTP request builder for building requests to send messages to the server */ private final HttpRequest.Builder requestBuilder; - /** JSON object mapper for message serialization/deserialization */ - protected ObjectMapper objectMapper; + /** JSON mapper for message serialization/deserialization */ + protected McpJsonMapper jsonMapper; /** Flag indicating if the transport is in closing state */ private volatile boolean isClosing = false; @@ -112,66 +115,7 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** * Customizer to modify requests before they are executed. */ - private final AsyncHttpRequestCustomizer httpRequestCustomizer; - - /** - * Creates a new transport instance with default HTTP client and object mapper. - * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(String baseUri) { - this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { - this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder and object mapper. - * @param clientBuilder the HTTP client builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, - ObjectMapper objectMapper) { - this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param clientBuilder the HTTP client builder to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This - * constructor will be removed in future versions. - */ - @Deprecated(forRemoval = true) - public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, - String baseUri, String sseEndpoint, ObjectMapper objectMapper) { - this(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, objectMapper); - } + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; /** * Creates a new transport instance with custom HTTP client builder, object mapper, @@ -180,30 +124,14 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @param requestBuilder the HTTP request builder to use * @param baseUri the base URI of the MCP server * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization - * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - */ - @Deprecated(forRemoval = true) - HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, ObjectMapper objectMapper) { - this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, AsyncHttpRequestCustomizer.NOOP); - } - - /** - * Creates a new transport instance with custom HTTP client builder, object mapper, - * and headers. - * @param httpClient the HTTP client to use - * @param requestBuilder the HTTP request builder to use - * @param baseUri the base URI of the MCP server - * @param sseEndpoint the SSE endpoint path - * @param objectMapper the object mapper for JSON serialization/deserialization + * @param jsonMapper the object mapper for JSON serialization/deserialization * @param httpRequestCustomizer customizer for the requestBuilder before executing * requests * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, - String sseEndpoint, ObjectMapper objectMapper, AsyncHttpRequestCustomizer httpRequestCustomizer) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); @@ -211,7 +139,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.httpRequestCustomizer = httpRequestCustomizer; @@ -242,11 +170,11 @@ public static class Builder { private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); - private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; private Duration connectTimeout = Duration.ofSeconds(10); @@ -257,19 +185,6 @@ public static class Builder { // Default constructor } - /** - * Creates a new builder with the specified base URI. - * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. - * This constructor is deprecated and will be removed or made {@code protected} or - * {@code private} in a future release. - */ - @Deprecated(forRemoval = true) - public Builder(String baseUri) { - Assert.hasText(baseUri, "baseUri must not be empty"); - this.baseUri = baseUri; - } - /** * Sets the base URI. * @param baseUri the base URI @@ -337,13 +252,13 @@ public Builder customizeRequest(final Consumer requestCusto } /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper + * Sets the JSON mapper implementation to use for serialization/deserialization. + * @param jsonMapper the JSON mapper * @return this builder */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -352,16 +267,17 @@ public Builder objectMapper(ObjectMapper objectMapper) { * executing them. *

* This overrides the customizer from - * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. *

- * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking - * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} * instead. * @param syncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { - this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); return this; } @@ -370,13 +286,13 @@ public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCu * executing them. *

* This overrides the customizer from - * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. *

* Do NOT use a blocking implementation in a non-blocking context. * @param asyncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { this.httpRequestCustomizer = asyncHttpRequestCustomizer; return this; } @@ -398,8 +314,8 @@ public Builder connectTimeout(Duration connectTimeout) { */ public HttpClientSseClientTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); - return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, - httpRequestCustomizer); + return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint, + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpRequestCustomizer); } } @@ -408,14 +324,15 @@ public HttpClientSseClientTransport build() { public Mono connect(Function, Mono> handler) { var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { var builder = requestBuilder.copy() .uri(uri) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) .GET(); - return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); }).flatMap(requestBuilder -> Mono.create(sink -> { Disposable connection = Flux.create(sseSink -> this.httpClient .sendAsync(requestBuilder.build(), @@ -445,7 +362,7 @@ public Mono connect(Function, Mono> h } } else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, responseEvent.sseEvent().data()); sink.success(); return Flux.just(message); @@ -526,7 +443,7 @@ public Mono sendMessage(JSONRPCMessage message) { private Mono serializeMessage(final JSONRPCMessage message) { return Mono.defer(() -> { try { - return Mono.just(objectMapper.writeValueAsString(message)); + return Mono.just(jsonMapper.writeValueAsString(message)); } catch (IOException e) { return Mono.error(new McpTransportException("Failed to serialize message", e)); @@ -536,13 +453,14 @@ private Mono serializeMessage(final JSONRPCMessage message) { private Mono> sendHttpPost(final String endpoint, final String body) { final URI requestUri = Utils.resolveUri(baseUri, endpoint); - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { var builder = this.requestBuilder.copy() .uri(requestUri) - .header("Content-Type", "application/json") + .header(HttpHeaders.CONTENT_TYPE, "application/json") .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) .POST(HttpRequest.BodyPublishers.ofString(body)); - return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body, transportContext)); }).flatMap(customizedBuilder -> { var request = customizedBuilder.build(); return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); @@ -576,8 +494,8 @@ public Mono closeGracefully() { * @return the unmarshalled object */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java similarity index 75% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index d8c49ae2f..d6b01e17f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -11,6 +11,8 @@ import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.time.Duration; +import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletionException; @@ -18,14 +20,15 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.ClosedMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportStream; import io.modelcontextprotocol.spec.HttpHeaders; @@ -38,6 +41,9 @@ import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; @@ -74,8 +80,6 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class); - private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; - private static final String DEFAULT_ENDPOINT = "/mcp"; /** @@ -103,7 +107,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { public static int BAD_REQUEST = 400; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final URI baseUri; @@ -113,18 +117,23 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; - private final AsyncHttpRequestCustomizer httpRequestCustomizer; + private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; - private final AtomicReference activeSession = new AtomicReference<>(); + private final AtomicReference> activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); private final AtomicReference> exceptionHandler = new AtomicReference<>(); - private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, + private final List supportedProtocolVersions; + + private final String latestSupportedProtocolVersion; + + private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, - boolean openConnectionOnStartup, AsyncHttpRequestCustomizer httpRequestCustomizer) { - this.objectMapper = objectMapper; + boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, + List supportedProtocolVersions) { + this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.baseUri = URI.create(baseUri); @@ -133,11 +142,16 @@ private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); this.httpRequestCustomizer = httpRequestCustomizer; + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + this.latestSupportedProtocolVersion = this.supportedProtocolVersions.stream() + .sorted(Comparator.reverseOrder()) + .findFirst() + .get(); } @Override public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + return supportedProtocolVersions; } public static Builder builder(String baseUri) { @@ -159,23 +173,34 @@ public Mono connect(Function, Mono createTransportSession() { Function> onClose = sessionId -> sessionId == null ? Mono.empty() : createDelete(sessionId); return new DefaultMcpTransportSession(onClose); } + private McpTransportSession createClosedSession(McpTransportSession existingSession) { + var existingSessionId = Optional.ofNullable(existingSession) + .filter(session -> !(session instanceof ClosedMcpTransportSession)) + .flatMap(McpTransportSession::sessionId) + .orElse(null); + return new ClosedMcpTransportSession<>(existingSessionId); + } + private Publisher createDelete(String sessionId) { var uri = Utils.resolveUri(this.baseUri, this.endpoint); - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { var builder = this.requestBuilder.copy() .uri(uri) .header("Cache-Control", "no-cache") .header(HttpHeaders.MCP_SESSION_ID, sessionId) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .DELETE(); - return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null, transportContext)); }).flatMap(requestBuilder -> { var request = requestBuilder.build(); return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); @@ -205,9 +230,9 @@ private void handleException(Throwable t) { public Mono closeGracefully() { return Mono.defer(() -> { logger.debug("Graceful close triggered"); - DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + McpTransportSession currentSession = this.activeSession.getAndUpdate(this::createClosedSession); if (currentSession != null) { - return currentSession.closeGracefully(); + return Mono.from(currentSession.closeGracefully()); } return Mono.empty(); }); @@ -228,7 +253,7 @@ private Mono reconnect(McpTransportStream stream) { final McpTransportSession transportSession = this.activeSession.get(); var uri = Utils.resolveUri(this.baseUri, this.endpoint); - Disposable connection = Mono.defer(() -> { + Disposable connection = Mono.deferContextual(connectionCtx -> { HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); if (transportSession != null && transportSession.sessionId().isPresent()) { @@ -241,11 +266,14 @@ private Mono reconnect(McpTransportStream stream) { } var builder = requestBuilder.uri(uri) - .header("Accept", TEXT_EVENT_STREAM) + .header(HttpHeaders.ACCEPT, TEXT_EVENT_STREAM) .header("Cache-Control", "no-cache") - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.PROTOCOL_VERSION, + connectionCtx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .GET(); - return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + var transportContext = connectionCtx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); }) .flatMapMany( requestBuilder -> Flux.create( @@ -268,12 +296,23 @@ private Mono reconnect(McpTransportStream stream) { if (statusCode >= 200 && statusCode < 300) { if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + String data = responseEvent.sseEvent().data(); + // Per 2025-11-25 spec (SEP-1699), servers may + // send SSE events + // with empty data to prime the client for + // reconnection. + // Skip these events as they contain no JSON-RPC + // message. + if (data == null || data.isBlank()) { + logger.debug("Skipping SSE event with empty data (stream primer)"); + return Flux.empty(); + } try { // We don't support batching ATM and probably // won't since the next version considers // removing it. - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage( - this.objectMapper, responseEvent.sseEvent().data()); + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.jsonMapper, data); Tuple2, Iterable> idWithMessages = Tuples .of(Optional.ofNullable(responseEvent.sseEvent().id()), @@ -365,7 +404,7 @@ private BodyHandler toSendMessageBodySubscriber(FluxSink si BodyHandler responseBodyHandler = responseInfo -> { - String contentType = responseInfo.headers().firstValue("Content-Type").orElse("").toLowerCase(); + String contentType = responseInfo.headers().firstValue(HttpHeaders.CONTENT_TYPE).orElse("").toLowerCase(); if (contentType.contains(TEXT_EVENT_STREAM)) { // For SSE streams, use line subscriber that returns Void @@ -388,7 +427,7 @@ else if (contentType.contains(APPLICATION_JSON)) { public String toString(McpSchema.JSONRPCMessage message) { try { - return this.objectMapper.writeValueAsString(message); + return this.jsonMapper.writeValueAsString(message); } catch (IOException e) { throw new RuntimeException("Failed to serialize JSON-RPC message", e); @@ -405,7 +444,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { var uri = Utils.resolveUri(this.baseUri, this.endpoint); String jsonBody = this.toString(sentMessage); - Disposable connection = Mono.defer(() -> { + Disposable connection = Mono.deferContextual(ctx -> { HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); if (transportSession != null && transportSession.sessionId().isPresent()) { @@ -414,12 +453,16 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { } var builder = requestBuilder.uri(uri) - .header("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) - .header("Content-Type", APPLICATION_JSON) - .header("Cache-Control", "no-cache") - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .header(HttpHeaders.ACCEPT, APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) + .header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON) + .header(HttpHeaders.CACHE_CONTROL, "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, + ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, + this.latestSupportedProtocolVersion)) .POST(HttpRequest.BodyPublishers.ofString(jsonBody)); - return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", uri, jsonBody)); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono + .from(this.httpRequestCustomizer.customize(builder, "POST", uri, jsonBody, transportContext)); }).flatMapMany(requestBuilder -> Flux.create(responseEventSink -> { // Create the async request with proper body subscriber selection @@ -451,28 +494,43 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { String contentType = responseEvent.responseInfo() .headers() - .firstValue("Content-Type") + .firstValue(HttpHeaders.CONTENT_TYPE) .orElse("") .toLowerCase(); - if (contentType.isBlank()) { - logger.debug("No content type returned for POST in session {}", sessionRepresentation); + String contentLength = responseEvent.responseInfo() + .headers() + .firstValue(HttpHeaders.CONTENT_LENGTH) + .orElse(null); + + // For empty content or HTTP code 202 (ACCEPTED), assume success + if (contentType.isBlank() || "0".equals(contentLength) || statusCode == 202) { + // if (contentType.isBlank() || "0".equals(contentLength)) { + logger.debug("No body returned for POST in session {}", sessionRepresentation); // No content type means no response body, so we can just - // return - // an empty stream + // return an empty stream deliveredSink.success(); return Flux.empty(); } else if (contentType.contains(TEXT_EVENT_STREAM)) { return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) .flatMap(sseEvent -> { + String data = sseEvent.data(); + // Per 2025-11-25 spec (SEP-1699), servers may send SSE + // events + // with empty data to prime the client for reconnection. + // Skip these events as they contain no JSON-RPC message. + if (data == null || data.isBlank()) { + logger.debug("Skipping SSE event with empty data (stream primer)"); + return Flux.empty(); + } try { // We don't support batching ATM and probably // won't // since the // next version considers removing it. McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.objectMapper, sseEvent.data()); + .deserializeJsonRpcMessage(this.jsonMapper, data); Tuple2, Iterable> idWithMessages = Tuples .of(Optional.ofNullable(sseEvent.id()), List.of(message)); @@ -495,13 +553,14 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { else if (contentType.contains(APPLICATION_JSON)) { deliveredSink.success(); String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); - if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(data)) { - logger.warn("Notification: {} received non-compliant response: {}", sentMessage, data); + if (sentMessage instanceof McpSchema.JSONRPCNotification) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, + Utils.hasText(data) ? data : "[empty]"); return Mono.empty(); } try { - return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); + return Mono.just(McpSchema.deserializeJsonRpcMessage(jsonMapper, data)); } catch (IOException e) { return Mono.error(new McpTransportException( @@ -575,8 +634,8 @@ private static String sessionIdOrPlaceholder(McpTransportSession transportSes } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } /** @@ -586,7 +645,7 @@ public static class Builder { private final String baseUri; - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); @@ -598,10 +657,13 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); - private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + private McpAsyncHttpClientRequestCustomizer httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.NOOP; private Duration connectTimeout = Duration.ofSeconds(10); + private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, + ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18, ProtocolVersions.MCP_2025_11_25); + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -656,13 +718,13 @@ public Builder customizeRequest(final Consumer requestCusto } /** - * Configure the {@link ObjectMapper} to use. - * @param objectMapper instance to use + * Configure a custom {@link McpJsonMapper} implementation to use. + * @param jsonMapper instance to use * @return the builder instance */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -709,16 +771,17 @@ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { * executing them. *

* This overrides the customizer from - * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)}. *

- * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking - * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * Do NOT use a blocking {@link McpSyncHttpClientRequestCustomizer} in a + * non-blocking context. Use + * {@link #asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer)} * instead. * @param syncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { - this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + public Builder httpRequestCustomizer(McpSyncHttpClientRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = McpAsyncHttpClientRequestCustomizer.fromSync(syncHttpRequestCustomizer); return this; } @@ -727,13 +790,13 @@ public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCu * executing them. *

* This overrides the customizer from - * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + * {@link #httpRequestCustomizer(McpSyncHttpClientRequestCustomizer)}. *

* Do NOT use a blocking implementation in a non-blocking context. * @param asyncHttpRequestCustomizer the request customizer * @return this builder */ - public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer asyncHttpRequestCustomizer) { this.httpRequestCustomizer = asyncHttpRequestCustomizer; return this; } @@ -749,18 +812,40 @@ public Builder connectTimeout(Duration connectTimeout) { return this; } + /** + * Sets the list of supported protocol versions used in version negotiation. By + * default, the client will send the latest of those versions in the + * {@code MCP-Protocol-Version} header. + *

+ * Setting this value only updates the values used in version negotiation, and + * does NOT impact the actual capabilities of the transport. It should only be + * used for compatibility with servers having strict requirements around the + * {@code MCP-Protocol-Version} header. + * @param supportedProtocolVersions protocol versions supported by this transport + * @return this builder + * @see version + * negotiation specification + * @see Protocol + * Version Header + */ + public Builder supportedProtocolVersions(List supportedProtocolVersions) { + Assert.notEmpty(supportedProtocolVersions, "supportedProtocolVersions must not be empty"); + this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); + return this; + } + /** * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using * the current builder configuration. * @return a new instance of {@link HttpClientStreamableHttpTransport} */ public HttpClientStreamableHttpTransport build() { - ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); - - return new HttpClientStreamableHttpTransport(objectMapper, httpClient, requestBuilder, baseUri, endpoint, - resumableStreams, openConnectionOnStartup, httpRequestCustomizer); + return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, + httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, + httpRequestCustomizer, supportedProtocolVersions); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/ServerParameters.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java similarity index 92% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 009d415e0..1b4eaca97 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -15,8 +15,8 @@ import java.util.function.Consumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -48,7 +48,7 @@ public class StdioClientTransport implements McpClientTransport { /** The server process being communicated with */ private Process process; - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; /** Scheduler for handling inbound messages from the server process */ private Scheduler inboundScheduler; @@ -70,29 +70,20 @@ public class StdioClientTransport implements McpClientTransport { private Consumer stdErrorHandler = error -> logger.info("STDERR Message received: {}", error); /** - * Creates a new StdioClientTransport with the specified parameters and default - * ObjectMapper. + * Creates a new StdioClientTransport with the specified parameters and JsonMapper. * @param params The parameters for configuring the server process + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - public StdioClientTransport(ServerParameters params) { - this(params, new ObjectMapper()); - } - - /** - * Creates a new StdioClientTransport with the specified parameters and ObjectMapper. - * @param params The parameters for configuring the server process - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper) { + public StdioClientTransport(ServerParameters params, McpJsonMapper jsonMapper) { Assert.notNull(params, "The params can not be null"); - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); this.params = params; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.errorSink = Sinks.many().unicast().onBackpressureBuffer(); @@ -259,7 +250,7 @@ private void startInboundProcessing() { String line; while (!isClosing && (line = processReader.readLine()) != null) { try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, line); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { if (!isClosing) { logger.error("Failed to enqueue inbound message: {}", message); @@ -300,7 +291,7 @@ private void startOutboundProcessing() { .handle((message, s) -> { if (message != null && !isClosing) { try { - String jsonMessage = objectMapper.writeValueAsString(message); + String jsonMessage = jsonMapper.writeValueAsString(message); // Escape any embedded newlines in the JSON message as per spec: // https://spec.modelcontextprotocol.io/specification/basic/transports/#stdio // - Messages are delimited by newlines, and MUST NOT contain @@ -392,8 +383,8 @@ public Sinks.Many getErrorSink() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return this.jsonMapper.convertValue(data, typeRef); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..2492efe18 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizer.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import org.reactivestreams.Publisher; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +import reactor.core.publisher.Mono; + +/** + * Composable {@link McpAsyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpAsyncHttpClientRequestCustomizer implements McpAsyncHttpClientRequestCustomizer { + + private final List customizers; + + public DelegatingMcpAsyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.customizers = customizers; + } + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body, McpTransportContext context) { + var result = Mono.just(builder); + for (var customizer : this.customizers) { + result = result.flatMap(b -> Mono.from(customizer.customize(b, method, endpoint, body, context))); + } + return result; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e627e7e69 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.util.Assert; + +/** + * Composable {@link McpSyncHttpClientRequestCustomizer} that applies multiple + * customizers, in order. + * + * @author Daniel Garnier-Moiroux + */ +public class DelegatingMcpSyncHttpClientRequestCustomizer implements McpSyncHttpClientRequestCustomizer { + + private final List delegates; + + public DelegatingMcpSyncHttpClientRequestCustomizer(List customizers) { + Assert.notNull(customizers, "Customizers must not be null"); + this.delegates = customizers; + } + + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, + McpTransportContext context) { + this.delegates.forEach(delegate -> delegate.customize(builder, method, endpoint, body, context)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java similarity index 62% rename from mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java index dee026d96..756b39c35 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpAsyncHttpClientRequestCustomizer.java @@ -2,15 +2,18 @@ * Copyright 2024-2025 the original author or authors. */ -package io.modelcontextprotocol.client.transport; +package io.modelcontextprotocol.client.transport.customizer; import java.net.URI; import java.net.http.HttpRequest; + import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.util.annotation.Nullable; +import io.modelcontextprotocol.common.McpTransportContext; + /** * Customize {@link HttpRequest.Builder} before executing the request, in either SSE or * Streamable HTTP transport. @@ -19,12 +22,12 @@ * * @author Daniel Garnier-Moiroux */ -public interface AsyncHttpRequestCustomizer { +public interface McpAsyncHttpClientRequestCustomizer { Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, - @Nullable String body); + @Nullable String body, McpTransportContext context); - AsyncHttpRequestCustomizer NOOP = new Noop(); + McpAsyncHttpClientRequestCustomizer NOOP = new Noop(); /** * Wrap a sync implementation in an async wrapper. @@ -32,18 +35,18 @@ Publisher customize(HttpRequest.Builder builder, String met * Do NOT wrap a blocking implementation for use in a non-blocking context. For a * blocking implementation, consider using {@link Schedulers#boundedElastic()}. */ - static AsyncHttpRequestCustomizer fromSync(SyncHttpRequestCustomizer customizer) { - return (builder, method, uri, body) -> Mono.fromSupplier(() -> { - customizer.customize(builder, method, uri, body); + static McpAsyncHttpClientRequestCustomizer fromSync(McpSyncHttpClientRequestCustomizer customizer) { + return (builder, method, uri, body, context) -> Mono.fromSupplier(() -> { + customizer.customize(builder, method, uri, body, context); return builder; }); } - class Noop implements AsyncHttpRequestCustomizer { + class Noop implements McpAsyncHttpClientRequestCustomizer { @Override public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, - String body) { + String body, McpTransportContext context) { return Mono.just(builder); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java new file mode 100644 index 000000000..e22e3aa62 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpSyncHttpClientRequestCustomizer.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; + +import reactor.util.annotation.Nullable; + +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.common.McpTransportContext; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or + * Streamable HTTP transport. Do not rely on thread-locals in this implementation, instead + * use {@link SyncSpec#transportContextProvider} to extract context, and then consume it + * through {@link McpTransportContext}. + * + * @author Daniel Garnier-Moiroux + */ +public interface McpSyncHttpClientRequestCustomizer { + + void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body, + McpTransportContext context); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java new file mode 100644 index 000000000..cde637b15 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/DefaultMcpTransportContext.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation for {@link McpTransportContext} which uses a map as storage. + * + * @author Dariusz JΔ™drzejczyk + * @author Daniel Garnier-Moiroux + */ +class DefaultMcpTransportContext implements McpTransportContext { + + private final Map metadata; + + DefaultMcpTransportContext(Map metadata) { + Assert.notNull(metadata, "The metadata cannot be null"); + this.metadata = metadata; + } + + @Override + public Object get(String key) { + return this.metadata.get(key); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) + return false; + + DefaultMcpTransportContext that = (DefaultMcpTransportContext) o; + return this.metadata.equals(that.metadata); + } + + @Override + public int hashCode() { + return this.metadata.hashCode(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java similarity index 68% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java rename to mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java index 1cd540f72..46a2ccf84 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/common/McpTransportContext.java @@ -2,9 +2,10 @@ * Copyright 2024-2025 the original author or authors. */ -package io.modelcontextprotocol.server; +package io.modelcontextprotocol.common; import java.util.Collections; +import java.util.Map; /** * Context associated with the transport layer. It allows to add transport-level metadata @@ -26,6 +27,15 @@ public interface McpTransportContext { @SuppressWarnings("unchecked") McpTransportContext EMPTY = new DefaultMcpTransportContext(Collections.EMPTY_MAP); + /** + * Create an unmodifiable context containing the given metadata. + * @param metadata the transport metadata + * @return the context containing the metadata + */ + static McpTransportContext create(Map metadata) { + return new DefaultMcpTransportContext(metadata); + } + /** * Extract a value from the context. * @param key the key under the data is expected @@ -33,18 +43,4 @@ public interface McpTransportContext { */ Object get(String key); - /** - * Inserts a value for a given key. - * @param key a String representing the key - * @param value the value to store - */ - void put(String key, Object value); - - /** - * Copies the contents of the context to allow further modifications without affecting - * the initial object. - * @return a new instance with the underlying storage copied. - */ - McpTransportContext copy(); - } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonDefaults.java b/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonDefaults.java new file mode 100644 index 000000000..11b370ed8 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonDefaults.java @@ -0,0 +1,76 @@ +/** + * Copyright 2026 - 2026 the original author or authors. + */ +package io.modelcontextprotocol.json; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier; +import io.modelcontextprotocol.util.McpServiceLoader; + +/** + * This class is to be used to provide access to the default {@link McpJsonMapper} and to + * the default {@link JsonSchemaValidator} instances via the static methods: + * {@link #getMapper()} and {@link #getSchemaValidator()}. + *

+ * The initialization of (singleton) instances of this class is different in non-OSGi + * environments and OSGi environments. Specifically, in non-OSGi environments the + * {@code McpJsonDefaults} class will be loaded by whatever classloader is used to call + * one of the existing static get methods for the first time. For servers, this will + * usually be in response to the creation of the first {@code McpServer} instance. At that + * first time, the {@code mcpMapperServiceLoader} and {@code mcpValidatorServiceLoader} + * will be null, and the {@code McpJsonDefaults} constructor will be called, + * creating/initializing the {@code mcpMapperServiceLoader} and the + * {@code mcpValidatorServiceLoader}...which will then be used to call the + * {@code ServiceLoader.load} method. + *

+ * In OSGi environments, upon bundle activation SCR will create a new (singleton) instance + * of {@code McpJsonDefaults} (via the constructor), and then inject suppliers via the + * {@code setMcpJsonMapperSupplier} and {@code setJsonSchemaValidatorSupplier} methods + * with the SCR-discovered instances of those services. This does depend upon the + * jars/bundles providing those suppliers to be started/activated. This SCR behavior is + * dictated by xml files in {@code OSGi-INF} directory of {@code mcp-core} (this + * project/jar/bundle), and the jsonmapper and jsonschemavalidator provider jars/bundles + * (e.g. {@code mcp-json-jackson2}, {@code mcp-json-jackson3}, or others). + */ +public class McpJsonDefaults { + + protected static McpServiceLoader mcpMapperServiceLoader; + + protected static McpServiceLoader mcpValidatorServiceLoader; + + public McpJsonDefaults() { + mcpMapperServiceLoader = new McpServiceLoader<>(McpJsonMapperSupplier.class); + mcpValidatorServiceLoader = new McpServiceLoader<>(JsonSchemaValidatorSupplier.class); + } + + void setMcpJsonMapperSupplier(McpJsonMapperSupplier supplier) { + mcpMapperServiceLoader.setSupplier(supplier); + } + + void unsetMcpJsonMapperSupplier(McpJsonMapperSupplier supplier) { + mcpMapperServiceLoader.unsetSupplier(supplier); + } + + public synchronized static McpJsonMapper getMapper() { + if (mcpMapperServiceLoader == null) { + new McpJsonDefaults(); + } + return mcpMapperServiceLoader.getDefault(); + } + + void setJsonSchemaValidatorSupplier(JsonSchemaValidatorSupplier supplier) { + mcpValidatorServiceLoader.setSupplier(supplier); + } + + void unsetJsonSchemaValidatorSupplier(JsonSchemaValidatorSupplier supplier) { + mcpValidatorServiceLoader.unsetSupplier(supplier); + } + + public synchronized static JsonSchemaValidator getSchemaValidator() { + if (mcpValidatorServiceLoader == null) { + new McpJsonDefaults(); + } + return mcpValidatorServiceLoader.getDefault(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java b/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java new file mode 100644 index 000000000..8481d1703 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonMapper.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.io.IOException; + +/** + * Abstraction for JSON serialization/deserialization to decouple the SDK from any + * specific JSON library. A default implementation backed by Jackson is provided in + * io.modelcontextprotocol.spec.json.jackson.JacksonJsonMapper. + */ +public interface McpJsonMapper { + + /** + * Deserialize JSON string into a target type. + * @param content JSON as String + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, Class type) throws IOException; + + /** + * Deserialize JSON bytes into a target type. + * @param content JSON as bytes + * @param type target class + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, Class type) throws IOException; + + /** + * Deserialize JSON string into a parameterized target type. + * @param content JSON as String + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(String content, TypeRef type) throws IOException; + + /** + * Deserialize JSON bytes into a parameterized target type. + * @param content JSON as bytes + * @param type parameterized type reference + * @return deserialized instance + * @param generic type + * @throws IOException on parse errors + */ + T readValue(byte[] content, TypeRef type) throws IOException; + + /** + * Convert a value to a given type, useful for mapping nested JSON structures. + * @param fromValue source value + * @param type target class + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, Class type); + + /** + * Convert a value to a given parameterized type. + * @param fromValue source value + * @param type target type reference + * @return converted value + * @param generic type + */ + T convertValue(Object fromValue, TypeRef type); + + /** + * Serialize an object to JSON string. + * @param value object to serialize + * @return JSON as String + * @throws IOException on serialization errors + */ + String writeValueAsString(Object value) throws IOException; + + /** + * Serialize an object to JSON bytes. + * @param value object to serialize + * @return JSON as bytes + * @throws IOException on serialization errors + */ + byte[] writeValueAsBytes(Object value) throws IOException; + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java b/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java new file mode 100644 index 000000000..619f96040 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/json/McpJsonMapperSupplier.java @@ -0,0 +1,14 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.util.function.Supplier; + +/** + * Strategy interface for resolving a {@link McpJsonMapper}. + */ +public interface McpJsonMapperSupplier extends Supplier { + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/json/TypeRef.java b/mcp-core/src/main/java/io/modelcontextprotocol/json/TypeRef.java new file mode 100644 index 000000000..725513c66 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/json/TypeRef.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +/** + * Captures generic type information at runtime for parameterized JSON (de)serialization. + * Usage: TypeRef> ref = new TypeRef<>(){}; + */ +public abstract class TypeRef { + + private final Type type; + + /** + * Constructs a new TypeRef instance, capturing the generic type information of the + * subclass. This constructor should be called from an anonymous subclass to capture + * the actual type arguments. For example:

+	 * TypeRef<List<Foo>> ref = new TypeRef<>(){};
+	 * 
+ * @throws IllegalStateException if TypeRef is not subclassed with actual type + * information + */ + protected TypeRef() { + Type superClass = getClass().getGenericSuperclass(); + if (superClass instanceof Class) { + throw new IllegalStateException("TypeRef constructed without actual type information"); + } + this.type = ((ParameterizedType) superClass).getActualTypeArguments()[0]; + } + + /** + * Returns the captured type information. + * @return the Type representing the actual type argument captured by this TypeRef + * instance + */ + public Type getType() { + return type; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java new file mode 100644 index 000000000..09fe604f4 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidator.java @@ -0,0 +1,44 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.json.schema; + +import java.util.Map; + +/** + * Interface for validating structured content against a JSON schema. This interface + * defines a method to validate structured content based on the provided output schema. + * + * @author Christian Tzolov + */ +public interface JsonSchemaValidator { + + /** + * Represents the result of a validation operation. + * + * @param valid Indicates whether the validation was successful. + * @param errorMessage An error message if the validation failed, otherwise null. + * @param jsonStructuredOutput The text structured content in JSON format if the + * validation was successful, otherwise null. + */ + record ValidationResponse(boolean valid, String errorMessage, String jsonStructuredOutput) { + + public static ValidationResponse asValid(String jsonStructuredOutput) { + return new ValidationResponse(true, null, jsonStructuredOutput); + } + + public static ValidationResponse asInvalid(String message) { + return new ValidationResponse(false, message, null); + } + } + + /** + * Validates the structured content against the provided JSON schema. + * @param schema The JSON schema to validate against. + * @param structuredContent The structured content to validate. + * @return A ValidationResponse indicating whether the validation was successful or + * not. + */ + ValidationResponse validate(Map schema, Object structuredContent); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java b/mcp-core/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..6f69169a0 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorSupplier.java @@ -0,0 +1,19 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import java.util.function.Supplier; + +/** + * A supplier interface that provides a {@link JsonSchemaValidator} instance. + * Implementations of this interface are expected to return a new or cached instance of + * {@link JsonSchemaValidator} when {@link #get()} is invoked. + * + * @see JsonSchemaValidator + * @see Supplier + */ +public interface JsonSchemaValidatorSupplier extends Supplier { + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java index 2df3514b6..660a15e6a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import org.slf4j.Logger; @@ -31,7 +32,9 @@ public Mono handleRequest(McpTransportContext transpo McpSchema.JSONRPCRequest request) { McpStatelessRequestHandler requestHandler = this.requestHandlers.get(request.method()); if (requestHandler == null) { - return Mono.error(new McpError("Missing handler for request type: " + request.method())); + return Mono.error(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Missing handler for request type: " + request.method()) + .build()); } return requestHandler.handle(transportContext, request.params()) .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java similarity index 67% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index a51c2e36c..32256987a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -5,7 +5,6 @@ package io.modelcontextprotocol.server; import java.time.Duration; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -15,34 +14,37 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; -import io.modelcontextprotocol.spec.McpServerTransportProviderBase; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.spec.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProviderBase; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + /** * The Model Context Protocol (MCP) server implementation that provides asynchronous * communication using Project Reactor's Mono and Flux types. @@ -91,7 +93,7 @@ public class McpAsyncServer { private final McpServerTransportProviderBase mcpTransportProvider; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final JsonSchemaValidator jsonSchemaValidator; @@ -103,10 +105,10 @@ public class McpAsyncServer { private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); // FIXME: this field is deprecated and should be remvoed together with the @@ -117,26 +119,26 @@ public class McpAsyncServer { private List protocolVersions; - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); /** * Create a new McpAsyncServer with the given transport provider and capabilities. * @param mcpTransportProvider The transport layer implementation for MCP * communication. * @param features The MCP server supported features. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); this.instructions = features.instructions(); this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); + this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; @@ -151,17 +153,17 @@ public class McpAsyncServer { requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); } - McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); this.instructions = features.instructions(); this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); + this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; @@ -319,25 +321,24 @@ private McpNotificationHandler asyncRootsListChangedNotificationHandler( */ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); } if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); + return Mono.error(new IllegalArgumentException("Tool must not be null")); } - if (toolSpecification.call() == null && toolSpecification.callHandler() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); + if (toolSpecification.callHandler() == null) { + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { - return Mono.error( - new McpError("Tool with name '" + wrappedToolSpecification.tool().name() + "' already exists")); + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); } this.tools.add(wrappedToolSpecification); @@ -376,6 +377,11 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal return this.delegateCallToolResult.apply(exchange, request).map(result -> { + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + if (outputSchema == null) { if (result.structuredContent() != null) { logger.warn( @@ -391,11 +397,12 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal // results that conform to this schema. // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema if (result.structuredContent() == null) { - logger.warn( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - return new CallToolResult( - "Response missing structured content which is expected when calling tool with non-empty outputSchema", - true); + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); } // Validate the result against the output schema @@ -403,7 +410,10 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal if (!validation.valid()) { logger.warn("Tool call result validation failed: {}", validation.errorMessage()); - return new CallToolResult(validation.errorMessage(), true); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); } if (Utils.isEmpty(result.content())) { @@ -413,8 +423,11 @@ public Mono apply(McpAsyncServerExchange exchange, McpSchema.Cal // TextContent block.) // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content - return new CallToolResult(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput())), - result.isError(), result.structuredContent()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); } return result; @@ -453,6 +466,14 @@ private static McpServerFeatures.AsyncToolSpecification withStructuredOutputHand .build(); } + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpServerFeatures.AsyncToolSpecification::tool); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -460,23 +481,25 @@ private static McpServerFeatures.AsyncToolSpecification withStructuredOutputHand */ public Mono removeTool(String toolName) { if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); + return Mono.error(new IllegalArgumentException("Tool name must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + logger.debug("Removed tool handler: {}", toolName); if (this.serverCapabilities.tools().listChanged()) { return notifyToolsListChanged(); } - return Mono.empty(); } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); }); } @@ -498,8 +521,8 @@ private McpRequestHandler toolsListRequestHandler() { private McpRequestHandler toolsCallRequestHandler() { return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { }); Optional toolSpecification = this.tools.stream() @@ -507,11 +530,13 @@ private McpRequestHandler toolsCallRequestHandler() { .findAny(); if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); } - return toolSpecification.map(tool -> Mono.defer(() -> tool.callHandler().apply(exchange, callToolRequest))) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + return toolSpecification.get().callHandler().apply(exchange, callToolRequest); }; } @@ -526,19 +551,22 @@ private McpRequestHandler toolsCallRequestHandler() { */ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); + return Mono.error(new IllegalArgumentException("Resource must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resources")); } return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); if (this.serverCapabilities.resources().listChanged()) { return notifyResourcesListChanged(); } @@ -546,6 +574,14 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou }); } + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()).map(McpServerFeatures.AsyncResourceSpecification::resource); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -553,10 +589,11 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou */ public Mono removeResource(String resourceUri) { if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resources")); } return Mono.defer(() -> { @@ -568,7 +605,74 @@ public Mono removeResource(String resourceUri) { } return Mono.empty(); } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + else { + logger.warn("Ignore as a Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates.remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); }); } @@ -600,46 +704,50 @@ private McpRequestHandler resourcesListRequestHan } private McpRequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(McpServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; } - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.title(), - resource.description(), resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); + private McpRequestHandler resourcesReadRequestHandler() { + return (ex, params) -> { + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); - list.addAll(resourceTemplates); + var resourceUri = resourceRequest.uri(); - return list; + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ex, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); + }; } - private McpRequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() - .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) - .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } - return Mono.defer(() -> specification.readHandler().apply(exchange, resourceRequest)); - }; + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); } // --------------------------------------- @@ -653,32 +761,36 @@ private McpRequestHandler resourcesReadRequestHand */ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error( - new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); + return this.notifyPromptsListChanged(); } + return Mono.empty(); }); } + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()).map(McpServerFeatures.AsyncPromptSpecification::prompt); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove @@ -686,10 +798,10 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe */ public Mono removePrompt(String promptName) { if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { @@ -697,14 +809,15 @@ public Mono removePrompt(String promptName) { if (removed != null) { logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes if (this.serverCapabilities.prompts().listChanged()) { return this.notifyPromptsListChanged(); } return Mono.empty(); } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + return Mono.empty(); }); } @@ -734,14 +847,18 @@ private McpRequestHandler promptsListRequestHandler private McpRequestHandler promptsGetRequestHandler() { return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { }); // Implement prompt retrieval logic here McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); } return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); @@ -752,39 +869,12 @@ private McpRequestHandler promptsGetRequestHandler() // Logging Management // --------------------------------------- - /** - * This implementation would, incorrectly, broadcast the logging message to all - * connected clients, using a single minLoggingLevel for all of them. Similar to the - * sampling and roots, the logging level should be set per client session and use the - * ServerExchange to send the logging message to the right client. - * @param loggingMessageNotification The logging message to send - * @return A Mono that completes when the notification has been sent - * @deprecated Use - * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} - * instead. - */ - @Deprecated - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, - loggingMessageNotification); - } - private McpRequestHandler setLoggerRequestHandler() { return (exchange, params) -> { return Mono.defer(() -> { - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); + SetLevelRequest newMinLoggingLevel = jsonMapper.convertValue(params, new TypeRef() { + }); exchange.setMinLoggingLevel(newMinLoggingLevel.level()); @@ -797,27 +887,38 @@ private McpRequestHandler setLoggerRequestHandler() { }; } + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + private McpRequestHandler completionCompleteRequestHandler() { return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); } if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); } String type = request.ref().type(); String argumentName = request.argument().name(); - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); } if (!promptSpec.prompt() .arguments() @@ -826,27 +927,67 @@ private McpRequestHandler completionCompleteRequestHan .findFirst() .isPresent()) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; } } - if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; } + McpServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } } + // Handle the completion request using the registered handler + // for the given reference. McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); } return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); @@ -877,9 +1018,9 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) { String refType = (String) refMap.get("type"); McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), refMap.get("title") != null ? (String) refMap.get("title") : null); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); default -> throw new IllegalArgumentException("Invalid ref type: " + refType); }; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java similarity index 79% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 61d60bacc..40a76045b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -4,10 +4,11 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import java.util.ArrayList; import java.util.Collections; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; @@ -36,40 +37,18 @@ public class McpAsyncServerExchange { private final McpTransportContext transportContext; - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef CREATE_MESSAGE_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef LIST_ROOTS_RESULT_TYPE_REF = new TypeRef<>() { }; - private static final TypeReference ELICITATION_RESULT_TYPE_REF = new TypeReference<>() { + private static final TypeRef ELICITATION_RESULT_TYPE_REF = new TypeRef<>() { }; - public static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; - /** - * Create a new asynchronous exchange with the client. - * @param session The server session representing a 1-1 interaction. - * @param clientCapabilities The client capabilities that define the supported - * features and functionality. - * @param clientInfo The client implementation information. - * @deprecated Use - * {@link #McpAsyncServerExchange(String, McpLoggableSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} - */ - @Deprecated - public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, - McpSchema.Implementation clientInfo) { - this.sessionId = null; - if (!(session instanceof McpLoggableSession)) { - throw new IllegalArgumentException("Expecting session to be a McpLoggableSession instance"); - } - this.session = (McpLoggableSession) session; - this.clientCapabilities = clientCapabilities; - this.clientInfo = clientInfo; - this.transportContext = McpTransportContext.EMPTY; - } - /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. @@ -141,10 +120,11 @@ public String sessionId() { */ public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { if (this.clientCapabilities == null) { - return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); } if (this.clientCapabilities.sampling() == null) { - return Mono.error(new McpError("Client must be configured with sampling capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with sampling capabilities")); } return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, CREATE_MESSAGE_RESULT_TYPE_REF); @@ -166,10 +146,11 @@ public Mono createMessage(McpSchema.CreateMessage */ public Mono createElicitation(McpSchema.ElicitRequest elicitRequest) { if (this.clientCapabilities == null) { - return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); } if (this.clientCapabilities.elicitation() == null) { - return Mono.error(new McpError("Client must be configured with elicitation capabilities")); + return Mono.error(new IllegalStateException("Client must be configured with elicitation capabilities")); } return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, ELICITATION_RESULT_TYPE_REF); @@ -214,7 +195,7 @@ public Mono listRoots(String cursor) { public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); + return Mono.error(new IllegalStateException("Logging message must not be null")); } return Mono.defer(() -> { @@ -233,7 +214,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN */ public Mono progressNotification(McpSchema.ProgressNotification progressNotification) { if (progressNotification == null) { - return Mono.error(new McpError("Progress notification must not be null")); + return Mono.error(new IllegalStateException("Progress notification must not be null")); } return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java similarity index 87% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java index f5dfffffb..360eb607d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -13,19 +13,19 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.spec.DefaultJsonSchemaValidator; -import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpStatelessServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.ToolNameValidator; import reactor.core.publisher.Mono; /** @@ -66,17 +66,23 @@ * Example of creating a basic synchronous server:
{@code
  * McpServer.sync(transportProvider)
  *     .serverInfo("my-server", "1.0.0")
- *     .tool(new Tool("calculator", "Performs calculations", schema),
- *           (exchange, args) -> new CallToolResult("Result: " + calculate(args)))
+ *     .toolCall(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
+ *           (exchange, request) -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Result: " + calculate(request.arguments()))))
+ *                   .isError(false)
+ *                   .build())
  *     .build();
  * }
* * Example of creating a basic asynchronous server:
{@code
  * McpServer.async(transportProvider)
  *     .serverInfo("my-server", "1.0.0")
- *     .tool(new Tool("calculator", "Performs calculations", schema),
- *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
- *               .map(result -> new CallToolResult("Result: " + result)))
+ *     .toolCall(Tool.builder().name("calculator").title("Performs calculations").inputSchema(schema).build(),
+ *           (exchange, request) -> Mono.fromSupplier(() -> calculate(request.arguments()))
+ *               .map(result -> CallToolResult.builder()
+ *                   .content(List.of(new McpSchema.TextContent("Result: " + result)))
+ *                   .isError(false)
+ *                   .build()))
  *     .build();
  * }
* @@ -90,12 +96,18 @@ * McpServerFeatures.AsyncToolSpecification.builder() * .tool(calculatorTool) * .callTool((exchange, args) -> Mono.fromSupplier(() -> calculate(args.arguments())) - * .map(result -> new CallToolResult("Result: " + result)))) + * .map(result -> CallToolResult.builder() + * .content(List.of(new McpSchema.TextContent("Result: " + result))) + * .isError(false) + * .build())) *. .build(), * McpServerFeatures.AsyncToolSpecification.builder() * .tool((weatherTool) * .callTool((exchange, args) -> Mono.fromSupplier(() -> getWeather(args.arguments())) - * .map(result -> new CallToolResult("Weather: " + result)))) + * .map(result -> CallToolResult.builder() + * .content(List.of(new McpSchema.TextContent("Weather: " + result))) + * .isError(false) + * .build())) * .build() * ) * // Register resources @@ -133,7 +145,7 @@ */ public interface McpServer { - McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("Java SDK MCP Server", "0.15.0"); /** * Starts building a synchronous MCP server that provides blocking operations. @@ -226,11 +238,12 @@ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + + var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator + : McpJsonDefaults.getSchemaValidator(); + + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); } } @@ -253,11 +266,10 @@ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, this.instructions); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + : McpJsonDefaults.getSchemaValidator(); + return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); } } @@ -267,9 +279,9 @@ public McpAsyncServer build() { */ abstract class AsyncSpecification> { - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -279,6 +291,8 @@ abstract class AsyncSpecification> { String instructions; + boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -297,7 +311,14 @@ abstract class AsyncSpecification> { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -388,6 +409,18 @@ public AsyncSpecification instructions(String instructions) { return this; } + /** + * Sets whether to use strict tool name validation for this server. When set, this + * takes priority over the system property + * {@code io.modelcontextprotocol.strictToolNameValidation}. + * @param strict true to throw exception on invalid names and false to warn only + * @return This builder instance for method chaining + */ + public AsyncSpecification strictToolNameValidation(boolean strict) { + this.strictToolNameValidation = strict; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -408,42 +441,6 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCap return this; } - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.AsyncToolSpecification} explicitly. - * - *

- * Example usage:

{@code
-		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
-		 *         .map(result -> new CallToolResult("Result: " + result))
-		 * )
-		 * }
- * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * The function's first argument is an {@link McpAsyncServerExchange} upon which - * the server can interact with the connected client. The second argument is the - * map of arguments passed to the tool. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - * @deprecated Use {@link #toolCall(McpSchema.Tool, BiFunction)} instead for tool - * calls that require a request object. - */ - @Deprecated - public AsyncSpecification tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - assertNoDuplicateTool(tool.name()); - - this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); - - return this; - } - /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a @@ -462,6 +459,7 @@ public AsyncSpecification toolCall(McpSchema.Tool tool, Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Handler must not be null"); + validateToolName(tool.name()); assertNoDuplicateTool(tool.name()); this.tools @@ -484,6 +482,7 @@ public AsyncSpecification tools(List tools(McpServerFeatures.AsyncToolSpecification... t Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { + validateToolName(tool.tool().name()); assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } return this; } + private void validateToolName(String toolName) { + ToolNameValidator.validate(toolName, this.strictToolNameValidation); + } + private void assertNoDuplicateTool(String toolName) { if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); @@ -584,40 +588,38 @@ public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecificat } /** - * Sets the resource templates that define patterns for dynamic resource access. - * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates Map of template URI to specification. Must not be + * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public AsyncSpecification resourceTemplates(List resourceTemplates) { + public AsyncSpecification resourceTemplates( + List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } return this; } /** - * Sets the resource templates using varargs for convenience. This is an - * alternative to {@link #resourceTemplates(List)}. - * @param resourceTemplates The resource templates to set. + * Registers multiple resource templates with their specifications using a List. + * This method is useful when resource templates need to be added in bulk from a + * collection. + * @param resourceTemplates List of template URI to specification. Must not be + * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public AsyncSpecification resourceTemplates( + McpServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpServerFeatures.AsyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); } return this; } @@ -764,14 +766,14 @@ public AsyncSpecification rootsChangeHandlers( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public AsyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public AsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -812,13 +814,11 @@ public McpSyncServer build() { this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -845,13 +845,11 @@ public McpSyncServer build() { this.rootsChangeHandlers, this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + : McpJsonDefaults.getSchemaValidator(); + var asyncServer = new McpAsyncServer(transportProvider, + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); - return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -862,9 +860,9 @@ public McpSyncServer build() { */ abstract class SyncSpecification> { - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -872,6 +870,8 @@ abstract class SyncSpecification> { String instructions; + boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -890,7 +890,14 @@ abstract class SyncSpecification> { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); JsonSchemaValidator jsonSchemaValidator; @@ -985,6 +992,18 @@ public SyncSpecification instructions(String instructions) { return this; } + /** + * Sets whether to use strict tool name validation for this server. When set, this + * takes priority over the system property + * {@code io.modelcontextprotocol.strictToolNameValidation}. + * @param strict true to throw exception on invalid names, false to warn only + * @return This builder instance for method chaining + */ + public SyncSpecification strictToolNameValidation(boolean strict) { + this.strictToolNameValidation = strict; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -1005,41 +1024,6 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapa return this; } - /** - * Adds a single tool with its implementation handler to the server. This is a - * convenience method for registering individual tools without creating a - * {@link McpServerFeatures.SyncToolSpecification} explicitly. - * - *

- * Example usage:

{@code
-		 * .tool(
-		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     (exchange, args) -> new CallToolResult("Result: " + calculate(args))
-		 * )
-		 * }
- * @param tool The tool definition including name, description, and schema. Must - * not be null. - * @param handler The function that implements the tool's logic. Must not be null. - * The function's first argument is an {@link McpSyncServerExchange} upon which - * the server can interact with the connected client. The second argument is the - * list of arguments passed to the tool. - * @return This builder instance for method chaining - * @throws IllegalArgumentException if tool or handler is null - * @deprecated Use {@link #toolCall(McpSchema.Tool, BiFunction)} instead for tool - * calls that require a request object. - */ - @Deprecated - public SyncSpecification tool(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> handler) { - Assert.notNull(tool, "Tool must not be null"); - Assert.notNull(handler, "Handler must not be null"); - assertNoDuplicateTool(tool.name()); - - this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); - - return this; - } - /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a @@ -1057,9 +1041,10 @@ public SyncSpecification toolCall(McpSchema.Tool tool, BiFunction handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); + validateToolName(tool.name()); assertNoDuplicateTool(tool.name()); - this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, null, handler)); + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); return this; } @@ -1079,7 +1064,8 @@ public SyncSpecification tools(List for (var tool : toolSpecifications) { String toolName = tool.tool().name(); - assertNoDuplicateTool(toolName); // Check against existing tools + validateToolName(toolName); + assertNoDuplicateTool(toolName); this.tools.add(tool); } @@ -1107,12 +1093,17 @@ public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... too Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { + validateToolName(tool.tool().name()); assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } return this; } + private void validateToolName(String toolName) { + ToolNameValidator.validate(toolName, this.strictToolNameValidation); + } + private void assertNoDuplicateTool(String toolName) { if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); @@ -1182,23 +1173,17 @@ public SyncSpecification resources(McpServerFeatures.SyncResourceSpecificatio /** * Sets the resource templates that define patterns for dynamic resource access. * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. + * @param resourceTemplates List of resource template specifications. Must not be + * null. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public SyncSpecification resourceTemplates(List resourceTemplates) { + public SyncSpecification resourceTemplates( + List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + for (McpServerFeatures.SyncResourceTemplateSpecification resource : resourceTemplates) { + this.resourceTemplates.put(resource.resourceTemplate().uriTemplate(), resource); + } return this; } @@ -1210,10 +1195,11 @@ public SyncSpecification resourceTemplates(List resourceTem * @throws IllegalArgumentException if resourceTemplates is null * @see #resourceTemplates(List) */ - public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public SyncSpecification resourceTemplates( + McpServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); } return this; } @@ -1362,14 +1348,14 @@ public SyncSpecification rootsChangeHandlers( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public SyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public SyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -1401,9 +1387,9 @@ class StatelessAsyncSpecification { private final McpStatelessServerTransport transport; - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -1413,6 +1399,8 @@ class StatelessAsyncSpecification { String instructions; + boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -1431,7 +1419,14 @@ class StatelessAsyncSpecification { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -1523,6 +1518,18 @@ public StatelessAsyncSpecification instructions(String instructions) { return this; } + /** + * Sets whether to use strict tool name validation for this server. When set, this + * takes priority over the system property + * {@code io.modelcontextprotocol.strictToolNameValidation}. + * @param strict true to throw exception on invalid names, false to warn only + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification strictToolNameValidation(boolean strict) { + this.strictToolNameValidation = strict; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -1561,6 +1568,7 @@ public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Handler must not be null"); + validateToolName(tool.name()); assertNoDuplicateTool(tool.name()); this.tools.add(new McpStatelessServerFeatures.AsyncToolSpecification(tool, callHandler)); @@ -1583,6 +1591,7 @@ public StatelessAsyncSpecification tools( Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { + validateToolName(tool.tool().name()); assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } @@ -1611,12 +1620,17 @@ public StatelessAsyncSpecification tools( Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { + validateToolName(tool.tool().name()); assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } return this; } + private void validateToolName(String toolName) { + ToolNameValidator.validate(toolName, this.strictToolNameValidation); + } + private void assertNoDuplicateTool(String toolName) { if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); @@ -1687,23 +1701,17 @@ public StatelessAsyncSpecification resources( /** * Sets the resource templates that define patterns for dynamic resource access. * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
* @param resourceTemplates List of resource templates. If null, clears existing * templates. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public StatelessAsyncSpecification resourceTemplates(List resourceTemplates) { + public StatelessAsyncSpecification resourceTemplates( + List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + for (var resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } return this; } @@ -1715,10 +1723,11 @@ public StatelessAsyncSpecification resourceTemplates(List reso * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public StatelessAsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public StatelessAsyncSpecification resourceTemplates( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); } return this; } @@ -1820,14 +1829,14 @@ public StatelessAsyncSpecification completions( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public StatelessAsyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public StatelessAsyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -1848,11 +1857,9 @@ public StatelessAsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonS public McpStatelessAsyncServer build() { var features = new McpStatelessServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - return new McpStatelessAsyncServer(this.transport, mapper, features, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, + features, requestTimeout, uriTemplateManagerFactory, + jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); } } @@ -1863,9 +1870,9 @@ class StatelessSyncSpecification { boolean immediateExecution = false; - McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); - ObjectMapper objectMapper; + McpJsonMapper jsonMapper; McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -1875,6 +1882,8 @@ class StatelessSyncSpecification { String instructions; + boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -1893,7 +1902,14 @@ class StatelessSyncSpecification { */ final Map resources = new HashMap<>(); - final List resourceTemplates = new ArrayList<>(); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resource templates to clients. Resource templates allow servers to + * define parameterized URIs that clients can use to access dynamic resources. + * Each resource template includes variables that clients can fill in to form + * concrete resource URIs. + */ + final Map resourceTemplates = new HashMap<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -1985,6 +2001,18 @@ public StatelessSyncSpecification instructions(String instructions) { return this; } + /** + * Sets whether to use strict tool name validation for this server. When set, this + * takes priority over the system property + * {@code io.modelcontextprotocol.strictToolNameValidation}. + * @param strict true to throw exception on invalid names, false to warn only + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification strictToolNameValidation(boolean strict) { + this.strictToolNameValidation = strict; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -2023,6 +2051,7 @@ public StatelessSyncSpecification toolCall(McpSchema.Tool tool, Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Handler must not be null"); + validateToolName(tool.name()); assertNoDuplicateTool(tool.name()); this.tools.add(new McpStatelessServerFeatures.SyncToolSpecification(tool, callHandler)); @@ -2045,6 +2074,7 @@ public StatelessSyncSpecification tools( Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { + validateToolName(tool.tool().name()); assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } @@ -2073,12 +2103,17 @@ public StatelessSyncSpecification tools( Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { + validateToolName(tool.tool().name()); assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } return this; } + private void validateToolName(String toolName) { + ToolNameValidator.validate(toolName, this.strictToolNameValidation); + } + private void assertNoDuplicateTool(String toolName) { if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); @@ -2149,23 +2184,17 @@ public StatelessSyncSpecification resources( /** * Sets the resource templates that define patterns for dynamic resource access. * Templates use URI patterns with placeholders that can be filled at runtime. - * - *

- * Example usage:

{@code
-		 * .resourceTemplates(
-		 *     new ResourceTemplate("file://{path}", "Access files by path"),
-		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
-		 * )
-		 * }
- * @param resourceTemplates List of resource templates. If null, clears existing - * templates. + * @param resourceTemplatesSpec List of resource templates. If null, clears + * existing templates. * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceTemplates is null. - * @see #resourceTemplates(ResourceTemplate...) */ - public StatelessSyncSpecification resourceTemplates(List resourceTemplates) { - Assert.notNull(resourceTemplates, "Resource templates must not be null"); - this.resourceTemplates.addAll(resourceTemplates); + public StatelessSyncSpecification resourceTemplates( + List resourceTemplatesSpec) { + Assert.notNull(resourceTemplatesSpec, "Resource templates must not be null"); + for (var resourceTemplate : resourceTemplatesSpec) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); + } return this; } @@ -2177,10 +2206,11 @@ public StatelessSyncSpecification resourceTemplates(List resou * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public StatelessSyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public StatelessSyncSpecification resourceTemplates( + McpStatelessServerFeatures.SyncResourceTemplateSpecification... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); - for (ResourceTemplate resourceTemplate : resourceTemplates) { - this.resourceTemplates.add(resourceTemplate); + for (McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplate : resourceTemplates) { + this.resourceTemplates.put(resourceTemplate.resourceTemplate().uriTemplate(), resourceTemplate); } return this; } @@ -2282,14 +2312,14 @@ public StatelessSyncSpecification completions( } /** - * Sets the object mapper to use for serializing and deserializing JSON messages. - * @param objectMapper the instance to use. Must not be null. + * Sets the JsonMapper to use for serializing and deserializing JSON messages. + * @param jsonMapper the mapper to use. Must not be null. * @return This builder instance for method chaining. - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public StatelessSyncSpecification objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public StatelessSyncSpecification jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -2324,31 +2354,13 @@ public StatelessSyncSpecification immediateExecution(boolean immediateExecution) } public McpStatelessSyncServer build() { - /* - * McpServerFeatures.Sync syncFeatures = new - * McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - * this.tools, this.resources, this.resourceTemplates, this.prompts, - * this.completions, this.rootsChangeHandlers, this.instructions); - * McpServerFeatures.Async asyncFeatures = - * McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - * var mapper = this.objectMapper != null ? this.objectMapper : new - * ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null - * ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); - * - * var asyncServer = new McpAsyncServer(this.transportProvider, mapper, - * asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, - * jsonSchemaValidator); - * - * return new McpSyncServer(asyncServer, this.immediateExecution); - */ var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - var asyncServer = new McpStatelessAsyncServer(this.transport, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + var asyncServer = new McpStatelessAsyncServer(transport, + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout, + uriTemplateManagerFactory, + this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); return new McpStatelessSyncServer(asyncServer, this.immediateExecution); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 12edfb341..a0cbae0f2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -41,7 +41,7 @@ public class McpServerFeatures { */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, @@ -53,7 +53,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes @@ -61,7 +61,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, @@ -84,7 +84,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.tools = (tools != null) ? tools : List.of(); this.resources = (resources != null) ? resources : Map.of(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : Map.of(); this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); @@ -112,6 +112,11 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); }); + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); + }); + Map prompts = new HashMap<>(); syncSpec.prompts().forEach((key, prompt) -> { prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); @@ -130,8 +135,8 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { .subscribeOn(Schedulers.boundedElastic())); } - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions()); + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, rootChangeConsumers, syncSpec.instructions()); } } @@ -151,7 +156,7 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List>> rootsChangeConsumers, String instructions) { @@ -171,7 +176,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, List>> rootsChangeConsumers, @@ -194,7 +199,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.tools = (tools != null) ? tools : new ArrayList<>(); this.resources = (resources != null) ? resources : new HashMap<>(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); this.completions = (completions != null) ? completions : new HashMap<>(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); @@ -218,19 +223,8 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, - @Deprecated BiFunction, Mono> call, BiFunction> callHandler) { - /** - * @deprecated Use {@link AsyncToolSpecification(McpSchema.Tool, null, - * BiFunction)} instead. - **/ - @Deprecated - public AsyncToolSpecification(McpSchema.Tool tool, - BiFunction, Mono> call) { - this(tool, call, (exchange, toolReq) -> call.apply(exchange, toolReq.arguments())); - } - static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { return fromSync(syncToolSpec, false); } @@ -242,13 +236,6 @@ static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boole return null; } - BiFunction, Mono> deprecatedCall = (syncToolSpec - .call() != null) ? (exchange, map) -> { - var toolResult = Mono - .fromCallable(() -> syncToolSpec.call().apply(new McpSyncServerExchange(exchange), map)); - return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); - } : null; - BiFunction> callHandler = ( exchange, req) -> { var toolResult = Mono @@ -256,7 +243,7 @@ static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boole return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); }; - return new AsyncToolSpecification(syncToolSpec.tool(), deprecatedCall, callHandler); + return new AsyncToolSpecification(syncToolSpec.tool(), callHandler); } /** @@ -299,7 +286,7 @@ public AsyncToolSpecification build() { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Call handler function must not be null"); - return new AsyncToolSpecification(tool, null, callHandler); + return new AsyncToolSpecification(tool, callHandler); } } @@ -329,7 +316,13 @@ public static Builder builder() { * *
{@code
 	 * new McpServerFeatures.AsyncResourceSpecification(
-	 * 		new Resource("docs", "Documentation files", "text/markdown"),
+	 *     Resource.builder()
+	 *         .uri("docs")
+	 *         .name("Documentation files")
+	 * 		   .title("Documentation files")
+	 * 		   .mimeType("text/markdown")
+	 * 		   .description("Markdown documentation files")
+	 * 		   .build(),
 	 * 		(exchange, request) -> Mono.fromSupplier(() -> readFile(request.getPath()))
 	 * 				.map(ReadResourceResult::new))
 	 * }
@@ -356,6 +349,47 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, b } } + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (exchange, req) -> { + var resourceResult = Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + /** * Specification of a prompt template with its asynchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: @@ -453,40 +487,34 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * *
{@code
 	 * McpServerFeatures.SyncToolSpecification.builder()
-	 * 		.tool(new Tool(
-	 * 				"calculator",
-	 * 				"Performs mathematical calculations",
-	 * 				new JsonSchemaObject()
+	 * 		.tool(Tool.builder()
+	 * 				.name("calculator")
+	 * 				.title("Performs mathematical calculations")
+	 * 				.inputSchema(new JsonSchemaObject()
 	 * 						.required("expression")
-	 * 						.property("expression", JsonSchemaType.STRING)))
+	 * 						.property("expression", JsonSchemaType.STRING))
+	 * 				.build()
 	 * 		.toolHandler((exchange, req) -> {
 	 * 			String expr = (String) req.arguments().get("expression");
-	 * 			return new CallToolResult("Result: " + evaluate(expr));
+	 * 			return CallToolResult.builder()
+	 *                   .content(List.of(new McpSchema.TextContent("Result: " + evaluate(expr))))
+	 *                   .isError(false)
+	 *                   .build();
 	 * 		}))
 	 *      .build();
 	 * }
* * @param tool The tool definition including name, description, and parameter schema - * @param call (Deprected) The function that implements the tool's logic, receiving - * arguments and returning results. The function's first argument is an - * {@link McpSyncServerExchange} upon which the server can interact with the connected * @param callHandler The function that implements the tool's logic, receiving a * {@link McpSyncServerExchange} and a * {@link io.modelcontextprotocol.spec.McpSchema.CallToolRequest} and returning * results. The function's first argument is an {@link McpSyncServerExchange} upon - * which the server can interact with the client. The second arguments is a map of - * arguments passed to the tool. + * which the server can interact with the client. The second argument is a request + * object containing the arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, - @Deprecated BiFunction, McpSchema.CallToolResult> call, BiFunction callHandler) { - @Deprecated - public SyncToolSpecification(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> call) { - this(tool, call, (exchange, toolReq) -> call.apply(exchange, toolReq.arguments())); - } - /** * Builder for creating SyncToolSpecification instances. */ @@ -527,7 +555,7 @@ public SyncToolSpecification build() { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "CallTool function must not be null"); - return new SyncToolSpecification(tool, null, callHandler); + return new SyncToolSpecification(tool, callHandler); } } @@ -557,7 +585,13 @@ public static Builder builder() { * *
{@code
 	 * new McpServerFeatures.SyncResourceSpecification(
-	 * 		new Resource("docs", "Documentation files", "text/markdown"),
+	 *     Resource.builder()
+	 *         .uri("docs")
+	 *         .name("Documentation files")
+	 * 		   .title("Documentation files")
+	 * 		   .mimeType("text/markdown")
+	 * 		   .description("Markdown documentation files")
+	 * 		   .build(),
 	 * 		(exchange, request) -> {
 	 * 			String content = readFile(request.getPath());
 	 * 			return new ReadResourceResult(content);
@@ -574,6 +608,34 @@ public record SyncResourceSpecification(McpSchema.Resource resource,
 			BiFunction readHandler) {
 	}
 
+	/**
+	 * Specification of a resource template with its synchronous handler function.
+	 * Resource templates allow servers to expose parameterized resources using URI
+	 * templates:  URI
+	 * templates.. Arguments may be auto-completed through the
+	 * completion API.
+	 *
+	 * Templates support:
+	 * 
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + /** * Specification of a prompt template with its synchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java similarity index 60% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java index 41e0e9588..c7a1fd0d7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -4,21 +4,27 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncResourceTemplateSpecification; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpStatelessServerTransport; import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import java.time.Duration; @@ -31,6 +37,8 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; + /** * A stateless MCP server implementation for use with Streamable HTTP transport types. It * allows simple horizontal scalability since it does not maintain a session and does not @@ -45,7 +53,7 @@ public class McpStatelessAsyncServer { private final McpStatelessServerTransport mcpTransportProvider; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final McpSchema.ServerCapabilities serverCapabilities; @@ -55,7 +63,7 @@ public class McpStatelessAsyncServer { private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); @@ -65,21 +73,21 @@ public class McpStatelessAsyncServer { private List protocolVersions; - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); private final JsonSchemaValidator jsonSchemaValidator; - McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, ObjectMapper objectMapper, + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper, McpStatelessServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransport; - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities(); this.instructions = features.instructions(); this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); + this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; @@ -129,7 +137,7 @@ public class McpStatelessAsyncServer { // --------------------------------------- private McpStatelessRequestHandler asyncInitializeRequestHandler() { return (ctx, req) -> Mono.defer(() -> { - McpSchema.InitializeRequest initializeRequest = this.objectMapper.convertValue(req, + McpSchema.InitializeRequest initializeRequest = this.jsonMapper.convertValue(req, McpSchema.InitializeRequest.class); logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", @@ -248,6 +256,11 @@ public Mono apply(McpTransportContext transportContext, McpSchem return this.delegateHandler.apply(transportContext, request).map(result -> { + if (Boolean.TRUE.equals(result.isError())) { + // If the tool call resulted in an error, skip further validation + return result; + } + if (outputSchema == null) { if (result.structuredContent() != null) { logger.warn( @@ -263,11 +276,12 @@ public Mono apply(McpTransportContext transportContext, McpSchem // results that conform to this schema. // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema if (result.structuredContent() == null) { - logger.warn( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - return new CallToolResult( - "Response missing structured content which is expected when calling tool with non-empty outputSchema", - true); + String content = "Response missing structured content which is expected when calling tool with non-empty outputSchema"; + logger.warn(content); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(content))) + .isError(true) + .build(); } // Validate the result against the output schema @@ -275,7 +289,10 @@ public Mono apply(McpTransportContext transportContext, McpSchem if (!validation.valid()) { logger.warn("Tool call result validation failed: {}", validation.errorMessage()); - return new CallToolResult(validation.errorMessage(), true); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); } if (Utils.isEmpty(result.content())) { @@ -285,8 +302,11 @@ public Mono apply(McpTransportContext transportContext, McpSchem // TextContent block.) // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content - return new CallToolResult(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput())), - result.isError(), result.structuredContent()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput()))) + .isError(result.isError()) + .structuredContent(result.structuredContent()) + .build(); } return result; @@ -302,25 +322,24 @@ public Mono apply(McpTransportContext transportContext, McpSchem */ public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { if (toolSpecification == null) { - return Mono.error(new McpError("Tool specification must not be null")); + return Mono.error(new IllegalArgumentException("Tool specification must not be null")); } if (toolSpecification.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); + return Mono.error(new IllegalArgumentException("Tool must not be null")); } if (toolSpecification.callHandler() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); + return Mono.error(new IllegalArgumentException("Tool call handler must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { - return Mono.error( - new McpError("Tool with name '" + wrappedToolSpecification.tool().name() + "' already exists")); + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); } this.tools.add(wrappedToolSpecification); @@ -330,6 +349,14 @@ public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification tool }); } + /** + * List all registered tools. + * @return A Flux stream of all registered tools + */ + public Flux listTools() { + return Flux.fromIterable(this.tools).map(McpStatelessServerFeatures.AsyncToolSpecification::tool); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -337,20 +364,22 @@ public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification tool */ public Mono removeTool(String toolName) { if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); + return Mono.error(new IllegalArgumentException("Tool name must not be null")); } if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); } return Mono.defer(() -> { - boolean removed = this.tools - .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); - if (removed) { + if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + logger.debug("Removed tool handler: {}", toolName); - return Mono.empty(); } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + else { + logger.warn("Ignore as a Tool with name '{}' not found", toolName); + } + + return Mono.empty(); }); } @@ -365,8 +394,8 @@ private McpStatelessRequestHandler toolsListRequestHa private McpStatelessRequestHandler toolsCallRequestHandler() { return (ctx, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, + new TypeRef() { }); Optional toolSpecification = this.tools.stream() @@ -374,11 +403,13 @@ private McpStatelessRequestHandler toolsCallRequestHandler() { .findAny(); if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: invalid_tool_name") + .data("Tool not found: " + callToolRequest.name()) + .build()); } - return toolSpecification.map(tool -> tool.callHandler().apply(ctx, callToolRequest)) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + return toolSpecification.get().callHandler().apply(ctx, callToolRequest); }; } @@ -393,23 +424,34 @@ private McpStatelessRequestHandler toolsCallRequestHandler() { */ public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecification resourceSpecification) { if (resourceSpecification == null || resourceSpecification.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); + return Mono.error(new IllegalArgumentException("Resource must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); } return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { - return Mono.error(new McpError( - "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + var previous = this.resources.put(resourceSpecification.resource().uri(), resourceSpecification); + if (previous != null) { + logger.warn("Replace existing Resource with URI '{}'", resourceSpecification.resource().uri()); + } + else { + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); } - logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); return Mono.empty(); }); } + /** + * List all registered resources. + * @return A Flux stream of all registered resources + */ + public Flux listResources() { + return Flux.fromIterable(this.resources.values()) + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -417,19 +459,83 @@ public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecificat */ public Mono removeResource(String resourceUri) { if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); + return Mono.error(new IllegalArgumentException("Resource URI must not be null")); } if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with resource capabilities")); } return Mono.defer(() -> { McpStatelessServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); if (removed != null) { logger.debug("Removed resource handler: {}", resourceUri); - return Mono.empty(); } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + else { + logger.warn("Resource with URI '{}' not found", resourceUri); + } + return Mono.empty(); + }); + } + + /** + * Add a new resource template at runtime. + * @param resourceTemplateSpecification The resource template to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResourceTemplate( + McpStatelessServerFeatures.AsyncResourceTemplateSpecification resourceTemplateSpecification) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow adding resource templates")); + } + + return Mono.defer(() -> { + var previous = this.resourceTemplates.put(resourceTemplateSpecification.resourceTemplate().uriTemplate(), + resourceTemplateSpecification); + if (previous != null) { + logger.warn("Replace existing Resource Template with URI '{}'", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + else { + logger.debug("Added resource template handler: {}", + resourceTemplateSpecification.resourceTemplate().uriTemplate()); + } + return Mono.empty(); + }); + } + + /** + * List all registered resource templates. + * @return A Flux stream of all registered resource templates + */ + public Flux listResourceTemplates() { + return Flux.fromIterable(this.resourceTemplates.values()) + .map(McpStatelessServerFeatures.AsyncResourceTemplateSpecification::resourceTemplate); + } + + /** + * Remove a resource template at runtime. + * @param uriTemplate The URI template of the resource template to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResourceTemplate(String uriTemplate) { + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new IllegalStateException( + "Server must be configured with resource capabilities to allow removing resource templates")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceTemplateSpecification removed = this.resourceTemplates + .remove(uriTemplate); + if (removed != null) { + logger.debug("Removed resource template: {}", uriTemplate); + } + else { + logger.warn("Ignore as a Resource Template with URI '{}' not found", uriTemplate); + } + return Mono.empty(); }); } @@ -444,47 +550,52 @@ private McpStatelessRequestHandler resourcesListR } private McpStatelessRequestHandler resourceTemplateListRequestHandler() { - return (ctx, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); - - } - - private List getResourceTemplates() { - var list = new ArrayList<>(this.resourceTemplates); - List resourceTemplates = this.resources.keySet() - .stream() - .filter(uri -> uri.contains("{")) - .map(uri -> { - var resource = this.resources.get(uri).resource(); - var template = new ResourceTemplate(resource.uri(), resource.name(), resource.title(), - resource.description(), resource.mimeType(), resource.annotations()); - return template; - }) - .toList(); - - list.addAll(resourceTemplates); - - return list; + return (exchange, params) -> { + var resourceList = this.resourceTemplates.values() + .stream() + .map(AsyncResourceTemplateSpecification::resourceTemplate) + .toList(); + return Mono.just(new McpSchema.ListResourceTemplatesResult(resourceList, null)); + }; } private McpStatelessRequestHandler resourcesReadRequestHandler() { return (ctx, params) -> { - McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, - new TypeReference() { - }); + McpSchema.ReadResourceRequest resourceRequest = jsonMapper.convertValue(params, new TypeRef<>() { + }); var resourceUri = resourceRequest.uri(); - McpStatelessServerFeatures.AsyncResourceSpecification specification = this.resources.values() - .stream() - .filter(resourceSpecification -> this.uriTemplateManagerFactory - .create(resourceSpecification.resource().uri()) - .matches(resourceUri)) - .findFirst() - .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + // First try to find a static resource specification + // Static resources have exact URIs + return this.findResourceSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> { + // If not found, try to find a dynamic resource specification + // Dynamic resources have URI templates + return this.findResourceTemplateSpecification(resourceUri) + .map(spec -> spec.readHandler().apply(ctx, resourceRequest)) + .orElseGet(() -> Mono.error(RESOURCE_NOT_FOUND.apply(resourceUri))); + }); - return specification.readHandler().apply(ctx, resourceRequest); }; } + private Optional findResourceSpecification(String uri) { + var result = this.resources.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resource().uri()).matches(uri)) + .findFirst(); + return result; + } + + private Optional findResourceTemplateSpecification( + String uri) { + return this.resourceTemplates.values() + .stream() + .filter(spec -> this.uriTemplateManagerFactory.create(spec.resourceTemplate().uriTemplate()).matches(uri)) + .findFirst(); + } + // --------------------------------------- // Prompt Management // --------------------------------------- @@ -496,26 +607,34 @@ private McpStatelessRequestHandler resourcesReadRe */ public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification promptSpecification) { if (promptSpecification == null) { - return Mono.error(new McpError("Prompt specification must not be null")); + return Mono.error(new IllegalArgumentException("Prompt specification must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { - McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts - .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); - if (specification != null) { - return Mono.error( - new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + var previous = this.prompts.put(promptSpecification.prompt().name(), promptSpecification); + if (previous != null) { + logger.warn("Replace existing Prompt with name '{}'", promptSpecification.prompt().name()); + } + else { + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); } - - logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); return Mono.empty(); }); } + /** + * List all registered prompts. + * @return A Flux stream of all registered prompts + */ + public Flux listPrompts() { + return Flux.fromIterable(this.prompts.values()) + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove @@ -523,10 +642,10 @@ public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification */ public Mono removePrompt(String promptName) { if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); + return Mono.error(new IllegalArgumentException("Prompt name must not be null")); } if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); + return Mono.error(new IllegalStateException("Server must be configured with prompt capabilities")); } return Mono.defer(() -> { @@ -536,7 +655,11 @@ public Mono removePrompt(String promptName) { logger.debug("Removed prompt handler: {}", promptName); return Mono.empty(); } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + else { + logger.warn("Ignore as a Prompt with name '{}' not found", promptName); + } + + return Mono.empty(); }); } @@ -558,67 +681,122 @@ private McpStatelessRequestHandler promptsListReque private McpStatelessRequestHandler promptsGetRequestHandler() { return (ctx, params) -> { - McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, - new TypeReference() { + McpSchema.GetPromptRequest promptRequest = jsonMapper.convertValue(params, + new TypeRef() { }); // Implement prompt retrieval logic here McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); if (specification == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Invalid prompt name") + .data("Prompt not found: " + promptRequest.name()) + .build()); } return specification.promptHandler().apply(ctx, promptRequest); }; } + private static final Mono EMPTY_COMPLETION_RESULT = Mono + .just(new McpSchema.CompleteResult(new CompleteCompletion(List.of(), 0, false))); + private McpStatelessRequestHandler completionCompleteRequestHandler() { return (ctx, params) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); if (request.ref() == null) { - return Mono.error(new McpError("ref must not be null")); + return Mono.error( + McpError.builder(ErrorCodes.INVALID_PARAMS).message("Completion ref must not be null").build()); } if (request.ref().type() == null) { - return Mono.error(new McpError("type must not be null")); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Completion ref type must not be null") + .build()); } String type = request.ref().type(); String argumentName = request.argument().name(); - // check if the referenced resource exists - if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + // Check if valid a Prompt exists for this completion request + if (type.equals(PromptReference.TYPE) + && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts .get(promptReference.name()); if (promptSpec == null) { - return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Prompt not found: " + promptReference.name()) + .build()); } - if (promptSpec.prompt().arguments().stream().noneMatch(arg -> arg.name().equals(argumentName))) { + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); + + return EMPTY_COMPLETION_RESULT; } } - if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this.resources - .get(resourceReference.uri()); - if (resourceSpec == null) { - return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); - } - if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) - .getVariableNames() - .contains(argumentName)) { - return Mono.error(new McpError("Argument not found: " + argumentName)); + // Check if valid Resource or ResourceTemplate exists for this completion + // request + if (type.equals(ResourceReference.TYPE) + && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + + var uriTemplateManager = uriTemplateManagerFactory.create(resourceReference.uri()); + + if (!uriTemplateManager.isUriTemplate(resourceReference.uri())) { + // Attempting to autocomplete a fixed resource URI is not an error in + // the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; } + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this + .findResourceSpecification(resourceReference.uri()) + .orElse(null); + + if (resourceSpec != null) { + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource: " + resourceReference.uri()) + .build()); + } + } + else { + var templateSpec = this.findResourceTemplateSpecification(resourceReference.uri()).orElse(null); + if (templateSpec != null) { + + if (!uriTemplateManagerFactory.create(templateSpec.resourceTemplate().uriTemplate()) + .getVariableNames() + .contains(argumentName)) { + + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Argument not found: " + argumentName + " in resource template: " + + resourceReference.uri()) + .build()); + } + } + else { + return Mono.error(RESOURCE_NOT_FOUND.apply(resourceReference.uri())); + } + } } McpStatelessServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); if (specification == null) { - return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("AsyncCompletionSpecification not found: " + request.ref()) + .build()); } return specification.completionHandler().apply(ctx, request); @@ -647,9 +825,9 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) { String refType = (String) refMap.get("type"); McpSchema.CompleteReference ref = switch (refType) { - case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + case PromptReference.TYPE -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), refMap.get("title") != null ? (String) refMap.get("title") : null); - case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + case ResourceReference.TYPE -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); default -> throw new IllegalArgumentException("Invalid ref type: " + refType); }; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java index 6db79a62c..a2fabb283 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import reactor.core.publisher.Mono; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java similarity index 91% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java index e5c9e7c09..37cd3c096 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import reactor.core.publisher.Mono; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java similarity index 80% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java index 60c1dbb65..a15681ba5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -4,6 +4,13 @@ package io.modelcontextprotocol.server; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.util.Assert; @@ -11,12 +18,6 @@ import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiFunction; - /** * MCP stateless server features specification that a particular server can choose to * support. @@ -33,13 +34,14 @@ public class McpStatelessServerFeatures { * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, - Map resources, List resourceTemplates, + Map resources, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -50,13 +52,14 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, - Map resources, List resourceTemplates, + Map resources, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -75,7 +78,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.tools = (tools != null) ? tools : List.of(); this.resources = (resources != null) ? resources : Map.of(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : Map.of(); this.completions = (completions != null) ? completions : Map.of(); this.instructions = instructions; @@ -102,6 +105,11 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); }); + Map resourceTemplates = new HashMap<>(); + syncSpec.resourceTemplates().forEach((key, resource) -> { + resourceTemplates.put(key, AsyncResourceTemplateSpecification.fromSync(resource, immediateExecution)); + }); + Map prompts = new HashMap<>(); syncSpec.prompts().forEach((key, prompt) -> { prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); @@ -112,8 +120,8 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); }); - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, completions, syncSpec.instructions()); + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, + prompts, completions, syncSpec.instructions()); } } @@ -124,14 +132,14 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -142,14 +150,14 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * @param serverCapabilities The server capabilities * @param tools The list of tool specifications * @param resources The map of resource specifications - * @param resourceTemplates The list of resource templates + * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param instructions The server instructions text */ Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, - List resourceTemplates, + Map resourceTemplates, Map prompts, Map completions, String instructions) { @@ -171,7 +179,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se this.tools = (tools != null) ? tools : new ArrayList<>(); this.resources = (resources != null) ? resources : new HashMap<>(); - this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); this.completions = (completions != null) ? completions : new HashMap<>(); this.instructions = instructions; @@ -295,6 +303,46 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, b } } + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record AsyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction> readHandler) { + + static AsyncResourceTemplateSpecification fromSync(SyncResourceTemplateSpecification resource, + boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceTemplateSpecification(resource.resourceTemplate(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + /** * Specification of a prompt template with its asynchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: @@ -445,6 +493,34 @@ public record SyncResourceSpecification(McpSchema.Resource resource, BiFunction readHandler) { } + /** + * Specification of a resource template with its synchronous handler function. + * Resource templates allow servers to expose parameterized resources using URI + * templates: URI + * templates.. Arguments may be auto-completed through the + * completion API. + * + * Templates support: + *
    + *
  • Parameterized resource definitions + *
  • Dynamic content generation + *
  • Consistent resource formatting + *
  • Contextual data injection + *
+ * + * @param resourceTemplate The resource template definition including name, + * description, and parameter schema + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpTransportContext} upon which the server can interact + * with the connected client. The second arguments is a + * {@link McpSchema.ReadResourceRequest}. {@link McpSchema.ResourceTemplate} + * {@link McpSchema.ReadResourceResult} + */ + public record SyncResourceTemplateSpecification(McpSchema.ResourceTemplate resourceTemplate, + BiFunction readHandler) { + } + /** * Specification of a prompt template with its synchronous handler function. Prompts * provide structured templates for AI model interactions, supporting: diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java similarity index 94% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java index 7c4e23cfc..cbae58bfd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import reactor.core.publisher.Mono; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java similarity index 71% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java index 0151a754b..6849eb8ed 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -74,6 +74,14 @@ public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecifi .block(); } + /** + * List all registered tools. + * @return A list of all registered tools + */ + public List listTools() { + return this.asyncServer.listTools().collectList().block(); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -93,6 +101,14 @@ public void addResource(McpStatelessServerFeatures.SyncResourceSpecification res .block(); } + /** + * List all registered resources. + * @return A list of all registered resources + */ + public List listResources() { + return this.asyncServer.listResources().collectList().block(); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -101,6 +117,34 @@ public void removeResource(String resourceUri) { this.asyncServer.removeResource(resourceUri).block(); } + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate( + McpStatelessServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpStatelessServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + /** * Add a new prompt handler at runtime. * @param promptSpecification The prompt handler to add @@ -112,6 +156,14 @@ public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptS .block(); } + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java similarity index 76% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 5adda1a74..d33299d02 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import java.util.List; + import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; @@ -87,6 +89,14 @@ public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { .block(); } + /** + * List all registered tools. + * @return A list of all registered tools + */ + public List listTools() { + return this.asyncServer.listTools().collectList().block(); + } + /** * Remove a tool handler. * @param toolName The name of the tool handler to remove @@ -97,15 +107,23 @@ public void removeTool(String toolName) { /** * Add a new resource handler. - * @param resourceHandler The resource handler to add + * @param resourceSpecification The resource specification to add */ - public void addResource(McpServerFeatures.SyncResourceSpecification resourceHandler) { + public void addResource(McpServerFeatures.SyncResourceSpecification resourceSpecification) { this.asyncServer - .addResource( - McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler, this.immediateExecution)) + .addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) .block(); } + /** + * List all registered resources. + * @return A list of all registered resources + */ + public List listResources() { + return this.asyncServer.listResources().collectList().block(); + } + /** * Remove a resource handler. * @param resourceUri The URI of the resource handler to remove @@ -114,6 +132,33 @@ public void removeResource(String resourceUri) { this.asyncServer.removeResource(resourceUri).block(); } + /** + * Add a new resource template. + * @param resourceTemplateSpecification The resource template specification to add + */ + public void addResourceTemplate(McpServerFeatures.SyncResourceTemplateSpecification resourceTemplateSpecification) { + this.asyncServer + .addResourceTemplate(McpServerFeatures.AsyncResourceTemplateSpecification + .fromSync(resourceTemplateSpecification, this.immediateExecution)) + .block(); + } + + /** + * List all registered resource templates. + * @return A list of all registered resource templates + */ + public List listResourceTemplates() { + return this.asyncServer.listResourceTemplates().collectList().block(); + } + + /** + * Remove a resource template. + * @param uriTemplate The URI template of the resource template to remove + */ + public void removeResourceTemplate(String uriTemplate) { + this.asyncServer.removeResourceTemplate(uriTemplate).block(); + } + /** * Add a new prompt handler. * @param promptSpecification The prompt specification to add @@ -125,6 +170,14 @@ public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecificat .block(); } + /** + * List all registered prompts. + * @return A list of all registered prompts + */ + public List listPrompts() { + return this.asyncServer.listPrompts().collectList().block(); + } + /** * Remove a prompt handler. * @param promptName The name of the prompt handler to remove @@ -177,21 +230,6 @@ public void notifyPromptsListChanged() { this.asyncServer.notifyPromptsListChanged().block(); } - /** - * This implementation would, incorrectly, broadcast the logging message to all - * connected clients, using a single minLoggingLevel for all of them. Similar to the - * sampling and roots, the logging level should be set per client session and use the - * ServerExchange to send the logging message to the right client. - * @param loggingMessageNotification The logging message to send - * @deprecated Use - * {@link McpSyncServerExchange#loggingNotification(LoggingMessageNotification)} - * instead. - */ - @Deprecated - public void loggingNotification(LoggingMessageNotification loggingMessageNotification) { - this.asyncServer.loggingNotification(loggingMessageNotification).block(); - } - /** * Close the server gracefully. */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java similarity index 98% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 5f22df5e9..0b9115b79 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java similarity index 59% rename from mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java index 97fcecf0d..ea9f05a4f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; + /** * The contract for extracting metadata from a generic transport request of type * {@link T}. @@ -15,14 +17,11 @@ public interface McpTransportContextExtractor { /** - * Given an empty context, provides the means to fill it with transport-specific - * metadata extracted from the request. + * Extract transport-specific metadata from the request into an McpTransportContext. * @param request the generic representation for the request in the context of a * specific transport implementation - * @param transportContext the mutable context which can be filled in with metadata - * @return the context filled in with metadata. It can be the same instance as - * provided or a new one. + * @return the context containing the metadata */ - McpTransportContext extract(T request, McpTransportContext transportContext); + McpTransportContext extract(T request); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java new file mode 100644 index 000000000..e96403e48 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java @@ -0,0 +1,204 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation of {@link ServerTransportSecurityValidator} that validates the + * Origin and Host headers against lists of allowed values. + * + *

+ * Supports exact matches and wildcard port patterns (e.g., "http://example.com:*" for + * origins, "example.com:*" for hosts). + * + * @author Daniel Garnier-Moiroux + * @see ServerTransportSecurityValidator + * @see ServerTransportSecurityException + */ +public final class DefaultServerTransportSecurityValidator implements ServerTransportSecurityValidator { + + private static final String ORIGIN_HEADER = "Origin"; + + private static final String HOST_HEADER = "Host"; + + private final List allowedOrigins; + + private final List allowedHosts; + + /** + * Creates a new validator with the specified allowed origins and hosts. + * @param allowedOrigins List of allowed origin patterns. Supports exact matches + * (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*") + * @param allowedHosts List of allowed host patterns. Supports exact matches (e.g., + * "example.com:8080") and wildcard ports (e.g., "example.com:*") + */ + private DefaultServerTransportSecurityValidator(List allowedOrigins, List allowedHosts) { + Assert.notNull(allowedOrigins, "allowedOrigins must not be null"); + Assert.notNull(allowedHosts, "allowedHosts must not be null"); + this.allowedOrigins = allowedOrigins; + this.allowedHosts = allowedHosts; + } + + @Override + public void validateHeaders(Map> headers) throws ServerTransportSecurityException { + boolean missingHost = true; + for (Map.Entry> entry : headers.entrySet()) { + if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) { + List values = entry.getValue(); + if (values == null || values.isEmpty()) { + throw new ServerTransportSecurityException(403, "Invalid Origin header"); + } + validateOrigin(values.get(0)); + } + else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) { + missingHost = false; + List values = entry.getValue(); + if (values == null || values.isEmpty()) { + throw new ServerTransportSecurityException(421, "Invalid Host header"); + } + validateHost(values.get(0)); + } + } + if (!allowedHosts.isEmpty() && missingHost) { + throw new ServerTransportSecurityException(421, "Invalid Host header"); + } + } + + /** + * Validates a single origin value against the allowed origins. Subclasses can + * override this method to customize origin validation logic. + * @param origin The origin header value, or null if not present + * @throws ServerTransportSecurityException if the origin is not allowed + */ + protected void validateOrigin(String origin) throws ServerTransportSecurityException { + // Origin absent = no validation needed (same-origin request) + if (origin == null || origin.isBlank()) { + return; + } + + for (String allowed : allowedOrigins) { + if (allowed.equals(origin)) { + return; + } + else if (allowed.endsWith(":*")) { + // Wildcard port pattern: "http://example.com:*" + String baseOrigin = allowed.substring(0, allowed.length() - 2); + if (origin.equals(baseOrigin) || origin.startsWith(baseOrigin + ":")) { + return; + } + } + + } + + throw new ServerTransportSecurityException(403, "Invalid Origin header"); + } + + /** + * Validates a single host value against the allowed hosts. + * @param host The host header value, or null if not present + * @throws ServerTransportSecurityException if the host is not allowed + */ + private void validateHost(String host) throws ServerTransportSecurityException { + if (allowedHosts.isEmpty()) { + return; + } + + // Host is required + if (host == null || host.isBlank()) { + throw new ServerTransportSecurityException(421, "Invalid Host header"); + } + + for (String allowed : allowedHosts) { + if (allowed.equals(host)) { + return; + } + else if (allowed.endsWith(":*")) { + // Wildcard port pattern: "example.com:*" + String baseHost = allowed.substring(0, allowed.length() - 2); + if (host.equals(baseHost) || host.startsWith(baseHost + ":")) { + return; + } + } + } + + throw new ServerTransportSecurityException(421, "Invalid Host header"); + } + + /** + * Creates a new builder for constructing a DefaultServerTransportSecurityValidator. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link DefaultServerTransportSecurityValidator}. + */ + public static class Builder { + + private final List allowedOrigins = new ArrayList<>(); + + private final List allowedHosts = new ArrayList<>(); + + /** + * Adds an allowed origin pattern. + * @param origin The origin to allow (e.g., "http://localhost:8080" or + * "http://example.com:*") + * @return this builder instance + */ + public Builder allowedOrigin(String origin) { + this.allowedOrigins.add(origin); + return this; + } + + /** + * Adds multiple allowed origin patterns. + * @param origins The origins to allow + * @return this builder instance + */ + public Builder allowedOrigins(List origins) { + Assert.notNull(origins, "origins must not be null"); + this.allowedOrigins.addAll(origins); + return this; + } + + /** + * Adds an allowed host pattern. + * @param host The host to allow (e.g., "localhost:8080" or "example.com:*") + * @return this builder instance + */ + public Builder allowedHost(String host) { + this.allowedHosts.add(host); + return this; + } + + /** + * Adds multiple allowed host patterns. + * @param hosts The hosts to allow + * @return this builder instance + */ + public Builder allowedHosts(List hosts) { + Assert.notNull(hosts, "hosts must not be null"); + this.allowedHosts.addAll(hosts); + return this; + } + + /** + * Builds the validator instance. + * @return A new DefaultServerTransportSecurityValidator + */ + public DefaultServerTransportSecurityValidator build() { + return new DefaultServerTransportSecurityValidator(allowedOrigins, allowedHosts); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java new file mode 100644 index 000000000..32246948c --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import jakarta.servlet.http.HttpServletRequest; + +/** + * Utility methods for working with {@link HttpServletRequest}. For internal use only. + * + * @author Daniel Garnier-Moiroux + */ +final class HttpServletRequestUtils { + + private HttpServletRequestUtils() { + } + + /** + * Extracts all headers from the HTTP request into a map. + * @param request The HTTP servlet request + * @return A map of header names to their values + */ + static Map> extractHeaders(HttpServletRequest request) { + Map> headers = new HashMap<>(); + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + headers.put(name, Collections.list(request.getHeaders(name))); + } + return headers; + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java similarity index 76% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 582120e3f..7037ff293 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024 - 2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -14,10 +14,10 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -70,7 +70,9 @@ @WebServlet(asyncSupported = true) public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { - /** Logger for this class */ + /** + * Logger for this class + */ private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); public static final String UTF_8 = "UTF-8"; @@ -79,38 +81,60 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; - /** Default endpoint path for SSE connections */ + /** + * Default endpoint path for SSE connections + */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - /** Event type for regular messages */ + /** + * Event type for regular messages + */ public static final String MESSAGE_EVENT_TYPE = "message"; - /** Event type for endpoint information */ + /** + * Event type for endpoint information + */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String SESSION_ID = "sessionId"; + public static final String DEFAULT_BASE_URL = ""; - /** JSON object mapper for serialization/deserialization */ - private final ObjectMapper objectMapper; + /** + * JSON mapper for serialization/deserialization + */ + private final McpJsonMapper jsonMapper; - /** Base URL for the server transport */ + /** + * Base URL for the server transport + */ private final String baseUrl; - /** The endpoint path for handling client messages */ + /** + * The endpoint path for handling client messages + */ private final String messageEndpoint; - /** The endpoint path for handling SSE connections */ + /** + * The endpoint path for handling SSE connections + */ private final String sseEndpoint; - /** Map of active client sessions, keyed by session ID */ + /** + * Map of active client sessions, keyed by session ID + */ private final Map sessions = new ConcurrentHashMap<>(); private McpTransportContextExtractor contextExtractor; - /** Flag indicating if the transport is in the process of shutting down */ + /** + * Flag indicating if the transport is in the process of shutting down + */ private final AtomicBoolean isClosing = new AtomicBoolean(false); - /** Session factory for creating new sessions */ + /** + * Session factory for creating new sessions + */ private McpServerSession.Factory sessionFactory; /** @@ -120,62 +144,14 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement private KeepAliveScheduler keepAliveScheduler; /** - * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); - } - - /** - * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param baseUrl The base URL for the server transport - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null, (serverRequest, context) -> context); - } - - /** - * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param baseUrl The base URL for the server transport - * @param messageEndpoint The endpoint path where clients will send their messages - * @param sseEndpoint The endpoint path where clients will establish SSE connections - * @param keepAliveInterval The interval for keep-alive pings, or null to disable - * keep-alive functionality - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. + * Security validator for validating HTTP requests. */ - @Deprecated - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, - (serverRequest, context) -> context); - } + private final ServerTransportSecurityValidator securityValidator; /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. - * @param objectMapper The JSON object mapper to use for message + * @param jsonMapper The JSON object mapper to use for message * serialization/deserialization * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages @@ -183,23 +159,25 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b * @param keepAliveInterval The interval for keep-alive pings, or null to disable * keep-alive functionality * @param contextExtractor The extractor for transport context from the request. - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. + * @param securityValidator The security validator for validating HTTP requests. */ - private HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + private HttpServletSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(messageEndpoint, "messageEndpoint must not be null"); Assert.notNull(sseEndpoint, "sseEndpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; if (keepAliveInterval != null) { @@ -218,17 +196,6 @@ public List protocolVersions() { return List.of(ProtocolVersions.MCP_2024_11_05); } - /** - * Creates a new HttpServletSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - */ - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - /** * Sets the session factory for creating new sessions. * @param sessionFactory The session factory to use @@ -287,6 +254,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = HttpServletRequestUtils.extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + response.setContentType("text/event-stream"); response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); @@ -308,7 +284,22 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) this.sessions.put(sessionId, session); // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, buildEndpointUrl(sessionId)); + } + + /** + * Constructs the full message endpoint URL by combining the base URL, message path, + * and the required session_id query parameter. + * @param sessionId the unique session identifier + * @return the fully qualified endpoint URL as a string + */ + private String buildEndpointUrl(String sessionId) { + // for WebMVC compatibility + if (this.baseUrl.endsWith("/")) { + return this.baseUrl.substring(0, this.baseUrl.length() - 1) + this.messageEndpoint + "?sessionId=" + + sessionId; + } + return this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId; } /** @@ -337,13 +328,24 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = HttpServletRequestUtils.extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + // Get the session ID from the request parameter String sessionId = request.getParameter("sessionId"); if (sessionId == null) { response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - String jsonError = objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + String jsonError = jsonMapper.writeValueAsString(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Session ID missing in message endpoint") + .build()); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -356,7 +358,9 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_NOT_FOUND); - String jsonError = objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + String jsonError = jsonMapper.writeValueAsString(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Session not found: " + sessionId) + .build()); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -371,9 +375,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } - final McpTransportContext transportContext = this.contextExtractor.extract(request, - new DefaultMcpTransportContext()); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + final McpTransportContext transportContext = this.contextExtractor.extract(request); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); // Process the message through the session's handle method // Block for Servlet compatibility @@ -384,11 +387,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) catch (Exception e) { logger.error("Error processing message: {}", e.getMessage()); try { - McpError mcpError = new McpError(e.getMessage()); + McpError mcpError = McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message(e.getMessage()) + .build(); response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - String jsonError = objectMapper.writeValueAsString(mcpError); + String jsonError = jsonMapper.writeValueAsString(mcpError); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -484,7 +489,7 @@ private class HttpServletMcpSessionTransport implements McpServerTransport { public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable(() -> { try { - String jsonText = objectMapper.writeValueAsString(message); + String jsonText = jsonMapper.writeValueAsString(message); sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); logger.debug("Message sent to session {}", sessionId); } @@ -497,15 +502,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured JsonMapper. * @param data The source data object to convert * @param typeRef The target type reference - * @return The converted object of type T * @param The target type + * @return The converted object of type T */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -561,7 +566,7 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper = new ObjectMapper(); + private McpJsonMapper jsonMapper; private String baseUrl = DEFAULT_BASE_URL; @@ -569,18 +574,23 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private Duration keepAliveInterval; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + /** - * Sets the JSON object mapper to use for message serialization/deserialization. - * @param objectMapper The object mapper to use + * Sets the JsonMapper implementation to use for serialization/deserialization. If + * not specified, a JacksonJsonMapper will be created from the configured + * ObjectMapper. + * @param jsonMapper The JsonMapper to use * @return This builder instance for method chaining */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -645,21 +655,31 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return This builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. * @return A new HttpServletSseServerTransportProvider instance - * @throws IllegalStateException if objectMapper or messageEndpoint is not set + * @throws IllegalStateException if jsonMapper or messageEndpoint is not set */ public HttpServletSseServerTransportProvider build() { - if (objectMapper == null) { - throw new IllegalStateException("ObjectMapper must be set"); - } if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval, contextExtractor); + return new HttpServletSseServerTransportProvider( + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, baseUrl, messageEndpoint, + sseEndpoint, keepAliveInterval, contextExtractor, securityValidator); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java similarity index 72% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 25b003564..047aeebe8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -7,15 +7,17 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; +import java.util.List; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpStatelessServerHandler; -import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -49,7 +51,7 @@ public class HttpServletStatelessServerTransport extends HttpServlet implements public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final String mcpEndpoint; @@ -59,15 +61,23 @@ public class HttpServletStatelessServerTransport extends HttpServlet implements private volatile boolean isClosing = false; - private HttpServletStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "objectMapper must not be null"); + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + + private HttpServletStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor, + ServerTransportSecurityValidator securityValidator) { + Assert.notNull(jsonMapper, "jsonMapper must not be null"); Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); Assert.notNull(contextExtractor, "contextExtractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; } @Override @@ -123,12 +133,23 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + try { + Map> headers = HttpServletRequestUtils.extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request); String accept = request.getHeader(ACCEPT); if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, - new McpError("Both application/json and text/event-stream required in Accept header")); + McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Both application/json and text/event-stream required in Accept header") + .build()); return; } @@ -140,7 +161,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { try { @@ -153,7 +174,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setCharacterEncoding(UTF_8); response.setStatus(HttpServletResponse.SC_OK); - String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse); + String jsonResponseText = jsonMapper.writeValueAsString(jsonrpcResponse); PrintWriter writer = response.getWriter(); writer.write(jsonResponseText); writer.flush(); @@ -161,7 +182,9 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) catch (Exception e) { logger.error("Failed to handle request: {}", e.getMessage()); this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Failed to handle request: " + e.getMessage())); + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Failed to handle request: " + e.getMessage()) + .build()); } } else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { @@ -174,22 +197,29 @@ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { catch (Exception e) { logger.error("Failed to handle notification: {}", e.getMessage()); this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Failed to handle notification: " + e.getMessage())); + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Failed to handle notification: " + e.getMessage()) + .build()); } } else { this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, - new McpError("The server accepts either requests or notifications")); + McpError.builder(McpSchema.ErrorCodes.INVALID_REQUEST) + .message("The server accepts either requests or notifications") + .build()); } } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format")); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + McpError.builder(McpSchema.ErrorCodes.INVALID_REQUEST).message("Invalid message format").build()); } catch (Exception e) { logger.error("Unexpected error handling message: {}", e.getMessage()); this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Unexpected error: " + e.getMessage())); + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Unexpected error: " + e.getMessage()) + .build()); } } @@ -204,7 +234,7 @@ private void responseError(HttpServletResponse response, int httpCode, McpError response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(httpCode); - String jsonError = objectMapper.writeValueAsString(mcpError); + String jsonError = jsonMapper.writeValueAsString(mcpError); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -237,26 +267,29 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; + + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; private Builder() { // used by a static method } /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The JsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if jsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -288,6 +321,18 @@ public Builder contextExtractor(McpTransportContextExtractor return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link HttpServletStatelessServerTransport} with the * configured settings. @@ -295,10 +340,10 @@ public Builder contextExtractor(McpTransportContextExtractor * @throws IllegalStateException if required parameters are not set */ public HttpServletStatelessServerTransport build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new HttpServletStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + return new HttpServletStatelessServerTransport( + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, mcpEndpoint, contextExtractor, + securityValidator); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java similarity index 83% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 8b95ec607..d7561188c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server.transport; @@ -10,17 +10,16 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpError; @@ -30,6 +29,8 @@ import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; @@ -98,7 +99,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private final boolean disallowDelete; - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private McpStreamableServerSession.Factory sessionFactory; @@ -120,27 +121,37 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private KeepAliveScheduler keepAliveScheduler; + /** + * Security validator for validating HTTP requests. + */ + private final ServerTransportSecurityValidator securityValidator; + /** * Constructs a new HttpServletStreamableServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization of + * messages. * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. * @param disallowDelete Whether to disallow DELETE requests on the endpoint. * @param contextExtractor The extractor for transport context from the request. + * @param keepAliveInterval The interval for keep-alive pings. If null, no keep-alive + * will be scheduled. + * @param securityValidator The security validator for validating HTTP requests. * @throws IllegalArgumentException if any parameter is null */ - private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor contextExtractor, - Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(securityValidator, "Security validator must not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; + this.securityValidator = securityValidator; if (keepAliveInterval != null) { @@ -157,7 +168,8 @@ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, @Override public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, + ProtocolVersions.MCP_2025_06_18, ProtocolVersions.MCP_2025_11_25); } @Override @@ -246,6 +258,15 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = HttpServletRequestUtils.extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + List badRequestErrors = new ArrayList<>(); String accept = request.getHeader(ACCEPT); @@ -261,7 +282,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) if (!badRequestErrors.isEmpty()) { String combinedMessage = String.join("; ", badRequestErrors); - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND).message(combinedMessage).build()); return; } @@ -274,7 +296,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) logger.debug("Handling GET request for session: {}", sessionId); - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); try { response.setContentType(TEXT_EVENT_STREAM); @@ -373,6 +395,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + try { + Map> headers = HttpServletRequestUtils.extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + List badRequestErrors = new ArrayList<>(); String accept = request.getHeader(ACCEPT); @@ -383,7 +414,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) badRequestErrors.add("application/json required in Accept header"); } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); try { BufferedReader reader = request.getReader(); @@ -393,19 +424,20 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, body.toString()); // Handle initialization request if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { if (!badRequestErrors.isEmpty()) { String combinedMessage = String.join("; ", badRequestErrors); - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND).message(combinedMessage).build()); return; } - McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), - new TypeReference() { + McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), + new TypeRef() { }); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); @@ -419,7 +451,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) response.setHeader(HttpHeaders.MCP_SESSION_ID, init.session().getId()); response.setStatus(HttpServletResponse.SC_OK); - String jsonResponse = objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse( + String jsonResponse = jsonMapper.writeValueAsString(new McpSchema.JSONRPCResponse( McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null)); PrintWriter writer = response.getWriter(); @@ -430,7 +462,9 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) catch (Exception e) { logger.error("Failed to initialize session: {}", e.getMessage()); this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Failed to initialize session: " + e.getMessage())); + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Failed to initialize session: " + e.getMessage()) + .build()); return; } } @@ -443,7 +477,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) if (!badRequestErrors.isEmpty()) { String combinedMessage = String.join("; ", badRequestErrors); - this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND).message(combinedMessage).build()); return; } @@ -451,7 +486,9 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) if (session == null) { this.responseError(response, HttpServletResponse.SC_NOT_FOUND, - new McpError("Session not found: " + sessionId)); + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Session not found: " + sessionId) + .build()); return; } @@ -493,19 +530,23 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { } else { this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Unknown message type")); + McpError.builder(McpSchema.ErrorCodes.INVALID_REQUEST).message("Unknown message type").build()); } } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, - new McpError("Invalid message format: " + e.getMessage())); + McpError.builder(McpSchema.ErrorCodes.INVALID_REQUEST) + .message("Invalid message format: " + e.getMessage()) + .build()); } catch (Exception e) { logger.error("Error handling message: {}", e.getMessage()); try { this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError("Error processing message: " + e.getMessage())); + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Error processing message: " + e.getMessage()) + .build()); } catch (IOException ex) { logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); @@ -536,16 +577,27 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response return; } + try { + Map> headers = HttpServletRequestUtils.extractHeaders(request); + this.securityValidator.validateHeaders(headers); + } + catch (ServerTransportSecurityException e) { + response.sendError(e.getStatusCode(), e.getMessage()); + return; + } + if (this.disallowDelete) { response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); return; } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + McpTransportContext transportContext = this.contextExtractor.extract(request); if (request.getHeader(HttpHeaders.MCP_SESSION_ID) == null) { this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, - new McpError("Session ID required in mcp-session-id header")); + McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Session ID required in mcp-session-id header") + .build()); return; } @@ -566,7 +618,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); try { this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, - new McpError(e.getMessage())); + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(e.getMessage()).build()); } catch (IOException ex) { logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); @@ -579,7 +631,7 @@ public void responseError(HttpServletResponse response, int httpCode, McpError m response.setContentType(APPLICATION_JSON); response.setCharacterEncoding(UTF_8); response.setStatus(httpCode); - String jsonError = objectMapper.writeValueAsString(mcpError); + String jsonError = jsonMapper.writeValueAsString(mcpError); PrintWriter writer = response.getWriter(); writer.write(jsonError); writer.flush(); @@ -686,7 +738,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId return; } - String jsonText = objectMapper.writeValueAsString(message); + String jsonText = jsonMapper.writeValueAsString(message); HttpServletStreamableServerTransportProvider.this.sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText, messageId != null ? messageId : this.sessionId); logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); @@ -703,15 +755,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId } /** - * Converts data from one type to another using the configured ObjectMapper. + * Converts data from one type to another using the configured JsonMapper. * @param data The source data object to convert * @param typeRef The target type reference * @return The converted object of type T * @param The target type */ @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } /** @@ -763,26 +815,29 @@ public static Builder builder() { */ public static class Builder { - private ObjectMapper objectMapper; + private McpJsonMapper jsonMapper; private String mcpEndpoint = "/mcp"; private boolean disallowDelete = false; - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private McpTransportContextExtractor contextExtractor = ( + serverRequest) -> McpTransportContext.EMPTY; private Duration keepAliveInterval; + private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * Sets the JsonMapper to use for JSON serialization/deserialization of MCP * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. + * @param jsonMapper The JsonMapper instance. Must not be null. * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null + * @throws IllegalArgumentException if JsonMapper is null */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; + public Builder jsonMapper(McpJsonMapper jsonMapper) { + Assert.notNull(jsonMapper, "JsonMapper must not be null"); + this.jsonMapper = jsonMapper; return this; } @@ -832,6 +887,18 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { return this; } + /** + * Sets the security validator for validating HTTP requests. + * @param securityValidator The security validator to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if securityValidator is null + */ + public Builder securityValidator(ServerTransportSecurityValidator securityValidator) { + Assert.notNull(securityValidator, "Security validator must not be null"); + this.securityValidator = securityValidator; + return this; + } + /** * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} * with the configured settings. @@ -839,11 +906,10 @@ public Builder keepAliveInterval(Duration keepAliveInterval) { * @throws IllegalStateException if required parameters are not set */ public HttpServletStreamableServerTransportProvider build() { - Assert.notNull(this.objectMapper, "ObjectMapper must be set"); Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); - - return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, - this.disallowDelete, this.contextExtractor, this.keepAliveInterval); + return new HttpServletStreamableServerTransportProvider( + jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, mcpEndpoint, disallowDelete, + contextExtractor, keepAliveInterval, securityValidator); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java new file mode 100644 index 000000000..96a06d3bd --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +/** + * Exception thrown when security validation fails for an HTTP request. Contains HTTP + * status code and message. + * + * @author Daniel Garnier-Moiroux + * @see ServerTransportSecurityValidator + */ +public class ServerTransportSecurityException extends Exception { + + private final int statusCode; + + /** + * Creates a new ServerTransportSecurityException with the specified HTTP status code + * and message. + */ + public ServerTransportSecurityException(int statusCode, String message) { + super(message); + this.statusCode = statusCode; + } + + public int getStatusCode() { + return statusCode; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ServerTransportSecurityException that = (ServerTransportSecurityException) obj; + return statusCode == that.statusCode && java.util.Objects.equals(getMessage(), that.getMessage()); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(statusCode, getMessage()); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java new file mode 100644 index 000000000..ce805931f --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.List; +import java.util.Map; + +/** + * Interface for validating HTTP requests in server transports. Implementations can + * validate Origin headers, Host headers, or any other security-related headers according + * to the MCP specification. + * + * @author Daniel Garnier-Moiroux + * @see DefaultServerTransportSecurityValidator + * @see ServerTransportSecurityException + */ +@FunctionalInterface +public interface ServerTransportSecurityValidator { + + /** + * A no-op validator that accepts all requests without validation. + */ + ServerTransportSecurityValidator NOOP = headers -> { + }; + + /** + * Validates the HTTP headers from an incoming request. + * @param headers A map of header names to their values (multi-valued headers + * supported) + * @throws ServerTransportSecurityException if validation fails + */ + void validateHeaders(Map> headers) throws ServerTransportSecurityException; + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java similarity index 88% rename from mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index af602f610..d288ea3d6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -15,8 +15,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -25,6 +24,7 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.McpJsonMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -44,7 +44,7 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); - private final ObjectMapper objectMapper; + private final McpJsonMapper jsonMapper; private final InputStream inputStream; @@ -56,36 +56,28 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private final Sinks.One inboundReady = Sinks.one(); - /** - * Creates a new StdioServerTransportProvider with a default ObjectMapper and System - * streams. - */ - public StdioServerTransportProvider() { - this(new ObjectMapper()); - } - /** * Creates a new StdioServerTransportProvider with the specified ObjectMapper and * System streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization */ - public StdioServerTransportProvider(ObjectMapper objectMapper) { - this(objectMapper, System.in, System.out); + public StdioServerTransportProvider(McpJsonMapper jsonMapper) { + this(jsonMapper, System.in, System.out); } /** * Creates a new StdioServerTransportProvider with the specified ObjectMapper and * streams. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param jsonMapper The JsonMapper to use for JSON serialization/deserialization * @param inputStream The input stream to read from * @param outputStream The output stream to write to */ - public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream inputStream, OutputStream outputStream) { - Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + public StdioServerTransportProvider(McpJsonMapper jsonMapper, InputStream inputStream, OutputStream outputStream) { + Assert.notNull(jsonMapper, "The JsonMapper can not be null"); Assert.notNull(inputStream, "The InputStream can not be null"); Assert.notNull(outputStream, "The OutputStream can not be null"); - this.objectMapper = objectMapper; + this.jsonMapper = jsonMapper; this.inputStream = inputStream; this.outputStream = outputStream; } @@ -106,7 +98,7 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { @Override public Mono notifyClients(String method, Object params) { if (this.session == null) { - return Mono.error(new McpError("No session to close")); + return Mono.error(new IllegalStateException("No session to close")); } return this.session.sendNotification(method, params) .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); @@ -165,8 +157,8 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return jsonMapper.convertValue(data, typeRef); } @Override @@ -219,7 +211,7 @@ private void startInboundProcessing() { logger.debug("Received JSON message: {}", line); try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(jsonMapper, line); if (!this.inboundSink.tryEmitNext(message).isSuccess()) { // logIfNotClosing("Failed to enqueue message"); @@ -263,7 +255,7 @@ private void startOutboundProcessing() { .handle((message, sink) -> { if (message != null && !isClosing.get()) { try { - String jsonMessage = objectMapper.writeValueAsString(message); + String jsonMessage = jsonMapper.writeValueAsString(message); // Escape any embedded newlines in the JSON message as per spec jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java new file mode 100644 index 000000000..b18364abb --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ClosedMcpTransportSession.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ +package io.modelcontextprotocol.spec; + +import java.util.Optional; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Represents a closed MCP session, which may not be reused. All calls will throw a + * {@link McpTransportSessionClosedException}. + * + * @param the resource representing the connection that the transport + * manages. + * @author Daniel Garnier-Moiroux + */ +public class ClosedMcpTransportSession implements McpTransportSession { + + private final String sessionId; + + public ClosedMcpTransportSession(@Nullable String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Optional sessionId() { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public boolean markInitialized(String sessionId) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void addConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void removeConnection(CONNECTION connection) { + throw new McpTransportSessionClosedException(sessionId); + } + + @Override + public void close() { + + } + + @Override + public Publisher closeGracefully() { + return Mono.empty(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java new file mode 100644 index 000000000..6afc2c119 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Names of HTTP headers in use by MCP HTTP transports. + * + * @author Dariusz JΔ™drzejczyk + */ +public interface HttpHeaders { + + /** + * Identifies individual MCP sessions. + */ + String MCP_SESSION_ID = "Mcp-Session-Id"; + + /** + * Identifies events within an SSE Stream. + */ + String LAST_EVENT_ID = "Last-Event-ID"; + + /** + * Identifies the MCP protocol version. + */ + String PROTOCOL_VERSION = "MCP-Protocol-Version"; + + /** + * The HTTP Content-Length header. + * @see RFC9110 + */ + String CONTENT_LENGTH = "Content-Length"; + + /** + * The HTTP Content-Type header. + * @see RFC9110 + */ + String CONTENT_TYPE = "Content-Type"; + + /** + * The HTTP Accept header. + * @see RFC9110 + */ + String ACCEPT = "Accept"; + + /** + * The HTTP Cache-Control header. + * @see RFC9111 + */ + String CACHE_CONTROL = "Cache-Control"; + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java similarity index 93% rename from mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java index 572d7c043..4a42c9ff3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java @@ -40,6 +40,6 @@ public static ValidationResponse asInvalid(String message) { * @return A ValidationResponse indicating whether the validation was successful or * not. */ - ValidationResponse validate(Map schema, Map structuredContent); + ValidationResponse validate(Map schema, Object structuredContent); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java similarity index 89% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index f7db3d7aa..80b5ae246 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -4,7 +4,7 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -35,6 +35,7 @@ * * @author Christian Tzolov * @author Dariusz JΔ™drzejczyk + * @author Yanming Zhou */ public class McpClientSession implements McpSession { @@ -95,21 +96,6 @@ public interface NotificationHandler { } - /** - * Creates a new McpClientSession with the specified configuration and handlers. - * @param requestTimeout Duration to wait for responses - * @param transport Transport implementation for message exchange - * @param requestHandlers Map of method names to request handlers - * @param notificationHandlers Map of method names to notification handlers - * @deprecated Use - * {@link #McpClientSession(Duration, McpClientTransport, Map, Map, Function)} - */ - @Deprecated - public McpClientSession(Duration requestTimeout, McpClientTransport transport, - Map> requestHandlers, Map notificationHandlers) { - this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity()); - } - /** * Creates a new McpClientSession with the specified configuration and handlers. * @param requestTimeout Duration to wait for responses @@ -146,21 +132,34 @@ private void dismissPendingResponses() { private void handle(McpSchema.JSONRPCMessage message) { if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); + logger.debug("Received response: {}", response); + if (response.id() != null) { + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); handleIncomingRequest(request).onErrorResume(error -> { + + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); + jsonRpcError); return Mono.just(errorResponse); }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { logger.warn("Issue sending response to the client, ", t); @@ -246,7 +245,7 @@ private String generateRequestId() { * @return A Mono containing the response */ @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = this.generateRequestId(); return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java new file mode 100644 index 000000000..a3e7890e6 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -0,0 +1,111 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; +import io.modelcontextprotocol.util.Assert; + +import java.util.Map; +import java.util.function.Function; + +public class McpError extends RuntimeException { + + /** + * Resource + * Error Handling + */ + public static final Function RESOURCE_NOT_FOUND = resourceUri -> new McpError(new JSONRPCError( + McpSchema.ErrorCodes.RESOURCE_NOT_FOUND, "Resource not found", Map.of("uri", resourceUri))); + + private JSONRPCError jsonRpcError; + + public McpError(JSONRPCError jsonRpcError) { + super(jsonRpcError.message()); + this.jsonRpcError = jsonRpcError; + } + + public JSONRPCError getJsonRpcError() { + return jsonRpcError; + } + + @Override + public String toString() { + var builder = new StringBuilder(super.toString()); + if (jsonRpcError != null) { + builder.append("\n"); + builder.append(jsonRpcError.toString()); + } + return builder.toString(); + } + + public static Builder builder(int errorCode) { + return new Builder(errorCode); + } + + public static class Builder { + + private final int code; + + private String message; + + private Object data; + + private Builder(int code) { + this.code = code; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder data(Object data) { + this.data = data; + return this; + } + + public McpError build() { + Assert.hasText(message, "message must not be empty"); + return new McpError(new JSONRPCError(code, message, data)); + } + + } + + public static Throwable findRootCause(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + Throwable rootCause = throwable; + while (rootCause.getCause() != null && rootCause.getCause() != rootCause) { + rootCause = rootCause.getCause(); + } + return rootCause; + } + + public static String aggregateExceptionMessages(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + + StringBuilder messages = new StringBuilder(); + Throwable current = throwable; + + while (current != null) { + if (messages.length() > 0) { + messages.append("\n Caused by: "); + } + + messages.append(current.getClass().getSimpleName()); + if (current.getMessage() != null) { + messages.append(": ").append(current.getMessage()); + } + + if (current.getCause() == current) { + break; + } + current = current.getCause(); + } + + return messages.toString(); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java similarity index 87% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 3f8150271..bb9cead7e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -11,20 +11,17 @@ import java.util.List; import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.annotation.JsonTypeInfo.As; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Based on the JSON-RPC 2.0 @@ -44,9 +41,6 @@ public final class McpSchema { private McpSchema() { } - @Deprecated - public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; - public static final String JSONRPC_VERSION = "2.0"; public static final String FIRST_PAGE = null; @@ -111,8 +105,6 @@ private McpSchema() { // Elicitation Methods public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - // --------------------------- // JSON-RPC Error Codes // --------------------------- @@ -146,44 +138,58 @@ public static final class ErrorCodes { */ public static final int INTERNAL_ERROR = -32603; + /** + * Resource not found. + */ + public static final int RESOURCE_NOT_FOUND = -32002; + } - public sealed interface Request - permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, - GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, PaginatedRequest { + /** + * Base interface for MCP objects that include optional metadata in the `_meta` field. + */ + public interface Meta { + /** + * @see Specification + * for notes on _meta usage + * @return additional metadata related to this resource. + */ Map meta(); - default String progressToken() { + } + + public sealed interface Request extends Meta + permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, + GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, PaginatedRequest { + + default Object progressToken() { if (meta() != null && meta().containsKey("progressToken")) { - return meta().get("progressToken").toString(); + return meta().get("progressToken"); } return null; } } - public sealed interface Result permits InitializeResult, ListResourcesResult, ListResourceTemplatesResult, - ReadResourceResult, ListPromptsResult, GetPromptResult, ListToolsResult, CallToolResult, - CreateMessageResult, ElicitResult, CompleteResult, ListRootsResult { - - Map meta(); + public sealed interface Result extends Meta permits InitializeResult, ListResourcesResult, + ListResourceTemplatesResult, ReadResourceResult, ListPromptsResult, GetPromptResult, ListToolsResult, + CallToolResult, CreateMessageResult, ElicitResult, CompleteResult, ListRootsResult { } - public sealed interface Notification + public sealed interface Notification extends Meta permits ProgressNotification, LoggingMessageNotification, ResourcesUpdatedNotification { - Map meta(); - } - private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + private static final TypeRef> MAP_TYPE_REF = new TypeRef<>() { }; /** * Deserializes a JSON string into a JSONRPCMessage object. - * @param objectMapper The ObjectMapper instance to use for deserialization + * @param jsonMapper The JsonMapper instance to use for deserialization * @param jsonText The JSON string to deserialize * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. @@ -191,22 +197,22 @@ public sealed interface Notification * @throws IllegalArgumentException If the JSON structure doesn't match any known * message type */ - public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText) + public static JSONRPCMessage deserializeJsonRpcMessage(McpJsonMapper jsonMapper, String jsonText) throws IOException { logger.debug("Received JSON message: {}", jsonText); - var map = objectMapper.readValue(jsonText, MAP_TYPE_REF); + var map = jsonMapper.readValue(jsonText, MAP_TYPE_REF); // Determine message type based on specific JSON structure if (map.containsKey("method") && map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCRequest.class); + return jsonMapper.convertValue(map, JSONRPCRequest.class); } else if (map.containsKey("method") && !map.containsKey("id")) { - return objectMapper.convertValue(map, JSONRPCNotification.class); + return jsonMapper.convertValue(map, JSONRPCNotification.class); } else if (map.containsKey("result") || map.containsKey("error")) { - return objectMapper.convertValue(map, JSONRPCResponse.class); + return jsonMapper.convertValue(map, JSONRPCResponse.class); } throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText); @@ -268,12 +274,12 @@ public record JSONRPCNotification( // @formatter:off } /** - * A successful (non-error) response to a request. + * A response to a request (successful, or error). * * @param jsonrpc The JSON-RPC version (must be "2.0") * @param id The request identifier that this response corresponds to - * @param result The result of the successful request - * @param error Error information if the request failed + * @param result The result of the successful request; null if error + * @param error Error information if the request failed; null if has result */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) @@ -297,7 +303,7 @@ public record JSONRPCResponse( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCError( // @formatter:off - @JsonProperty("code") int code, + @JsonProperty("code") Integer code, @JsonProperty("message") String message, @JsonProperty("data") Object data) { // @formatter:on } @@ -407,9 +413,47 @@ public record Sampling() { * maintain control over user interactions and data sharing while enabling servers * to gather necessary information dynamically. Servers can request structured * data from users with optional JSON schemas to validate responses. + * + *

+ * Per the 2025-11-25 spec, clients can declare support for specific elicitation + * modes: + *

    + *
  • {@code form} - In-band structured data collection with optional schema + * validation
  • + *
  • {@code url} - Out-of-band interaction via URL navigation
  • + *
+ * + *

+ * For backward compatibility, an empty elicitation object {@code {}} is + * equivalent to declaring support for form mode only. + * + * @param form support for in-band form-based elicitation + * @param url support for out-of-band URL-based elicitation */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - public record Elicitation() { + public record Elicitation(@JsonProperty("form") Form form, @JsonProperty("url") Url url) { + + /** + * Marker record indicating support for form-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Form() { + } + + /** + * Marker record indicating support for URL-based elicitation mode. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public record Url() { + } + + /** + * Creates an Elicitation with default settings (backward compatible, produces + * empty JSON object). + */ + public Elicitation() { + this(null, null); + } } public static Builder builder() { @@ -441,11 +485,28 @@ public Builder sampling() { return this; } + /** + * Enables elicitation capability with default settings (backward compatible, + * produces empty JSON object). + * @return this builder + */ public Builder elicitation() { this.elicitation = new Elicitation(); return this; } + /** + * Enables elicitation capability with explicit form and/or url mode support. + * @param form whether to support form-based elicitation + * @param url whether to support URL-based elicitation + * @return this builder + */ + public Builder elicitation(boolean form, boolean url) { + this.elicitation = new Elicitation(form ? new Elicitation.Form() : null, + url ? new Elicitation.Url() : null); + return this; + } + public ClientCapabilities build() { return new ClientCapabilities(experimental, roots, sampling, elicitation); } @@ -607,7 +668,7 @@ public ServerCapabilities build() { public record Implementation( // @formatter:off @JsonProperty("name") String name, @JsonProperty("title") String title, - @JsonProperty("version") String version) implements BaseMetadata { // @formatter:on + @JsonProperty("version") String version) implements Identifier { // @formatter:on public Implementation(String name, String version) { this(name, null, version); @@ -651,7 +712,13 @@ public interface Annotated { @JsonIgnoreProperties(ignoreUnknown = true) public record Annotations( // @formatter:off @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority) { // @formatter:on + @JsonProperty("priority") Double priority, + @JsonProperty("lastModified") String lastModified + ) { // @formatter:on + + public Annotations(List audience, Double priority) { + this(audience, priority, null); + } } /** @@ -660,7 +727,9 @@ public record Annotations( // @formatter:off * interface is implemented by both {@link Resource} and {@link ResourceLink} to * provide a consistent way to access resource metadata. */ - public interface ResourceContent extends BaseMetadata { + public interface ResourceContent extends Identifier, Annotated, Meta { + + // name & title from Identifier String uri(); @@ -670,15 +739,15 @@ public interface ResourceContent extends BaseMetadata { Long size(); - Annotations annotations(); + // annotations from Annotated + // meta from Meta } /** - * Base interface for metadata with name (identifier) and title (display name) - * properties. + * Base interface with name (identifier) and title (display name) properties. */ - public interface BaseMetadata { + public interface Identifier { /** * Intended for programmatic or logical use, but used as a display name in past @@ -724,36 +793,7 @@ public record Resource( // @formatter:off @JsonProperty("mimeType") String mimeType, @JsonProperty("size") Long size, @JsonProperty("annotations") Annotations annotations, - @JsonProperty("_meta") Map meta) implements Annotated, ResourceContent { // @formatter:on - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Resource#builder()} instead. - */ - @Deprecated - public Resource(String uri, String name, String title, String description, String mimeType, Long size, - Annotations annotations) { - this(uri, name, title, description, mimeType, size, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Resource#builder()} instead. - */ - @Deprecated - public Resource(String uri, String name, String description, String mimeType, Long size, - Annotations annotations) { - this(uri, name, null, description, mimeType, size, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Resource#builder()} instead. - */ - @Deprecated - public Resource(String uri, String name, String description, String mimeType, Annotations annotations) { - this(uri, name, null, description, mimeType, null, annotations, null); - } + @JsonProperty("_meta") Map meta) implements ResourceContent { // @formatter:on public static Builder builder() { return new Builder(); @@ -854,7 +894,7 @@ public record ResourceTemplate( // @formatter:off @JsonProperty("description") String description, @JsonProperty("mimeType") String mimeType, @JsonProperty("annotations") Annotations annotations, - @JsonProperty("_meta") Map meta) implements Annotated, BaseMetadata { // @formatter:on + @JsonProperty("_meta") Map meta) implements Annotated, Identifier, Meta { // @formatter:on public ResourceTemplate(String uriTemplate, String name, String title, String description, String mimeType, Annotations annotations) { @@ -865,6 +905,70 @@ public ResourceTemplate(String uriTemplate, String name, String description, Str Annotations annotations) { this(uriTemplate, name, null, description, mimeType, annotations); } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String uriTemplate; + + private String name; + + private String title; + + private String description; + + private String mimeType; + + private Annotations annotations; + + private Map meta; + + public Builder uriTemplate(String uri) { + this.uriTemplate = uri; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder mimeType(String mimeType) { + this.mimeType = mimeType; + return this; + } + + public Builder annotations(Annotations annotations) { + this.annotations = annotations; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ResourceTemplate build() { + Assert.hasText(uriTemplate, "uri must not be empty"); + Assert.hasText(name, "name must not be empty"); + + return new ResourceTemplate(uriTemplate, name, title, description, mimeType, annotations, meta); + } + + } } /** @@ -982,10 +1086,10 @@ public UnsubscribeRequest(String uri) { /** * The contents of a specific resource or sub-resource. */ - @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION, include = As.PROPERTY) - @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class, name = "text"), - @JsonSubTypes.Type(value = BlobResourceContents.class, name = "blob") }) - public sealed interface ResourceContents permits TextResourceContents, BlobResourceContents { + @JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION) + @JsonSubTypes({ @JsonSubTypes.Type(value = TextResourceContents.class), + @JsonSubTypes.Type(value = BlobResourceContents.class) }) + public sealed interface ResourceContents extends Meta permits TextResourceContents, BlobResourceContents { /** * The URI of this resource. @@ -999,14 +1103,6 @@ public sealed interface ResourceContents permits TextResourceContents, BlobResou */ String mimeType(); - /** - * @see Specification - * for notes on _meta usage - * @return additional metadata related to this resource. - */ - Map meta(); - } /** @@ -1073,7 +1169,7 @@ public record Prompt( // @formatter:off @JsonProperty("title") String title, @JsonProperty("description") String description, @JsonProperty("arguments") List arguments, - @JsonProperty("_meta") Map meta) implements BaseMetadata { // @formatter:on + @JsonProperty("_meta") Map meta) implements Identifier { // @formatter:on public Prompt(String name, String description, List arguments) { this(name, null, description, arguments != null ? arguments : new ArrayList<>()); @@ -1098,7 +1194,7 @@ public record PromptArgument( // @formatter:off @JsonProperty("name") String name, @JsonProperty("title") String title, @JsonProperty("description") String description, - @JsonProperty("required") Boolean required) implements BaseMetadata { // @formatter:on + @JsonProperty("required") Boolean required) implements Identifier { // @formatter:on public PromptArgument(String name, String description, Boolean required) { this(name, null, description, required); @@ -1272,53 +1368,6 @@ public record Tool( // @formatter:off @JsonProperty("annotations") ToolAnnotations annotations, @JsonProperty("_meta") Map meta) { // @formatter:on - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, JsonSchema inputSchema, ToolAnnotations annotations) { - this(name, null, description, inputSchema, null, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, String inputSchema) { - this(name, null, description, parseSchema(inputSchema), null, null, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, String schema, ToolAnnotations annotations) { - this(name, null, description, parseSchema(schema), null, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String description, String inputSchema, String outputSchema, - ToolAnnotations annotations) { - this(name, null, description, parseSchema(inputSchema), schemaToMap(outputSchema), annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link Tool#builder()} instead. - */ - @Deprecated - public Tool(String name, String title, String description, String inputSchema, String outputSchema, - ToolAnnotations annotations) { - this(name, title, description, parseSchema(inputSchema), schemaToMap(outputSchema), annotations, null); - } - public static Builder builder() { return new Builder(); } @@ -1359,8 +1408,8 @@ public Builder inputSchema(JsonSchema inputSchema) { return this; } - public Builder inputSchema(String inputSchema) { - this.inputSchema = parseSchema(inputSchema); + public Builder inputSchema(McpJsonMapper jsonMapper, String inputSchema) { + this.inputSchema = parseSchema(jsonMapper, inputSchema); return this; } @@ -1369,8 +1418,8 @@ public Builder outputSchema(Map outputSchema) { return this; } - public Builder outputSchema(String outputSchema) { - this.outputSchema = schemaToMap(outputSchema); + public Builder outputSchema(McpJsonMapper jsonMapper, String outputSchema) { + this.outputSchema = schemaToMap(jsonMapper, outputSchema); return this; } @@ -1392,18 +1441,18 @@ public Tool build() { } } - private static Map schemaToMap(String schema) { + private static Map schemaToMap(McpJsonMapper jsonMapper, String schema) { try { - return OBJECT_MAPPER.readValue(schema, MAP_TYPE_REF); + return jsonMapper.readValue(schema, MAP_TYPE_REF); } catch (IOException e) { throw new IllegalArgumentException("Invalid schema: " + schema, e); } } - private static JsonSchema parseSchema(String schema) { + private static JsonSchema parseSchema(McpJsonMapper jsonMapper, String schema) { try { - return OBJECT_MAPPER.readValue(schema, JsonSchema.class); + return jsonMapper.readValue(schema, JsonSchema.class); } catch (IOException e) { throw new IllegalArgumentException("Invalid schema: " + schema, e); @@ -1427,17 +1476,17 @@ public record CallToolRequest( // @formatter:off @JsonProperty("arguments") Map arguments, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on - public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments), null); + public CallToolRequest(McpJsonMapper jsonMapper, String name, String jsonArguments) { + this(name, parseJsonArguments(jsonMapper, jsonArguments), null); } public CallToolRequest(String name, Map arguments) { this(name, arguments, null); } - private static Map parseJsonArguments(String jsonArguments) { + private static Map parseJsonArguments(McpJsonMapper jsonMapper, String jsonArguments) { try { - return OBJECT_MAPPER.readValue(jsonArguments, MAP_TYPE_REF); + return jsonMapper.readValue(jsonArguments, MAP_TYPE_REF); } catch (IOException e) { throw new IllegalArgumentException("Invalid arguments: " + jsonArguments, e); @@ -1466,8 +1515,8 @@ public Builder arguments(Map arguments) { return this; } - public Builder arguments(String jsonArguments) { - this.arguments = parseJsonArguments(jsonArguments); + public Builder arguments(McpJsonMapper jsonMapper, String jsonArguments) { + this.arguments = parseJsonArguments(jsonMapper, jsonArguments); return this; } @@ -1476,7 +1525,7 @@ public Builder meta(Map meta) { return this; } - public Builder progressToken(String progressToken) { + public Builder progressToken(Object progressToken) { if (this.meta == null) { this.meta = new HashMap<>(); } @@ -1508,32 +1557,9 @@ public CallToolRequest build() { public record CallToolResult( // @formatter:off @JsonProperty("content") List content, @JsonProperty("isError") Boolean isError, - @JsonProperty("structuredContent") Map structuredContent, + @JsonProperty("structuredContent") Object structuredContent, @JsonProperty("_meta") Map meta) implements Result { // @formatter:on - // backwards compatibility constructor - public CallToolResult(List content, Boolean isError) { - this(content, isError, null, null); - } - - // backwards compatibility constructor - public CallToolResult(List content, Boolean isError, Map structuredContent) { - this(content, isError, structuredContent, null); - } - - /** - * Creates a new instance of {@link CallToolResult} with a string containing the - * tool result. - * @param content The content of the tool result. This will be mapped to a - * one-sized list with a {@link TextContent} element. - * @param isError If true, indicates that the tool execution failed and the - * content contains error information. If false or absent, indicates successful - * execution. - */ - public CallToolResult(String content, Boolean isError) { - this(List.of(new TextContent(content)), isError, null); - } - /** * Creates a builder for {@link CallToolResult}. * @return a new builder instance @@ -1551,7 +1577,7 @@ public static class Builder { private Boolean isError = false; - private Map structuredContent; + private Object structuredContent; private Map meta; @@ -1566,16 +1592,16 @@ public Builder content(List content) { return this; } - public Builder structuredContent(Map structuredContent) { + public Builder structuredContent(Object structuredContent) { Assert.notNull(structuredContent, "structuredContent must not be null"); this.structuredContent = structuredContent; return this; } - public Builder structuredContent(String structuredContent) { + public Builder structuredContent(McpJsonMapper jsonMapper, String structuredContent) { Assert.hasText(structuredContent, "structuredContent must not be empty"); try { - this.structuredContent = OBJECT_MAPPER.readValue(structuredContent, MAP_TYPE_REF); + this.structuredContent = jsonMapper.readValue(structuredContent, MAP_TYPE_REF); } catch (IOException e) { throw new IllegalArgumentException("Invalid structured content: " + structuredContent, e); @@ -1790,14 +1816,14 @@ public record CreateMessageRequest( // @formatter:off @JsonProperty("systemPrompt") String systemPrompt, @JsonProperty("includeContext") ContextInclusionStrategy includeContext, @JsonProperty("temperature") Double temperature, - @JsonProperty("maxTokens") int maxTokens, + @JsonProperty("maxTokens") Integer maxTokens, @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on // backwards compatibility constructor public CreateMessageRequest(List messages, ModelPreferences modelPreferences, - String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, int maxTokens, + String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, Integer maxTokens, List stopSequences, Map metadata) { this(messages, modelPreferences, systemPrompt, includeContext, temperature, maxTokens, stopSequences, metadata, null); @@ -1827,7 +1853,7 @@ public static class Builder { private Double temperature; - private int maxTokens; + private Integer maxTokens; private List stopSequences; @@ -1880,7 +1906,7 @@ public Builder meta(Map meta) { return this; } - public Builder progressToken(String progressToken) { + public Builder progressToken(Object progressToken) { if (this.meta == null) { this.meta = new HashMap<>(); } @@ -2048,7 +2074,7 @@ public Builder meta(Map meta) { return this; } - public Builder progressToken(String progressToken) { + public Builder progressToken(Object progressToken) { if (this.meta == null) { this.meta = new HashMap<>(); } @@ -2185,13 +2211,13 @@ public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ProgressNotification( // @formatter:off - @JsonProperty("progressToken") String progressToken, + @JsonProperty("progressToken") Object progressToken, @JsonProperty("progress") Double progress, @JsonProperty("total") Double total, @JsonProperty("message") String message, @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on - public ProgressNotification(String progressToken, double progress, Double total, String message) { + public ProgressNotification(Object progressToken, double progress, Double total, String message) { this(progressToken, progress, total, message, null); } } @@ -2203,6 +2229,7 @@ public ProgressNotification(String progressToken, double progress, Double total, * @param uri The updated resource uri. * @param meta See specification for notes on _meta usage */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ResourcesUpdatedNotification(// @formatter:off @JsonProperty("uri") String uri, @@ -2224,6 +2251,7 @@ public ResourcesUpdatedNotification(String uri) { * @param data JSON-serializable logging data. * @param meta See specification for notes on _meta usage */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record LoggingMessageNotification( // @formatter:off @JsonProperty("level") LoggingLevel level, @@ -2337,14 +2365,16 @@ public sealed interface CompleteReference permits PromptReference, ResourceRefer public record PromptReference( // @formatter:off @JsonProperty("type") String type, @JsonProperty("name") String name, - @JsonProperty("title") String title ) implements McpSchema.CompleteReference, BaseMetadata { // @formatter:on + @JsonProperty("title") String title ) implements McpSchema.CompleteReference, Identifier { // @formatter:on + + public static final String TYPE = "ref/prompt"; public PromptReference(String type, String name) { this(type, name, null); } public PromptReference(String name) { - this("ref/prompt", name, null); + this(TYPE, name, null); } @Override @@ -2381,8 +2411,10 @@ public record ResourceReference( // @formatter:off @JsonProperty("type") String type, @JsonProperty("uri") String uri) implements McpSchema.CompleteReference { // @formatter:on + public static final String TYPE = "ref/resource"; + public ResourceReference(String uri) { - this("ref/resource", uri); + this(TYPE, uri); } @Override @@ -2445,8 +2477,9 @@ public record CompleteContext(@JsonProperty("arguments") Map arg */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - public record CompleteResult(@JsonProperty("completion") CompleteCompletion completion, - @JsonProperty("_meta") Map meta) implements Result { + public record CompleteResult(// @formatter:off + @JsonProperty("completion") CompleteCompletion completion, + @JsonProperty("_meta") Map meta) implements Result { // @formatter:on // backwards compatibility constructor public CompleteResult(CompleteCompletion completion) { @@ -2462,6 +2495,7 @@ public CompleteResult(CompleteCompletion completion) { * @param hasMore Indicates whether there are additional completion options beyond * those provided in the current response, even if the exact total is unknown */ + @JsonInclude(JsonInclude.Include.ALWAYS) public record CompleteCompletion( // @formatter:off @JsonProperty("values") List values, @JsonProperty("total") Integer total, @@ -2478,9 +2512,8 @@ public record CompleteCompletion( // @formatter:off @JsonSubTypes.Type(value = AudioContent.class, name = "audio"), @JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource"), @JsonSubTypes.Type(value = ResourceLink.class, name = "resource_link") }) - public sealed interface Content permits TextContent, ImageContent, AudioContent, EmbeddedResource, ResourceLink { - - Map meta(); + public sealed interface Content extends Meta + permits TextContent, ImageContent, AudioContent, EmbeddedResource, ResourceLink { default String type() { if (this instanceof TextContent) { @@ -2524,33 +2557,6 @@ public TextContent(Annotations annotations, String text) { public TextContent(String content) { this(null, content, null); } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link TextContent#TextContent(Annotations, String)} instead. - */ - @Deprecated - public TextContent(List audience, Double priority, String content) { - this(audience != null || priority != null ? new Annotations(audience, priority) : null, content, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link TextContent#annotations()} instead. - */ - @Deprecated - public List audience() { - return annotations == null ? null : annotations.audience(); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link TextContent#annotations()} instead. - */ - @Deprecated - public Double priority() { - return annotations == null ? null : annotations.priority(); - } } /** @@ -2573,34 +2579,6 @@ public record ImageContent( // @formatter:off public ImageContent(Annotations annotations, String data, String mimeType) { this(annotations, data, mimeType, null); } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ImageContent#ImageContent(Annotations, String, String)} instead. - */ - @Deprecated - public ImageContent(List audience, Double priority, String data, String mimeType) { - this(audience != null || priority != null ? new Annotations(audience, priority) : null, data, mimeType, - null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ImageContent#annotations()} instead. - */ - @Deprecated - public List audience() { - return annotations == null ? null : annotations.audience(); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ImageContent#annotations()} instead. - */ - @Deprecated - public Double priority() { - return annotations == null ? null : annotations.priority(); - } } /** @@ -2647,34 +2625,6 @@ public record EmbeddedResource( // @formatter:off public EmbeddedResource(Annotations annotations, ResourceContents resource) { this(annotations, resource, null); } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link EmbeddedResource#EmbeddedResource(Annotations, ResourceContents)} - * instead. - */ - @Deprecated - public EmbeddedResource(List audience, Double priority, ResourceContents resource) { - this(audience != null || priority != null ? new Annotations(audience, priority) : null, resource, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link EmbeddedResource#annotations()} instead. - */ - @Deprecated - public List audience() { - return annotations == null ? null : annotations.audience(); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link EmbeddedResource#annotations()} instead. - */ - @Deprecated - public Double priority() { - return annotations == null ? null : annotations.priority(); - } } /** @@ -2705,29 +2655,7 @@ public record ResourceLink( // @formatter:off @JsonProperty("mimeType") String mimeType, @JsonProperty("size") Long size, @JsonProperty("annotations") Annotations annotations, - @JsonProperty("_meta") Map meta) implements Annotated, Content, ResourceContent { // @formatter:on - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ResourceLink#ResourceLink(String, String, String, String, String, Long, Annotations)} - * instead. - */ - @Deprecated - public ResourceLink(String name, String title, String uri, String description, String mimeType, Long size, - Annotations annotations) { - this(name, title, uri, description, mimeType, size, annotations, null); - } - - /** - * @deprecated Only exists for backwards-compatibility purposes. Use - * {@link ResourceLink#ResourceLink(String, String, String, String, String, Long, Annotations)} - * instead. - */ - @Deprecated - public ResourceLink(String name, String uri, String description, String mimeType, Long size, - Annotations annotations) { - this(name, null, uri, description, mimeType, size, annotations); - } + @JsonProperty("_meta") Map meta) implements Content, ResourceContent { // @formatter:on public static Builder builder() { return new Builder(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java similarity index 76% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index e562ca012..ecb1dafd8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -11,12 +11,12 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpInitRequestHandler; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,34 +86,6 @@ public McpServerSession(String id, Duration requestTimeout, McpServerTransport t this.notificationHandlers = notificationHandlers; } - /** - * Creates a new server session with the given parameters and the transport to use. - * @param id session id - * @param transport the transport to use - * @param initHandler called when a - * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the - * server - * @param initNotificationHandler called when a - * {@link io.modelcontextprotocol.spec.McpSchema#METHOD_NOTIFICATION_INITIALIZED} is - * received. - * @param requestHandlers map of request handlers to use - * @param notificationHandlers map of notification handlers to use - * @deprecated Use - * {@link #McpServerSession(String, Duration, McpServerTransport, McpInitRequestHandler, Map, Map)} - */ - @Deprecated - public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, - McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, - Map notificationHandlers) { - this.id = id; - this.requestTimeout = requestTimeout; - this.transport = transport; - this.initRequestHandler = initHandler; - this.requestHandlers = requestHandlers; - this.notificationHandlers = notificationHandlers; - } - /** * Retrieve the session id. * @return session id @@ -153,7 +125,7 @@ public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = this.generateRequestId(); return Mono.create(sink -> { @@ -204,22 +176,32 @@ public Mono handle(McpSchema.JSONRPCMessage message) { // TODO handle errors for communication to without initialization happening // first if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); + logger.debug("Received response: {}", response); + if (response.id() != null) { + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } return Mono.empty(); } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); return handleIncomingRequest(request, transportContext).onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); + jsonRpcError); // TODO: Should the error go to SSE or back as POST return? return this.transport.sendMessage(errorResponse).then(Mono.empty()); }).flatMap(this.transport::sendMessage); @@ -252,7 +234,7 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { // TODO handle situation where already initialized! McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), - new TypeReference() { + new TypeRef() { }); this.state.lazySet(STATE_INITIALIZING); @@ -275,10 +257,15 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .onErrorResume(error -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (error instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + // TODO: add error message through the data field + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), McpError.aggregateExceptionMessages(error)); + return Mono.just( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, jsonRpcError)); + }); }); } @@ -340,23 +327,6 @@ public void close() { this.transport.close(); } - /** - * Request handler for the initialization request. - * - * @deprecated Use {@link McpInitRequestHandler} - */ - @Deprecated - public interface InitRequestHandler { - - /** - * Handles the initialization request. - * @param initializeRequest the initialization request by the client - * @return a Mono that will emit the result of the initialization - */ - Mono handle(McpSchema.InitializeRequest initializeRequest); - - } - /** * Notification handler for the initialization notification from the client. */ @@ -370,46 +340,6 @@ public interface InitNotificationHandler { } - /** - * A handler for client-initiated notifications. - * - * @deprecated Use {@link McpNotificationHandler} - */ - @Deprecated - public interface NotificationHandler { - - /** - * Handles a notification from the client. - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the notification. - * @return a Mono that completes once the notification is handled. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); - - } - - /** - * A handler for client-initiated requests. - * - * @param the type of the response that is expected as a result of handling the - * request. - * @deprecated Use {@link McpRequestHandler} - */ - @Deprecated - public interface RequestHandler { - - /** - * Handles a request from the client. - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the request. - * @return a Mono that will emit the response to the request. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); - - } - /** * Factory for creating server sessions which delegate to a provided 1:1 transport * with a connected client. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java similarity index 97% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 3473a4da8..767ed673e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -4,7 +4,7 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** @@ -37,7 +37,7 @@ public interface McpSession { * @param typeRef the TypeReference describing the expected response type * @return a Mono that will emit the response when received */ - Mono sendRequest(String method, Object requestParams, TypeReference typeRef); + Mono sendRequest(String method, Object requestParams, TypeRef typeRef); /** * Sends a notification to the model client or server without parameters. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java similarity index 87% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java index c1234b130..ee28f5ff8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -29,7 +29,8 @@ default void close() { Mono closeGracefully(); default List protocolVersions() { - return List.of(ProtocolVersions.MCP_2025_03_26); + return List.of(ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18, + ProtocolVersions.MCP_2025_11_25); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java similarity index 90% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index ef7967c1e..95f8959f5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -15,12 +15,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; -import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -33,6 +34,7 @@ * capability without the insight into the transport-specific details of HTTP handling. * * @author Dariusz JΔ™drzejczyk + * @author Yanming Zhou */ public class McpStreamableServerSession implements McpLoggableSession { @@ -108,7 +110,7 @@ private String generateRequestId() { } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { return Mono.defer(() -> { McpLoggableSession listeningStream = this.listeningStreamRef.get(); return listeningStream.sendRequest(method, requestParams, typeRef); @@ -177,9 +179,13 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStr .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) .onErrorResume(e -> { + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = (e instanceof McpError mcpError + && mcpError.getJsonRpcError() != null) ? mcpError.getJsonRpcError() + : new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), McpError.aggregateExceptionMessages(e)); + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - e.getMessage(), null)); + null, jsonRpcError); return Mono.just(errorResponse); }) .flatMap(transport::sendMessage) @@ -214,19 +220,30 @@ public Mono accept(McpSchema.JSONRPCNotification notification) { */ public Mono accept(McpSchema.JSONRPCResponse response) { return Mono.defer(() -> { - var stream = this.requestIdToStream.get(response.id()); - if (stream == null) { - return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO - // JSONize - } - // TODO: encapsulate this inside the stream itself - var sink = stream.pendingResponses.remove(response.id()); - if (sink == null) { - return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO - // JSONize + logger.debug("Received response: {}", response); + + if (response.id() != null) { + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + // TODO: encapsulate this inside the stream itself + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Unexpected response for unknown id " + response.id()) + .build()); + } + else { + sink.success(response); + } } else { - sink.success(response); + logger.error("Discarded MCP request response without session id. " + + "This is an indication of a bug in the request sender code that can lead to memory " + + "leaks as pending requests will never be completed."); } return Mono.empty(); }); @@ -334,7 +351,7 @@ public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { String requestId = McpStreamableServerSession.this.generateRequestId(); McpStreamableServerSession.this.requestIdToStream.put(requestId, this); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java similarity index 95% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 1922548a6..0a732bab6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -6,8 +6,8 @@ import java.util.List; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.json.TypeRef; import reactor.core.publisher.Mono; /** @@ -77,7 +77,7 @@ default void close() { * @param typeRef the type reference for the object to unmarshal * @return the unmarshalled object */ - T unmarshalFrom(Object data, TypeReference typeRef); + T unmarshalFrom(Object data, TypeRef typeRef); default List protocolVersions() { return List.of(ProtocolVersions.MCP_2024_11_05); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportException.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java index 716ff0d16..68f0fc5bb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -4,10 +4,10 @@ package io.modelcontextprotocol.spec; -import org.reactivestreams.Publisher; - import java.util.Optional; +import org.reactivestreams.Publisher; + /** * An abstraction of the session as perceived from the MCP transport layer. Not to be * confused with the {@link McpSession} type that operates at the level of the JSON-RPC diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java new file mode 100644 index 000000000..60e2850b9 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionClosedException.java @@ -0,0 +1,23 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.util.annotation.Nullable; + +/** + * Exception thrown when trying to use an {@link McpTransportSession} that has been + * closed. + * + * @see ClosedMcpTransportSession + * @author Daniel Garnier-Moiroux + */ +public class McpTransportSessionClosedException extends RuntimeException { + + public McpTransportSessionClosedException(@Nullable String sessionId) { + super(sessionId != null ? "MCP session with ID %s has been closed".formatted(sessionId) + : "MCP session has been closed"); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java similarity index 95% rename from mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java index aa33a8167..0bf70d5b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -4,7 +4,7 @@ package io.modelcontextprotocol.spec; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -31,7 +31,7 @@ public MissingMcpTransportSession(String sessionId) { } @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java similarity index 77% rename from mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java rename to mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java index d8cb913a5..d3d34db62 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java @@ -20,4 +20,10 @@ public interface ProtocolVersions { */ String MCP_2025_06_18 = "2025-06-18"; + /** + * MCP protocol version for 2025-11-25. + * https://modelcontextprotocol.io/specification/2025-11-25 + */ + String MCP_2025_11_25 = "2025-11-25"; + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/Assert.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/Assert.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java similarity index 81% rename from mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java index b2e9a5285..c3b922edf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManager.java @@ -33,9 +33,7 @@ public class DefaultMcpUriTemplateManager implements McpUriTemplateManager { * @param uriTemplate The URI template to be used for variable extraction */ public DefaultMcpUriTemplateManager(String uriTemplate) { - if (uriTemplate == null || uriTemplate.isEmpty()) { - throw new IllegalArgumentException("URI template must not be null or empty"); - } + Assert.hasText(uriTemplate, "URI template must not be null or empty"); this.uriTemplate = uriTemplate; } @@ -48,10 +46,6 @@ public DefaultMcpUriTemplateManager(String uriTemplate) { */ @Override public List getVariableNames() { - if (uriTemplate == null || uriTemplate.isEmpty()) { - return List.of(); - } - List variables = new ArrayList<>(); Matcher matcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); @@ -81,7 +75,7 @@ public Map extractVariableValues(String requestUri) { Map variableValues = new HashMap<>(); List uriVariables = this.getVariableNames(); - if (requestUri == null || uriVariables.isEmpty()) { + if (!Utils.hasText(requestUri) || uriVariables.isEmpty()) { return variableValues; } @@ -147,12 +141,30 @@ public boolean matches(String uri) { return uri.equals(this.uriTemplate); } - // Convert the pattern to a regex - String regex = this.uriTemplate.replaceAll("\\{[^/]+?\\}", "([^/]+?)"); - regex = regex.replace("/", "\\/"); + // Convert the URI template into a robust regex pattern that escapes special + // characters like '?'. + StringBuilder patternBuilder = new StringBuilder("^"); + Matcher variableMatcher = URI_VARIABLE_PATTERN.matcher(this.uriTemplate); + int lastEnd = 0; + + while (variableMatcher.find()) { + // Append the literal part of the template, safely quoted + String textBefore = this.uriTemplate.substring(lastEnd, variableMatcher.start()); + patternBuilder.append(Pattern.quote(textBefore)); + // Append a capturing group for the variable itself + patternBuilder.append("([^/]+?)"); + lastEnd = variableMatcher.end(); + } + + // Append any remaining literal text after the last variable + if (lastEnd < this.uriTemplate.length()) { + patternBuilder.append(Pattern.quote(this.uriTemplate.substring(lastEnd))); + } + + patternBuilder.append("$"); // Check if the URI matches the regex - return Pattern.compile(regex).matcher(uri).matches(); + return Pattern.compile(patternBuilder.toString()).matcher(uri).matches(); } @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java similarity index 86% rename from mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java index 44ea31690..fd1a3bd71 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/DefaultMcpUriTemplateManagerFactory.java @@ -7,7 +7,7 @@ /** * @author Christian Tzolov */ -public class DeafaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { +public class DefaultMcpUriTemplateManagerFactory implements McpUriTemplateManagerFactory { /** * Creates a new instance of {@link McpUriTemplateManager} with the specified URI diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java similarity index 97% rename from mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java index 9d411cd41..6d53ed516 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java @@ -11,7 +11,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSession; @@ -33,7 +33,7 @@ public class KeepAliveScheduler { private static final Logger logger = LoggerFactory.getLogger(KeepAliveScheduler.class); - private static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + private static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; /** Initial delay before the first keepAlive call */ diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/McpServiceLoader.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpServiceLoader.java new file mode 100644 index 000000000..f1c73a07a --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpServiceLoader.java @@ -0,0 +1,68 @@ +/** + * Copyright 2026 - 2026 the original author or authors. + */ +package io.modelcontextprotocol.util; + +import java.util.Optional; +import java.util.ServiceConfigurationError; +import java.util.ServiceLoader; +import java.util.function.Supplier; + +/** + * Instance of this class are intended to be used differently in OSGi and non-OSGi + * environments. In all non-OSGi environments the supplier member will be + * null and the serviceLoad method will be called to use the + * ServiceLoader.load to find the first instance of the supplier (assuming one is present + * in the runtime), cache it, and call the supplier's get method. + *

+ * In OSGi environments, the Service component runtime (scr) will call the setSupplier + * method upon bundle activation (assuming one is present in the runtime), and subsequent + * calls will use the given supplier instance rather than the ServiceLoader.load. + * + * @param the type of the supplier + * @param the type of the supplier result/returned value + */ +public class McpServiceLoader, R> { + + private Class supplierType; + + private S supplier; + + private R supplierResult; + + public void setSupplier(S supplier) { + this.supplier = supplier; + this.supplierResult = null; + } + + public void unsetSupplier(S supplier) { + this.supplier = null; + this.supplierResult = null; + } + + public McpServiceLoader(Class supplierType) { + this.supplierType = supplierType; + } + + protected Optional serviceLoad(Class type) { + return ServiceLoader.load(type).findFirst(); + } + + @SuppressWarnings("unchecked") + public synchronized R getDefault() { + if (this.supplierResult == null) { + if (this.supplier == null) { + // Use serviceloader + Optional sl = serviceLoad(this.supplierType); + if (sl.isEmpty()) { + throw new ServiceConfigurationError( + "No %s available for creating McpJsonMapper".formatted(this.supplierType.getSimpleName())); + } + this.supplier = (S) sl.get(); + } + this.supplierResult = this.supplier.get(); + } + return supplierResult; + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManager.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolNameValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolNameValidator.java new file mode 100644 index 000000000..d7ac18705 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolNameValidator.java @@ -0,0 +1,83 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.regex.Pattern; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Validates tool names according to the MCP specification. + * + *

+ * Tool names must conform to the following rules: + *

    + *
  • Must be between 1 and 128 characters in length
  • + *
  • May only contain: A-Z, a-z, 0-9, underscore (_), hyphen (-), and dot (.)
  • + *
  • Must not contain spaces, commas, or other special characters
  • + *
+ * + * @see MCP + * Specification - Tool Names + * @author Andrei Shakirin + */ +public final class ToolNameValidator { + + private static final Logger logger = LoggerFactory.getLogger(ToolNameValidator.class); + + private static final int MAX_LENGTH = 128; + + private static final Pattern VALID_NAME_PATTERN = Pattern.compile("^[A-Za-z0-9_\\-.]+$"); + + /** + * System property for strict tool name validation. Set to "false" to warn only + * instead of throwing exceptions. Default is true (strict). + */ + public static final String STRICT_VALIDATION_PROPERTY = "io.modelcontextprotocol.strictToolNameValidation"; + + private ToolNameValidator() { + } + + /** + * Returns the default strict validation setting from system property. + * @return true if strict validation is enabled (default), false if disabled via + * system property + */ + public static boolean isStrictByDefault() { + return !"false".equalsIgnoreCase(System.getProperty(STRICT_VALIDATION_PROPERTY)); + } + + /** + * Validates a tool name according to MCP specification. + * @param name the tool name to validate + * @param strict if true, throws exception on invalid name; if false, logs warning + * only + * @throws IllegalArgumentException if validation fails and strict is true + */ + public static void validate(String name, boolean strict) { + if (name == null || name.isEmpty()) { + handleError("Tool name must not be null or empty", name, strict); + } + else if (name.length() > MAX_LENGTH) { + handleError("Tool name must not exceed 128 characters", name, strict); + } + else if (!VALID_NAME_PATTERN.matcher(name).matches()) { + handleError("Tool name contains invalid characters (allowed: A-Z, a-z, 0-9, _, -, .)", name, strict); + } + } + + private static void handleError(String message, String name, boolean strict) { + String fullMessage = message + ": '" + name + "'"; + if (strict) { + throw new IllegalArgumentException(fullMessage); + } + else { + logger.warn("{}. Processing continues, but tool name should be fixed.", fullMessage); + } + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java similarity index 100% rename from mcp/src/main/java/io/modelcontextprotocol/util/Utils.java rename to mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java index 039b0d68e..cd420100c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,12 +4,12 @@ package io.modelcontextprotocol.util; -import reactor.util.annotation.Nullable; - import java.net.URI; import java.util.Collection; import java.util.Map; +import reactor.util.annotation.Nullable; + /** * Miscellaneous utility methods. * diff --git a/mcp-core/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.McpJsonDefaults.xml b/mcp-core/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.McpJsonDefaults.xml new file mode 100644 index 000000000..1a10fdfb3 --- /dev/null +++ b/mcp-core/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.McpJsonDefaults.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java similarity index 86% rename from mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java index 6f041daa6..8f68f0d6e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/McpUriTemplateManagerTests.java @@ -12,7 +12,7 @@ import java.util.List; import java.util.Map; -import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManager; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import org.junit.jupiter.api.BeforeEach; @@ -29,7 +29,7 @@ public class McpUriTemplateManagerTests { @BeforeEach void setUp() { - this.uriTemplateFactory = new DeafaultMcpUriTemplateManagerFactory(); + this.uriTemplateFactory = new DefaultMcpUriTemplateManagerFactory(); } @Test @@ -94,4 +94,13 @@ void shouldMatchUriAgainstTemplatePattern() { assertFalse(uriTemplateManager.matches("/api/users/123/comments/456")); } + @Test + void shouldMatchUriWithQueryParameters() { + String templateWithQuery = "file://name/search?={search}"; + var uriTemplateManager = this.uriTemplateFactory.create(templateWithQuery); + + assertTrue(uriTemplateManager.matches("file://name/search?=abcd"), + "Should correctly match a URI containing query parameters."); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java rename to mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index b1113a6d0..061a95e69 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -9,10 +9,10 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import reactor.core.publisher.Mono; @@ -29,7 +29,7 @@ public class MockMcpClientTransport implements McpClientTransport { private final BiConsumer interceptor; - private String protocolVersion = McpSchema.LATEST_PROTOCOL_VERSION; + private String protocolVersion = ProtocolVersions.MCP_2025_11_25; public MockMcpClientTransport() { this((t, msg) -> { @@ -99,8 +99,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return (T) data; } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java new file mode 100644 index 000000000..6f7390f19 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.context.ContextView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link LifecycleInitializer} postInitializationHook functionality. + * + * @author Christian Tzolov + */ +class LifecycleInitializerPostInitializationHookTests { + + private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5); + + private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder() + .build(); + + private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); + + private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0"); + + private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0", + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + @Mock + private McpClientSession mockClientSession; + + @Mock + private Function mockSessionSupplier; + + @Mock + private Function> mockPostInitializationHook; + + private LifecycleInitializer initializer; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); + when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.empty()); + when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); + + initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); + } + + @Test + void shouldInvokePostInitializationHook() { + AtomicReference capturedInit = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + capturedInit.set(invocation.getArgument(0)); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + + // Verify the hook received correct initialization data + assertThat(capturedInit.get()).isNotNull(); + assertThat(capturedInit.get().mcpSession()).isEqualTo(mockClientSession); + assertThat(capturedInit.get().initializeResult()).isEqualTo(MOCK_INIT_RESULT); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnce() { + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Second call should reuse initialization and NOT call hook again + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Hook should only be called once + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookOnlyOnceWithConcurrentRequests() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // Start multiple concurrent initializations + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) + .subscribeOn(Schedulers.parallel()); + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) + .subscribeOn(Schedulers.parallel()); + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) + .subscribeOn(Schedulers.parallel()); + + // TODO: can we assume the order of results? + StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { + assertThat(tuple.getT1()).isEqualTo("result1"); + assertThat(tuple.getT2()).isEqualTo("result2"); + assertThat(tuple.getT3()).isEqualTo("result3"); + }).verifyComplete(); + + // Hook should only be called once despite concurrent requests + assertThat(hookInvocationCount.get()).isEqualTo(1); + } + + @Test + void shouldFailInitializationWhenPostInitializationHookFails() { + RuntimeException hookError = new RuntimeException("Post-initialization hook failed"); + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.error(hookError)); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectErrorMatches(ex -> ex instanceof RuntimeException && ex.getCause() == hookError) + .verify(); + + // Verify initialization was not completed + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + + // Verify the hook was called + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenInitializationFails() { + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.error(new RuntimeException("Initialization failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when initialization fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldNotInvokePostInitializationHookWhenNotificationFails() { + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.error(new RuntimeException("Notification failed"))); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + // Hook should NOT be called when notification fails + verify(mockPostInitializationHook, never()).apply(any(Initialization.class)); + } + + @Test + void shouldInvokePostInitializationHookAgainAfterReinitialization() { + AtomicInteger hookInvocationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + hookInvocationCount.incrementAndGet(); + return Mono.empty(); + }); + + // First initialization + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + assertThat(hookInvocationCount.get()).isEqualTo(1); + + // Simulate transport session exception to trigger re-initialization + initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); + + // Hook should be called twice (once for each initialization) + assertThat(hookInvocationCount.get()).isEqualTo(2); + } + + @Test + void shouldAllowPostInitializationHookToPerformAsyncOperations() { + AtomicInteger operationCount = new AtomicInteger(0); + + when(mockPostInitializationHook.apply(any(Initialization.class))) + .thenReturn(Mono.fromRunnable(() -> operationCount.incrementAndGet()).then()); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the async operation was executed + assertThat(operationCount.get()).isEqualTo(1); + verify(mockPostInitializationHook, times(1)).apply(any(Initialization.class)); + } + + @Test + void shouldProvideCorrectInitializationDataToHook() { + AtomicReference capturedSession = new AtomicReference<>(); + AtomicReference capturedResult = new AtomicReference<>(); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + Initialization init = invocation.getArgument(0); + capturedSession.set(init.mcpSession()); + capturedResult.set(init.initializeResult()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook received the correct session and result + assertThat(capturedSession.get()).isEqualTo(mockClientSession); + assertThat(capturedResult.get()).isEqualTo(MOCK_INIT_RESULT); + assertThat(capturedResult.get().protocolVersion()).isEqualTo("2.0.0"); + assertThat(capturedResult.get().serverInfo().name()).isEqualTo("test-server"); + } + + @Test + void shouldInvokePostInitializationHookAfterSuccessfulInitialization() { + AtomicReference notificationSent = new AtomicReference<>(false); + AtomicReference hookCalledAfterNotification = new AtomicReference<>(false); + + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenAnswer(invocation -> { + notificationSent.set(true); + return Mono.empty(); + }); + + when(mockPostInitializationHook.apply(any(Initialization.class))).thenAnswer(invocation -> { + // Due to flatMap chaining in doInitialize, if the hook is called, + // the notification must have been sent first + hookCalledAfterNotification.set(notificationSent.get()); + return Mono.empty(); + }); + + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Verify the hook was called and notification was already sent at that point + assertThat(hookCalledAfterNotification.get()).isTrue(); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); + verify(mockPostInitializationHook).apply(any(Initialization.class)); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java similarity index 80% rename from mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java index 02021edbf..787ee9480 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -10,14 +10,14 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; - -import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; @@ -58,12 +58,16 @@ class LifecycleInitializerTests { @Mock private Function mockSessionSupplier; + @Mock + private Function> mockPostInitializationHook; + private LifecycleInitializer initializer; @BeforeEach void setUp() { MockitoAnnotations.openMocks(this); + when(mockPostInitializationHook.apply(any(Initialization.class))).thenReturn(Mono.empty()); when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(MOCK_INIT_RESULT)); @@ -72,45 +76,45 @@ void setUp() { when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, - INITIALIZATION_TIMEOUT, mockSessionSupplier); + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); } @Test void constructorShouldValidateParameters() { assertThatThrownBy(() -> new LifecycleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT, - mockSessionSupplier)) + mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Client capabilities must not be null"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS, - INITIALIZATION_TIMEOUT, mockSessionSupplier)) + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Client info must not be null"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null, - INITIALIZATION_TIMEOUT, mockSessionSupplier)) + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Protocol versions must not be empty"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(), - INITIALIZATION_TIMEOUT, mockSessionSupplier)) + INITIALIZATION_TIMEOUT, mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Protocol versions must not be empty"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null, - mockSessionSupplier)) + mockSessionSupplier, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Initialization timeout must not be null"); assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, - INITIALIZATION_TIMEOUT, null)) + INITIALIZATION_TIMEOUT, null, mockPostInitializationHook)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Session supplier must not be null"); } @Test void shouldInitializeSuccessfully() { - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { assertThat(result).isEqualTo(MOCK_INIT_RESULT); assertThat(initializer.isInitialized()).isTrue(); @@ -132,7 +136,7 @@ void shouldUseLatestProtocolVersionInInitializeRequest() { return Mono.just(MOCK_INIT_RESULT); }); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { assertThat(capturedRequest.get().protocolVersion()).isEqualTo("2.0.0"); // Latest // version @@ -152,7 +156,7 @@ void shouldFailForUnsupportedProtocolVersion() { when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(unsupportedResult)); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) .verify(); @@ -167,13 +171,13 @@ void shouldTimeoutOnSlowInitialization() { Duration SLOW_RESPONSE_DELAY = Duration.ofSeconds(5); LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, - PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier); + PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler)); StepVerifier - .withVirtualTime(() -> shortTimeoutInitializer.withIntitialization("test", + .withVirtualTime(() -> shortTimeoutInitializer.withInitialization("test", init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE) .expectSubscription() .expectNoEvent(INITIALIZE_TIMEOUT) @@ -184,12 +188,12 @@ void shouldTimeoutOnSlowInitialization() { @Test void shouldReuseExistingInitialization() { // First initialization - StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1"))) + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) .expectNext("result1") .verifyComplete(); // Second call should reuse the same initialization - StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2"))) + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) .expectNext("result2") .verifyComplete(); @@ -209,11 +213,11 @@ void shouldHandleConcurrentInitializationRequests() { // Start multiple concurrent initializations using subscribeOn with parallel // scheduler - Mono init1 = initializer.withIntitialization("test1", init -> Mono.just("result1")) + Mono init1 = initializer.withInitialization("test1", init -> Mono.just("result1")) .subscribeOn(Schedulers.parallel()); - Mono init2 = initializer.withIntitialization("test2", init -> Mono.just("result2")) + Mono init2 = initializer.withInitialization("test2", init -> Mono.just("result2")) .subscribeOn(Schedulers.parallel()); - Mono init3 = initializer.withIntitialization("test3", init -> Mono.just("result3")) + Mono init3 = initializer.withInitialization("test3", init -> Mono.just("result3")) .subscribeOn(Schedulers.parallel()); StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { @@ -230,20 +234,32 @@ void shouldHandleConcurrentInitializationRequests() { @Test void shouldHandleInitializationFailure() { when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) - .thenReturn(Mono.error(new RuntimeException("Connection failed"))); + // fail once + .thenReturn(Mono.error(new RuntimeException("Connection failed"))) + // succeeds on the second call + .thenReturn(Mono.just(MOCK_INIT_RESULT)); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) .verify(); assertThat(initializer.isInitialized()).isFalse(); assertThat(initializer.currentInitializationResult()).isNull(); + + // The initializer can recover from previous errors + StepVerifier + .create(initializer.withInitialization("successful init", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); } @Test void shouldHandleTransportSessionNotFoundException() { // successful initialization first - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -265,7 +281,7 @@ void shouldHandleTransportSessionNotFoundException() { @Test void shouldHandleOtherExceptions() { // Simulate a successful initialization first - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -283,7 +299,7 @@ void shouldHandleOtherExceptions() { @Test void shouldCloseGracefully() { - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -295,7 +311,7 @@ void shouldCloseGracefully() { @Test void shouldCloseImmediately() { - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -330,7 +346,7 @@ void shouldSetProtocolVersionsForTesting() { new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions")); }); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { // Latest from new versions assertThat(capturedRequest.get().protocolVersion()).isEqualTo("4.0.0"); @@ -351,7 +367,7 @@ void shouldPassContextToSessionSupplier() { }); StepVerifier - .create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())) + .create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult())) .contextWrite(Context.of(contextKey, contextValue))) .expectNext(MOCK_INIT_RESULT) .verifyComplete(); @@ -362,7 +378,7 @@ void shouldPassContextToSessionSupplier() { @Test void shouldProvideAccessToMcpSessionAndInitializeResult() { - StepVerifier.create(initializer.withIntitialization("test", init -> { + StepVerifier.create(initializer.withInitialization("test", init -> { assertThat(init.mcpSession()).isEqualTo(mockClientSession); assertThat(init.initializeResult()).isEqualTo(MOCK_INIT_RESULT); return Mono.just("success"); @@ -374,7 +390,7 @@ void shouldHandleNotificationFailure() { when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) .thenReturn(Mono.error(new RuntimeException("Notification failed"))); - StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) .verify(); @@ -391,7 +407,7 @@ void shouldReturnNullWhenNotInitialized() { @Test void shouldReinitializeAfterTransportSessionException() { // First initialization - StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1"))) + StepVerifier.create(initializer.withInitialization("test1", init -> Mono.just("result1"))) .expectNext("result1") .verifyComplete(); @@ -399,7 +415,7 @@ void shouldReinitializeAfterTransportSessionException() { initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); // Should be able to initialize again - StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2"))) + StepVerifier.create(initializer.withInitialization("test2", init -> Mono.just("result2"))) .expectNext("result2") .verifyComplete(); diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..a04787aa3 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpAsyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DelegatingMcpAsyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpAsyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + StepVerifier + .create(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context)) + .expectNext(TEST_BUILDER) + .verifyComplete(); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var customizer = new DelegatingMcpAsyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "one")), + (builder, method, uri, body, ctx) -> Mono.just(builder.copy().header("x-test", "two")))); + + var headers = Mono + .from(customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", + McpTransportContext.EMPTY)) + .map(HttpRequest.Builder::build) + .map(HttpRequest::headers) + .flatMapIterable(h -> h.allValues("x-test")); + + StepVerifier.create(headers).expectNext("one").expectNext("two").verifyComplete(); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java new file mode 100644 index 000000000..6c51a3d12 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/DelegatingMcpSyncHttpClientRequestCustomizerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import io.modelcontextprotocol.common.McpTransportContext; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DelegatingMcpSyncHttpClientRequestCustomizer}. + * + * @author Daniel Garnier-Moiroux + */ +class DelegatingMcpSyncHttpClientRequestCustomizerTest { + + private static final URI TEST_URI = URI.create("https://example.com"); + + private final HttpRequest.Builder TEST_BUILDER = HttpRequest.newBuilder(TEST_URI); + + @Test + void delegates() { + var mockCustomizer = Mockito.mock(McpSyncHttpClientRequestCustomizer.class); + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer(List.of(mockCustomizer)); + + var context = McpTransportContext.EMPTY; + customizer.customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + + verify(mockCustomizer).customize(TEST_BUILDER, "GET", TEST_URI, "{\"everybody\": \"needs somebody\"}", context); + } + + @Test + void delegatesInOrder() { + var testHeaderName = "x-test"; + var customizer = new DelegatingMcpSyncHttpClientRequestCustomizer( + List.of((builder, method, uri, body, ctx) -> builder.header(testHeaderName, "one"), + (builder, method, uri, body, ctx) -> builder.header(testHeaderName, "two"))); + + customizer.customize(TEST_BUILDER, "GET", TEST_URI, null, McpTransportContext.EMPTY); + var request = TEST_BUILDER.build(); + + assertThat(request.headers().allValues(testHeaderName)).containsExactly("one", "two"); + } + + @Test + void constructorRequiresNonNull() { + assertThatThrownBy(() -> new DelegatingMcpAsyncHttpClientRequestCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Customizers must not be null"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java new file mode 100644 index 000000000..897ae2ccc --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java @@ -0,0 +1,236 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; +import java.util.Map; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.ToolNameValidator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link McpServerFeatures.AsyncToolSpecification.Builder}. + * + * @author Christian Tzolov + */ +class AsyncToolSpecificationBuilderTest { + + @Test + void builderShouldCreateValidAsyncToolSpecification() { + + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of(new TextContent("Test result"))).isError(false).build())) + .build(); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.callHandler()).isNotNull(); + } + + @Test + void builderShouldThrowExceptionWhenToolIsNull() { + assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder() + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); + } + + @Test + void builderShouldThrowExceptionWhenCallToolIsNull() { + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder().tool(tool).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Call handler function must not be null"); + } + + @Test + void builderShouldAllowMethodChaining() { + Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .title("A test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + McpServerFeatures.AsyncToolSpecification.Builder builder = McpServerFeatures.AsyncToolSpecification.builder(); + + // Then - verify method chaining returns the same builder instance + assertThat(builder.tool(tool)).isSameAs(builder); + assertThat(builder.callHandler( + (exchange, request) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build()))) + .isSameAs(builder); + } + + @Test + void builtSpecificationShouldExecuteCallToolCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("calculator") + .title("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "42"; + + McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> Mono.just( + CallToolResult.builder().content(List.of(new TextContent(expectedResult))).isError(false).build())) + .build(); + + CallToolRequest request = new CallToolRequest("calculator", Map.of()); + Mono resultMono = specification.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + void fromSyncShouldConvertSyncToolSpecificationCorrectly() { + Tool tool = McpSchema.Tool.builder() + .name("sync-tool") + .title("A sync tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "sync result"; + + // Create a sync tool specification + McpServerFeatures.SyncToolSpecification syncSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build()) + .build(); + + // Convert to async using fromSync + McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification + .fromSync(syncSpec); + + assertThat(asyncSpec).isNotNull(); + assertThat(asyncSpec.tool()).isEqualTo(tool); + assertThat(asyncSpec.callHandler()).isNotNull(); + + // Test that the converted async specification works correctly + CallToolRequest request = new CallToolRequest("sync-tool", Map.of("param", "value")); + Mono resultMono = asyncSpec.callHandler().apply(null, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + }).verifyComplete(); + } + + @Test + void fromSyncShouldReturnNullWhenSyncSpecIsNull() { + assertThat(McpServerFeatures.AsyncToolSpecification.fromSync(null)).isNull(); + } + + @Nested + class ToolNameValidation { + + private McpServerTransportProvider transportProvider; + + private final Logger logger = (Logger) LoggerFactory.getLogger(ToolNameValidator.class); + + private final ListAppender logAppender = new ListAppender<>(); + + @BeforeEach + void setUp() { + transportProvider = mock(McpServerTransportProvider.class); + System.clearProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY); + logAppender.start(); + logger.addAppender(logAppender); + } + + @AfterEach + void tearDown() { + System.clearProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY); + logger.detachAppender(logAppender); + logAppender.stop(); + } + + @Test + void defaultShouldThrowOnInvalidName() { + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatThrownBy( + () -> McpServer.async(transportProvider).toolCall(invalidTool, (exchange, request) -> null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("invalid characters"); + } + + @Test + void lenientDefaultShouldLogOnInvalidName() { + System.setProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY, "false"); + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatCode(() -> McpServer.async(transportProvider).toolCall(invalidTool, (exchange, request) -> null)) + .doesNotThrowAnyException(); + assertThat(logAppender.list).hasSize(1); + } + + @Test + void lenientConfigurationShouldLogOnInvalidName() { + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatCode(() -> McpServer.async(transportProvider) + .strictToolNameValidation(false) + .toolCall(invalidTool, (exchange, request) -> null)).doesNotThrowAnyException(); + assertThat(logAppender.list).hasSize(1); + } + + @Test + void serverConfigurationShouldOverrideDefault() { + System.setProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY, "false"); + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatThrownBy(() -> McpServer.async(transportProvider) + .strictToolNameValidation(true) + .toolCall(invalidTool, (exchange, request) -> null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("invalid characters"); + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index 987c43663..e6161a59f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -4,15 +4,16 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.common.McpTransportContext; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -54,7 +55,7 @@ void setUp() { clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); exchange = new McpAsyncServerExchange("testSessionId", mockSession, clientCapabilities, clientInfo, - new DefaultMcpTransportContext()); + McpTransportContext.EMPTY); } @Test @@ -65,7 +66,7 @@ void testListRootsWithSinglePage() { McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(singlePageResult)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -93,11 +94,11 @@ void testListRootsWithMultiplePages() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -119,7 +120,7 @@ void testListRootsWithEmptyResult() { McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(emptyResult)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -139,7 +140,7 @@ void testListRootsWithSpecificCursor() { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(result)); StepVerifier.create(exchange.listRoots("someCursor")).assertNext(listResult -> { @@ -153,7 +154,7 @@ void testListRootsWithSpecificCursor() { void testListRootsWithError() { when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Network error"))); // When & Then @@ -174,11 +175,11 @@ void testListRootsUnmodifiabilityAfterAccumulation() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); StepVerifier.create(exchange.listRoots()).assertNext(result -> { @@ -214,7 +215,7 @@ void testGetClientInfo() { @Test void testLoggingNotificationWithNullMessage() { StepVerifier.create(exchange.loggingNotification(null)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Logging message must not be null"); + assertThat(error).isInstanceOf(IllegalStateException.class).hasMessage("Logging message must not be null"); }); } @@ -300,7 +301,8 @@ void testLoggingNotificationWithSessionError() { @Test void testCreateElicitationWithNullCapabilities() { // Given - Create exchange with null capabilities - McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange("testSessionId", mockSession, + null, clientInfo, McpTransportContext.EMPTY); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your name") @@ -308,13 +310,12 @@ void testCreateElicitationWithNullCapabilities() { StepVerifier.create(exchangeWithNullCapabilities.createElicitation(elicitRequest)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Client must be initialized. Call the initialize method first!"); }); // Verify that sendRequest was never called due to null capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -324,22 +325,21 @@ void testCreateElicitationWithoutElicitationCapabilities() { .roots(true) .build(); - McpAsyncServerExchange exchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutElicitation, clientInfo); + McpAsyncServerExchange exchangeWithoutElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithoutElicitation, clientInfo, McpTransportContext.EMPTY); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your name") .build(); StepVerifier.create(exchangeWithoutElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Client must be configured with elicitation capabilities"); }); // Verify that sendRequest was never called due to missing elicitation // capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -349,8 +349,8 @@ void testCreateElicitationWithComplexRequest() { .elicitation() .build(); - McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); // Create a complex elicit request with schema java.util.Map requestedSchema = new java.util.HashMap<>(); @@ -373,8 +373,7 @@ void testCreateElicitationWithComplexRequest() { .content(responseContent) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -393,8 +392,8 @@ void testCreateElicitationWithDeclineAction() { .elicitation() .build(); - McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide sensitive information") @@ -404,8 +403,7 @@ void testCreateElicitationWithDeclineAction() { .message(McpSchema.ElicitResult.Action.DECLINE) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -421,8 +419,8 @@ void testCreateElicitationWithCancelAction() { .elicitation() .build(); - McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your information") @@ -432,8 +430,7 @@ void testCreateElicitationWithCancelAction() { .message(McpSchema.ElicitResult.Action.CANCEL) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { @@ -449,15 +446,14 @@ void testCreateElicitationWithSessionError() { .elicitation() .build(); - McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your name") .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { @@ -472,7 +468,8 @@ void testCreateElicitationWithSessionError() { @Test void testCreateMessageWithNullCapabilities() { - McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange("testSessionId", mockSession, + null, clientInfo, McpTransportContext.EMPTY); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -481,13 +478,13 @@ void testCreateMessageWithNullCapabilities() { StepVerifier.create(exchangeWithNullCapabilities.createMessage(createMessageRequest)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Client must be initialized. Call the initialize method first!"); }); // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -497,8 +494,8 @@ void testCreateMessageWithoutSamplingCapabilities() { .roots(true) .build(); - McpAsyncServerExchange exchangeWithoutSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutSampling, clientInfo); + McpAsyncServerExchange exchangeWithoutSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithoutSampling, clientInfo, McpTransportContext.EMPTY); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -506,13 +503,13 @@ void testCreateMessageWithoutSamplingCapabilities() { .build(); StepVerifier.create(exchangeWithoutSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Client must be configured with sampling capabilities"); }); // Verify that sendRequest was never called due to missing sampling capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -522,8 +519,8 @@ void testCreateMessageWithBasicRequest() { .sampling() .build(); - McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -538,7 +535,7 @@ void testCreateMessageWithBasicRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -558,8 +555,8 @@ void testCreateMessageWithImageContent() { .sampling() .build(); - McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); // Create request with image content McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -576,7 +573,7 @@ void testCreateMessageWithImageContent() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -593,8 +590,8 @@ void testCreateMessageWithSessionError() { .sampling() .build(); - McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays @@ -602,7 +599,7 @@ void testCreateMessageWithSessionError() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).verifyErrorSatisfies(error -> { @@ -617,8 +614,8 @@ void testCreateMessageWithIncludeContext() { .sampling() .build(); - McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange(mockSession, capabilitiesWithSampling, - clientInfo); + McpAsyncServerExchange exchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() .messages(Arrays.asList(new McpSchema.SamplingMessage(McpSchema.Role.USER, @@ -634,7 +631,7 @@ void testCreateMessageWithIncludeContext() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); StepVerifier.create(exchangeWithSampling.createMessage(createMessageRequest)).assertNext(result -> { @@ -652,7 +649,7 @@ void testPingWithSuccessfulResponse() { java.util.Map expectedResponse = java.util.Map.of(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(expectedResponse)); StepVerifier.create(exchange.ping()).assertNext(result -> { @@ -661,14 +658,14 @@ void testPingWithSuccessfulResponse() { }).verifyComplete(); // Verify that sendRequest was called with correct parameters - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingWithMcpError() { // Given - Mock an MCP-specific error during ping - McpError mcpError = new McpError("Server unavailable"); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + McpError mcpError = McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message("Server unavailable").build(); + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.error(mcpError)); // When & Then @@ -676,13 +673,13 @@ void testPingWithMcpError() { assertThat(error).isInstanceOf(McpError.class).hasMessage("Server unavailable"); }); - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingMultipleCalls() { - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(Map.of())) .thenReturn(Mono.just(Map.of())); @@ -697,7 +694,7 @@ void testPingMultipleCalls() { }).verifyComplete(); // Verify that sendRequest was called twice - verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java index a73ec7209..fba733c9a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -9,10 +9,11 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.json.TypeRef; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -54,7 +55,8 @@ void setUp() { clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); - asyncExchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo); + asyncExchange = new McpAsyncServerExchange("testSessionId", mockSession, clientCapabilities, clientInfo, + McpTransportContext.EMPTY); exchange = new McpSyncServerExchange(asyncExchange); } @@ -66,7 +68,7 @@ void testListRootsWithSinglePage() { McpSchema.ListRootsResult singlePageResult = new McpSchema.ListRootsResult(roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(singlePageResult)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -94,11 +96,11 @@ void testListRootsWithMultiplePages() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -120,7 +122,7 @@ void testListRootsWithEmptyResult() { McpSchema.ListRootsResult emptyResult = new McpSchema.ListRootsResult(new ArrayList<>(), null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(emptyResult)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -140,7 +142,7 @@ void testListRootsWithSpecificCursor() { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(roots, "nextCursor"); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("someCursor")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(result)); McpSchema.ListRootsResult listResult = exchange.listRoots("someCursor"); @@ -154,7 +156,7 @@ void testListRootsWithSpecificCursor() { void testListRootsWithError() { when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), any(McpSchema.PaginatedRequest.class), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Network error"))); // When & Then @@ -173,11 +175,11 @@ void testListRootsUnmodifiabilityAfterAccumulation() { McpSchema.ListRootsResult page2Result = new McpSchema.ListRootsResult(page2Roots, null); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest(null)), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page1Result)); when(mockSession.sendRequest(eq(McpSchema.METHOD_ROOTS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(page2Result)); McpSchema.ListRootsResult result = exchange.listRoots(); @@ -212,7 +214,7 @@ void testGetClientInfo() { @Test void testLoggingNotificationWithNullMessage() { - assertThatThrownBy(() -> exchange.loggingNotification(null)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> exchange.loggingNotification(null)).isInstanceOf(IllegalStateException.class) .hasMessage("Logging message must not be null"); } @@ -294,8 +296,8 @@ void testLoggingNotificationWithSessionError() { @Test void testCreateElicitationWithNullCapabilities() { // Given - Create exchange with null capabilities - McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, - clientInfo); + McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange("testSessionId", + mockSession, null, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( asyncExchangeWithNullCapabilities); @@ -304,12 +306,11 @@ void testCreateElicitationWithNullCapabilities() { .build(); assertThatThrownBy(() -> exchangeWithNullCapabilities.createElicitation(elicitRequest)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalStateException.class) .hasMessage("Client must be initialized. Call the initialize method first!"); // Verify that sendRequest was never called due to null capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -319,8 +320,8 @@ void testCreateElicitationWithoutElicitationCapabilities() { .roots(true) .build(); - McpAsyncServerExchange asyncExchangeWithoutElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutElicitation, clientInfo); + McpAsyncServerExchange asyncExchangeWithoutElicitation = new McpAsyncServerExchange("testSessionId", + mockSession, capabilitiesWithoutElicitation, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithoutElicitation = new McpSyncServerExchange(asyncExchangeWithoutElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() @@ -328,13 +329,12 @@ void testCreateElicitationWithoutElicitationCapabilities() { .build(); assertThatThrownBy(() -> exchangeWithoutElicitation.createElicitation(elicitRequest)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalStateException.class) .hasMessage("Client must be configured with elicitation capabilities"); // Verify that sendRequest was never called due to missing elicitation // capabilities - verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), - any(TypeReference.class)); + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); } @Test @@ -344,8 +344,8 @@ void testCreateElicitationWithComplexRequest() { .elicitation() .build(); - McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); // Create a complex elicit request with schema @@ -369,8 +369,7 @@ void testCreateElicitationWithComplexRequest() { .content(responseContent) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -389,8 +388,8 @@ void testCreateElicitationWithDeclineAction() { .elicitation() .build(); - McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() @@ -401,8 +400,7 @@ void testCreateElicitationWithDeclineAction() { .message(McpSchema.ElicitResult.Action.DECLINE) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -418,8 +416,8 @@ void testCreateElicitationWithCancelAction() { .elicitation() .build(); - McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() @@ -430,8 +428,7 @@ void testCreateElicitationWithCancelAction() { .message(McpSchema.ElicitResult.Action.CANCEL) .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.ElicitResult result = exchangeWithElicitation.createElicitation(elicitRequest); @@ -447,16 +444,15 @@ void testCreateElicitationWithSessionError() { .elicitation() .build(); - McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange(mockSession, - capabilitiesWithElicitation, clientInfo); + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() .message("Please provide your name") .build(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), - any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); assertThatThrownBy(() -> exchangeWithElicitation.createElicitation(elicitRequest)) @@ -471,8 +467,8 @@ void testCreateElicitationWithSessionError() { @Test void testCreateMessageWithNullCapabilities() { - McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, - clientInfo); + McpAsyncServerExchange asyncExchangeWithNullCapabilities = new McpAsyncServerExchange("testSessionId", + mockSession, null, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithNullCapabilities = new McpSyncServerExchange( asyncExchangeWithNullCapabilities); @@ -482,12 +478,12 @@ void testCreateMessageWithNullCapabilities() { .build(); assertThatThrownBy(() -> exchangeWithNullCapabilities.createMessage(createMessageRequest)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalStateException.class) .hasMessage("Client must be initialized. Call the initialize method first!"); // Verify that sendRequest was never called due to null capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -497,8 +493,8 @@ void testCreateMessageWithoutSamplingCapabilities() { .roots(true) .build(); - McpAsyncServerExchange asyncExchangeWithoutSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithoutSampling, clientInfo); + McpAsyncServerExchange asyncExchangeWithoutSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithoutSampling, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithoutSampling = new McpSyncServerExchange(asyncExchangeWithoutSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -507,12 +503,12 @@ void testCreateMessageWithoutSamplingCapabilities() { .build(); assertThatThrownBy(() -> exchangeWithoutSampling.createMessage(createMessageRequest)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalStateException.class) .hasMessage("Client must be configured with sampling capabilities"); // Verify that sendRequest was never called due to missing sampling capabilities verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), any(), - any(TypeReference.class)); + any(TypeRef.class)); } @Test @@ -522,8 +518,8 @@ void testCreateMessageWithBasicRequest() { .sampling() .build(); - McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -539,7 +535,7 @@ void testCreateMessageWithBasicRequest() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -559,8 +555,8 @@ void testCreateMessageWithImageContent() { .sampling() .build(); - McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); // Create request with image content @@ -578,7 +574,7 @@ void testCreateMessageWithImageContent() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -595,8 +591,8 @@ void testCreateMessageWithSessionError() { .sampling() .build(); - McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -605,7 +601,7 @@ void testCreateMessageWithSessionError() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.error(new RuntimeException("Session communication error"))); assertThatThrownBy(() -> exchangeWithSampling.createMessage(createMessageRequest)) @@ -620,8 +616,8 @@ void testCreateMessageWithIncludeContext() { .sampling() .build(); - McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange(mockSession, - capabilitiesWithSampling, clientInfo); + McpAsyncServerExchange asyncExchangeWithSampling = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithSampling, clientInfo, McpTransportContext.EMPTY); McpSyncServerExchange exchangeWithSampling = new McpSyncServerExchange(asyncExchangeWithSampling); McpSchema.CreateMessageRequest createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -638,7 +634,7 @@ void testCreateMessageWithIncludeContext() { .build(); when(mockSession.sendRequest(eq(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE), eq(createMessageRequest), - any(TypeReference.class))) + any(TypeRef.class))) .thenReturn(Mono.just(expectedResult)); McpSchema.CreateMessageResult result = exchangeWithSampling.createMessage(createMessageRequest); @@ -656,32 +652,32 @@ void testPingWithSuccessfulResponse() { java.util.Map expectedResponse = java.util.Map.of(); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(expectedResponse)); exchange.ping(); // Verify that sendRequest was called with correct parameters - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingWithMcpError() { // Given - Mock an MCP-specific error during ping - McpError mcpError = new McpError("Server unavailable"); - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + McpError mcpError = McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message("Server unavailable").build(); + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.error(mcpError)); // When & Then assertThatThrownBy(() -> exchange.ping()).isInstanceOf(McpError.class).hasMessage("Server unavailable"); - verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } @Test void testPingMultipleCalls() { - when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class))) + when(mockSession.sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class))) .thenReturn(Mono.just(Map.of())) .thenReturn(Mono.just(Map.of())); @@ -692,7 +688,7 @@ void testPingMultipleCalls() { exchange.ping(); // Verify that sendRequest was called twice - verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeReference.class)); + verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java new file mode 100644 index 000000000..993ca717e --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/ResourceTemplateListingTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test to verify the separation of regular resources and resource templates. Regular + * resources (without template parameters) should only appear in resources/list. Template + * resources (containing {}) should only appear in resources/templates/list. + */ +public class ResourceTemplateListingTest { + + @Test + void testTemplateResourcesFilteredFromRegularListing() { + // The change we made filters resources containing "{" from the regular listing + // This test verifies that behavior is working correctly + + // Given a string with template parameter + String templateUri = "file:///test/{userId}/profile.txt"; + assertThat(templateUri.contains("{")).isTrue(); + + // And a regular URI + String regularUri = "file:///test/regular.txt"; + assertThat(regularUri.contains("{")).isFalse(); + + // The filter should exclude template URIs + assertThat(!templateUri.contains("{")).isFalse(); + assertThat(!regularUri.contains("{")).isTrue(); + } + + @Test + void testResourceListingWithMixedResources() { + // Create resource list with both regular and template resources + List allResources = List.of( + McpSchema.Resource.builder() + .uri("file:///test/doc1.txt") + .name("Document 1") + .mimeType("text/plain") + .build(), + McpSchema.Resource.builder() + .uri("file:///test/doc2.txt") + .name("Document 2") + .mimeType("text/plain") + .build(), + McpSchema.Resource.builder() + .uri("file:///test/{type}/document.txt") + .name("Typed Document") + .mimeType("text/plain") + .build(), + McpSchema.Resource.builder() + .uri("file:///users/{userId}/files/{fileId}") + .name("User File") + .mimeType("text/plain") + .build()); + + // Apply the filter logic from McpAsyncServer line 438 + List filteredResources = allResources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Verify only regular resources are included + assertThat(filteredResources).hasSize(2); + assertThat(filteredResources).extracting(McpSchema.Resource::uri) + .containsExactlyInAnyOrder("file:///test/doc1.txt", "file:///test/doc2.txt"); + } + + @Test + void testResourceTemplatesListedSeparately() { + // Create mixed resources + List resources = List.of( + McpSchema.Resource.builder() + .uri("file:///test/regular.txt") + .name("Regular Resource") + .mimeType("text/plain") + .build(), + McpSchema.Resource.builder() + .uri("file:///test/user/{userId}/profile.txt") + .name("User Profile") + .mimeType("text/plain") + .build()); + + // Create explicit resource template + McpSchema.ResourceTemplate explicitTemplate = new McpSchema.ResourceTemplate( + "file:///test/document/{docId}/content.txt", "Document Template", null, "text/plain", null); + + // Filter regular resources (those without template parameters) + List regularResources = resources.stream() + .filter(resource -> !resource.uri().contains("{")) + .collect(Collectors.toList()); + + // Extract template resources (those with template parameters) + List templateResources = resources.stream() + .filter(resource -> resource.uri().contains("{")) + .map(resource -> new McpSchema.ResourceTemplate(resource.uri(), resource.name(), resource.description(), + resource.mimeType(), resource.annotations())) + .collect(Collectors.toList()); + + // Verify regular resources list + assertThat(regularResources).hasSize(1); + assertThat(regularResources.get(0).uri()).isEqualTo("file:///test/regular.txt"); + + // Verify template resources list includes both extracted and explicit templates + assertThat(templateResources).hasSize(1); + assertThat(templateResources.get(0).uriTemplate()).isEqualTo("file:///test/user/{userId}/profile.txt"); + + // In the actual implementation, both would be combined + List allTemplates = List.of(templateResources.get(0), explicitTemplate); + assertThat(allTemplates).hasSize(2); + assertThat(allTemplates).extracting(McpSchema.ResourceTemplate::uriTemplate) + .containsExactlyInAnyOrder("file:///test/user/{userId}/profile.txt", + "file:///test/document/{docId}/content.txt"); + } + +} \ No newline at end of file diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java new file mode 100644 index 000000000..54c45e561 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java @@ -0,0 +1,181 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; +import java.util.Map; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.ToolNameValidator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link McpServerFeatures.SyncToolSpecification.Builder}. + * + * @author Christian Tzolov + */ +class SyncToolSpecificationBuilderTest { + + @Test + void builderShouldCreateValidSyncToolSpecification() { + + Tool tool = Tool.builder().name("test-tool").title("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + + McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> CallToolResult.builder() + .content(List.of(new TextContent("Test result"))) + .isError(false) + .build()) + .build(); + + assertThat(specification).isNotNull(); + assertThat(specification.tool()).isEqualTo(tool); + assertThat(specification.callHandler()).isNotNull(); + } + + @Test + void builderShouldThrowExceptionWhenToolIsNull() { + assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder() + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); + } + + @Test + void builderShouldThrowExceptionWhenCallToolIsNull() { + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + + assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder().tool(tool).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("CallTool function must not be null"); + } + + @Test + void builderShouldAllowMethodChaining() { + Tool tool = Tool.builder().name("test-tool").description("A test tool").inputSchema(EMPTY_JSON_SCHEMA).build(); + McpServerFeatures.SyncToolSpecification.Builder builder = McpServerFeatures.SyncToolSpecification.builder(); + + // Then - verify method chaining returns the same builder instance + assertThat(builder.tool(tool)).isSameAs(builder); + assertThat(builder + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build())) + .isSameAs(builder); + } + + @Test + void builtSpecificationShouldExecuteCallToolCorrectly() { + Tool tool = Tool.builder() + .name("calculator") + .description("Simple calculator") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + String expectedResult = "42"; + + McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() + .tool(tool) + .callHandler((exchange, request) -> { + // Simple test implementation + return CallToolResult.builder() + .content(List.of(new TextContent(expectedResult))) + .isError(false) + .build(); + }) + .build(); + + CallToolRequest request = new CallToolRequest("calculator", Map.of()); + CallToolResult result = specification.callHandler().apply(null, request); + + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); + assertThat(result.isError()).isFalse(); + } + + @Nested + class ToolNameValidation { + + private McpServerTransportProvider transportProvider; + + private final Logger logger = (Logger) LoggerFactory.getLogger(ToolNameValidator.class); + + private final ListAppender logAppender = new ListAppender<>(); + + @BeforeEach + void setUp() { + transportProvider = mock(McpServerTransportProvider.class); + System.clearProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY); + logAppender.start(); + logger.addAppender(logAppender); + } + + @AfterEach + void tearDown() { + System.clearProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY); + logger.detachAppender(logAppender); + logAppender.stop(); + } + + @Test + void defaultShouldThrowOnInvalidName() { + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatThrownBy( + () -> McpServer.sync(transportProvider).toolCall(invalidTool, (exchange, request) -> null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("invalid characters"); + } + + @Test + void lenientDefaultShouldLogOnInvalidName() { + System.setProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY, "false"); + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatCode(() -> McpServer.sync(transportProvider).toolCall(invalidTool, (exchange, request) -> null)) + .doesNotThrowAnyException(); + assertThat(logAppender.list).hasSize(1); + } + + @Test + void lenientConfigurationShouldLogOnInvalidName() { + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatCode(() -> McpServer.sync(transportProvider) + .strictToolNameValidation(false) + .toolCall(invalidTool, (exchange, request) -> null)).doesNotThrowAnyException(); + assertThat(logAppender.list).hasSize(1); + } + + @Test + void serverConfigurationShouldOverrideDefault() { + System.setProperty(ToolNameValidator.STRICT_VALIDATION_PROPERTY, "false"); + Tool invalidTool = Tool.builder().name("invalid tool name").build(); + + assertThatThrownBy(() -> McpServer.sync(transportProvider) + .strictToolNameValidation(true) + .toolCall(invalidTool, (exchange, request) -> null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("invalid characters"); + } + + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java new file mode 100644 index 000000000..d4cf8582d --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java @@ -0,0 +1,424 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author Daniel Garnier-Moiroux + */ +class DefaultServerTransportSecurityValidatorTests { + + private static final ServerTransportSecurityException INVALID_ORIGIN = new ServerTransportSecurityException(403, + "Invalid Origin header"); + + private static final ServerTransportSecurityException INVALID_HOST = new ServerTransportSecurityException(421, + "Invalid Host header"); + + private final DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:8080") + .build(); + + @Test + void builder() { + assertThatCode(() -> DefaultServerTransportSecurityValidator.builder().build()).doesNotThrowAnyException(); + assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedOrigins(null).build()) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> DefaultServerTransportSecurityValidator.builder().allowedHosts(null).build()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Nested + class OriginHeader { + + @Test + void originHeaderMissing() { + assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + } + + @Test + void originHeaderListEmpty() { + assertThatThrownBy(() -> validator.validateHeaders(Map.of("Origin", List.of()))).isEqualTo(INVALID_ORIGIN); + } + + @Test + void caseInsensitive() { + var headers = Map.of("origin", List.of("http://localhost:8080")); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void exactMatch() { + var headers = originHeader("http://localhost:8080"); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentPort() { + + var headers = originHeader("http://localhost:3000"); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentHost() { + + var headers = originHeader("http://example.com:8080"); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentScheme() { + + var headers = originHeader("https://localhost:8080"); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Nested + class WildcardPort { + + private final DefaultServerTransportSecurityValidator wildcardValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:*") + .build(); + + @Test + void anyPortWithWildcard() { + var headers = originHeader("http://localhost:3000"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void noPortWithWildcard() { + var headers = originHeader("http://localhost"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentPortWithWildcard() { + var headers = originHeader("http://localhost:8080"); + + assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentHostWithWildcard() { + var headers = originHeader("http://example.com:3000"); + + assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void differentSchemeWithWildcard() { + var headers = originHeader("https://localhost:3000"); + + assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + } + + @Nested + class MultipleOrigins { + + DefaultServerTransportSecurityValidator multipleOriginsValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigin("http://example.com:3000") + .allowedOrigin("http://myapp.example.com:*") + .build(); + + @Test + void matchingOneOfMultiple() { + var headers = originHeader("http://example.com:3000"); + + assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void matchingWildcardInMultiple() { + var headers = originHeader("http://myapp.example.com:9999"); + + assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void notMatchingAny() { + var headers = originHeader("http://malicious.example.com:1234"); + + assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + } + + @Nested + class BuilderTests { + + @Test + void shouldAddMultipleOriginsWithAllowedOriginsMethod() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*")) + .build(); + + var headers = originHeader("http://example.com:3000"); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void shouldCombineAllowedOriginMethods() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000")) + .build(); + + assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000"))) + .doesNotThrowAnyException(); + } + + } + + } + + @Nested + class HostHeader { + + private final DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedHost("localhost:8080") + .build(); + + @Test + void notConfigured() { + assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + } + + @Test + void missing() { + assertThatThrownBy(() -> hostValidator.validateHeaders(new HashMap<>())).isEqualTo(INVALID_HOST); + } + + @Test + void listEmpty() { + assertThatThrownBy(() -> hostValidator.validateHeaders(Map.of("Host", List.of()))).isEqualTo(INVALID_HOST); + } + + @Test + void caseInsensitive() { + var headers = Map.of("host", List.of("localhost:8080")); + + assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void exactMatch() { + var headers = hostHeader("localhost:8080"); + + assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentPort() { + var headers = hostHeader("localhost:3000"); + + assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + @Test + void differentHost() { + var headers = hostHeader("example.com:8080"); + + assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + @Nested + class HostWildcardPort { + + private final DefaultServerTransportSecurityValidator wildcardHostValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedHost("localhost:*") + .build(); + + @Test + void anyPort() { + var headers = hostHeader("localhost:3000"); + + assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void noPort() { + var headers = hostHeader("localhost"); + + assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentHost() { + var headers = hostHeader("example.com:3000"); + + assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + } + + @Nested + class MultipleHosts { + + DefaultServerTransportSecurityValidator multipleHostsValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedHost("example.com:3000") + .allowedHost("myapp.example.com:*") + .build(); + + @Test + void exactMatch() { + var headers = hostHeader("example.com:3000"); + + assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void wildcard() { + var headers = hostHeader("myapp.example.com:9999"); + + assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void differentHost() { + var headers = hostHeader("malicious.example.com:3000"); + + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + @Test + void differentPort() { + var headers = hostHeader("localhost:8080"); + + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + } + + @Nested + class HostBuilderTests { + + @Test + void multipleHosts() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedHosts(List.of("localhost:8080", "example.com:*")) + .build(); + + assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:3000"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + .doesNotThrowAnyException(); + } + + @Test + void combined() { + DefaultServerTransportSecurityValidator validator = DefaultServerTransportSecurityValidator.builder() + .allowedHost("localhost:8080") + .allowedHosts(List.of("example.com:*", "test.com:3000")) + .build(); + + assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:9999"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostHeader("test.com:3000"))).doesNotThrowAnyException(); + } + + } + + } + + @Nested + class CombinedOriginAndHostValidation { + + private final DefaultServerTransportSecurityValidator combinedValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build(); + + @Test + void bothValid() { + var header = headers("http://localhost:8080", "localhost:8080"); + + assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + } + + @Test + void originValidHostInvalid() { + var header = headers("http://localhost:8080", "malicious.example.com:8080"); + + assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + } + + @Test + void originInvalidHostValid() { + var header = headers("http://malicious.example.com:8080", "localhost:8080"); + + assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void originMissingHostValid() { + // Origin missing is OK (same-origin request) + var header = headers(null, "localhost:8080"); + + assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + } + + @Test + void originValidHostMissing() { + // Host missing is NOT OK when allowedHosts is configured + var header = headers("http://localhost:8080", null); + + assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + } + + } + + private static Map> originHeader(String origin) { + return Map.of("Origin", List.of(origin)); + } + + private static Map> hostHeader(String host) { + return Map.of("Host", List.of(host)); + } + + private static Map> headers(String origin, String host) { + var map = new HashMap>(); + if (origin != null) { + map.put("Origin", List.of(origin)); + } + if (host != null) { + map.put("Host", List.of(host)); + } + return map; + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java rename to mcp-core/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java similarity index 91% rename from mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java index d03a6926d..fbe17d464 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java @@ -5,7 +5,10 @@ package io.modelcontextprotocol.spec; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * Tests for MCP-specific validation of JSONRPCRequest ID requirements. diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java new file mode 100644 index 000000000..3de06f503 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -0,0 +1,313 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.function.Function; + +import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, + * request-response correlation, and notification processing. + * + * @author Christian Tzolov + */ +class McpClientSessionTests { + + private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + private static final String TEST_METHOD = "test.method"; + + private static final String TEST_NOTIFICATION = "test.notification"; + + private static final String ECHO_METHOD = "echo"; + + TypeRef responseType = new TypeRef<>() { + }; + + @Test + void testSendRequest() { + String testParam = "test parameter"; + String responseData = "test response"; + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Create a Mono that will emit the response after the request is sent + Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); + // Verify response handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); + }).consumeNextWith(response -> { + // Verify the request was sent + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; + assertThat(request.method()).isEqualTo(TEST_METHOD); + assertThat(request.params()).isEqualTo(testParam); + assertThat(response).isEqualTo(responseData); + }).verifyComplete(); + + session.close(); + } + + @Test + void testSendRequestWithError() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify error handling + StepVerifier.create(responseMono).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + // Simulate error response + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }).expectError(McpError.class).verify(); + + session.close(); + } + + @Test + void testRequestTimeout() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); + + // Verify timeout + StepVerifier.create(responseMono) + .expectError(java.util.concurrent.TimeoutException.class) + .verify(TIMEOUT.plusSeconds(1)); + + session.close(); + } + + @Test + void testSendNotification() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + Map params = Map.of("key", "value"); + Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); + + // Verify notification was sent + StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; + assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); + assertThat(notification.params()).isEqualTo(params); + }).verifyComplete(); + + session.close(); + } + + @Test + void testRequestHandling() { + String echoMessage = "Hello MCP!"; + Map> requestHandlers = Map.of(ECHO_METHOD, + params -> Mono.just(params)); + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of(), Function.identity()); + + // Simulate incoming request + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "test-id", echoMessage); + transport.simulateIncomingMessage(request); + + // Verify response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.result()).isEqualTo(echoMessage); + assertThat(response.error()).isNull(); + + session.close(); + } + + @Test + void testNotificationHandling() { + Sinks.One receivedParams = Sinks.one(); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))), + Function.identity()); + + // Simulate incoming notification from the server + Map notificationParams = Map.of("status", "ready"); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + TEST_NOTIFICATION, notificationParams); + + transport.simulateIncomingMessage(notification); + + // Verify handler was called + assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); + + session.close(); + } + + @Test + void testUnknownMethodHandling() { + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + // Simulate incoming request for unknown method + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify error response + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); + + session.close(); + } + + @Test + void testRequestHandlerThrowsMcpErrorWithJsonRpcError() { + // Setup: Create a request handler that throws McpError with custom error code and + // data + String testMethod = "test.customError"; + Map errorData = Map.of("customField", "customValue"); + McpClientSession.RequestHandler failingHandler = params -> Mono + .error(McpError.builder(123).message("Custom error message").data(errorData).build()); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain the custom error from McpError + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(123); + assertThat(response.error().message()).isEqualTo("Custom error message"); + assertThat(response.error().data()).isEqualTo(errorData); + + session.close(); + } + + @Test + void testRequestHandlerThrowsGenericException() { + // Setup: Create a request handler that throws a generic RuntimeException + String testMethod = "test.genericError"; + RuntimeException exception = new RuntimeException("Something went wrong"); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(exception); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with aggregated exception + // messages in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Something went wrong"); + // Verify data field contains aggregated exception messages + assertThat(response.error().data()).isNotNull(); + assertThat(response.error().data().toString()).contains("RuntimeException"); + assertThat(response.error().data().toString()).contains("Something went wrong"); + + session.close(); + } + + @Test + void testRequestHandlerThrowsExceptionWithCause() { + // Setup: Create a request handler that throws an exception with a cause chain + String testMethod = "test.chainedError"; + RuntimeException rootCause = new IllegalArgumentException("Root cause message"); + RuntimeException middleCause = new IllegalStateException("Middle cause message", rootCause); + RuntimeException topException = new RuntimeException("Top level message", middleCause); + McpClientSession.RequestHandler failingHandler = params -> Mono.error(topException); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(testMethod, failingHandler), Map.of(), + Function.identity()); + + // Simulate incoming request that will trigger the error + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, testMethod, + "test-id", null); + transport.simulateIncomingMessage(request); + + // Verify: The response should contain INTERNAL_ERROR with full exception chain + // in data field + McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); + assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; + assertThat(response.error()).isNotNull(); + assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.INTERNAL_ERROR); + assertThat(response.error().message()).isEqualTo("Top level message"); + // Verify data field contains the full exception chain + String dataString = response.error().data().toString(); + assertThat(dataString).contains("RuntimeException"); + assertThat(dataString).contains("Top level message"); + assertThat(dataString).contains("IllegalStateException"); + assertThat(dataString).contains("Middle cause message"); + assertThat(dataString).contains("IllegalArgumentException"); + assertThat(dataString).contains("Root cause message"); + + session.close(); + } + + @Test + void testGracefulShutdown() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), + Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params))), + Function.identity()); + + StepVerifier.create(session.closeGracefully()).verifyComplete(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java new file mode 100644 index 000000000..0978ffe0b --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpErrorTest.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.spec; + +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class McpErrorTest { + + @Test + void testNotFound() { + String uri = "file:///nonexistent.txt"; + McpError mcpError = McpError.RESOURCE_NOT_FOUND.apply(uri); + assertNotNull(mcpError.getJsonRpcError()); + assertEquals(-32002, mcpError.getJsonRpcError().code()); + assertEquals("Resource not found", mcpError.getJsonRpcError().message()); + assertEquals(Map.of("uri", uri), mcpError.getJsonRpcError().data()); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java similarity index 74% rename from mcp/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java rename to mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java index 25e22f968..1d7be0b51 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/PromptReferenceEqualsTest.java @@ -4,9 +4,12 @@ package io.modelcontextprotocol.spec; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; /** * Test class to verify the equals method implementation for PromptReference. @@ -15,8 +18,10 @@ class PromptReferenceEqualsTest { @Test void testEqualsWithSameIdentifierAndType() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Test Title"); - McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Different Title"); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); assertTrue(ref1.equals(ref2), "PromptReferences with same identifier and type should be equal"); assertEquals(ref1.hashCode(), ref2.hashCode(), "Equal objects should have same hash code"); @@ -24,15 +29,18 @@ void testEqualsWithSameIdentifierAndType() { @Test void testEqualsWithDifferentIdentifier() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt-1", "Test Title"); - McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/prompt", "test-prompt-2", "Test Title"); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-1", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt-2", + "Test Title"); assertFalse(ref1.equals(ref2), "PromptReferences with different identifiers should not be equal"); } @Test void testEqualsWithDifferentType() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Test Title"); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/other", "test-prompt", "Test Title"); assertFalse(ref1.equals(ref2), "PromptReferences with different types should not be equal"); @@ -40,14 +48,16 @@ void testEqualsWithDifferentType() { @Test void testEqualsWithNull() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Test Title"); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); assertFalse(ref1.equals(null), "PromptReference should not be equal to null"); } @Test void testEqualsWithDifferentClass() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Test Title"); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); String other = "not a PromptReference"; assertFalse(ref1.equals(other), "PromptReference should not be equal to different class"); @@ -55,16 +65,17 @@ void testEqualsWithDifferentClass() { @Test void testEqualsWithSameInstance() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Test Title"); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); assertTrue(ref1.equals(ref1), "PromptReference should be equal to itself"); } @Test void testEqualsIgnoresTitle() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Title 1"); - McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Title 2"); - McpSchema.PromptReference ref3 = new McpSchema.PromptReference("ref/prompt", "test-prompt", null); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 1"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", "Title 2"); + McpSchema.PromptReference ref3 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", null); assertTrue(ref1.equals(ref2), "PromptReferences should be equal regardless of title"); assertTrue(ref1.equals(ref3), "PromptReferences should be equal even when one has null title"); @@ -73,8 +84,10 @@ void testEqualsIgnoresTitle() { @Test void testHashCodeConsistency() { - McpSchema.PromptReference ref1 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Test Title"); - McpSchema.PromptReference ref2 = new McpSchema.PromptReference("ref/prompt", "test-prompt", "Different Title"); + McpSchema.PromptReference ref1 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Test Title"); + McpSchema.PromptReference ref2 = new McpSchema.PromptReference(PromptReference.TYPE, "test-prompt", + "Different Title"); assertEquals(ref1.hashCode(), ref2.hashCode(), "Objects that are equal should have the same hash code"); diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java new file mode 100644 index 000000000..ef7cd2737 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapper.java @@ -0,0 +1,97 @@ +package io.modelcontextprotocol.spec.json.gson; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.ToNumberPolicy; +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +/** + * Test-only Gson-based implementation of McpJsonMapper. This lives under src/test/java so + * it doesn't affect production code or dependencies. + */ +public final class GsonMcpJsonMapper implements McpJsonMapper { + + private final Gson gson; + + public GsonMcpJsonMapper() { + this(new GsonBuilder().serializeNulls() + // Ensure numeric values in untyped (Object) fields preserve integral numbers + // as Long + .setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .setNumberToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE) + .create()); + } + + public GsonMcpJsonMapper(Gson gson) { + if (gson == null) { + throw new IllegalArgumentException("Gson must not be null"); + } + this.gson = gson; + } + + public Gson getGson() { + return gson; + } + + @Override + public T readValue(String content, Class type) throws IOException { + try { + return gson.fromJson(content, type); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + try { + return gson.fromJson(content, type.getType()); + } + catch (Exception e) { + throw new IOException("Failed to deserialize JSON", e); + } + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + return readValue(new String(content, StandardCharsets.UTF_8), type); + } + + @Override + public T convertValue(Object fromValue, Class type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + String json = gson.toJson(fromValue); + return gson.fromJson(json, type.getType()); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + try { + return gson.toJson(value); + } + catch (Exception e) { + throw new IOException("Failed to serialize to JSON", e); + } + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return writeValueAsString(value).getBytes(StandardCharsets.UTF_8); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java new file mode 100644 index 000000000..498194d17 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/json/gson/GsonMcpJsonMapperTests.java @@ -0,0 +1,135 @@ +package io.modelcontextprotocol.spec.json.gson; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +class GsonMcpJsonMapperTests { + + record Person(String name, int age) { + } + + @Test + void roundTripSimplePojo() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + var input = new Person("Alice", 30); + String json = mapper.writeValueAsString(input); + assertNotNull(json); + assertTrue(json.contains("\"Alice\"")); + assertTrue(json.contains("\"age\"")); + + var decoded = mapper.readValue(json, Person.class); + assertEquals(input, decoded); + + byte[] bytes = mapper.writeValueAsBytes(input); + assertNotNull(bytes); + var decodedFromBytes = mapper.readValue(bytes, Person.class); + assertEquals(input, decodedFromBytes); + } + + @Test + void readWriteParameterizedTypeWithTypeRef() throws IOException { + var mapper = new GsonMcpJsonMapper(); + String json = "[\"a\", \"b\", \"c\"]"; + + List list = mapper.readValue(json, new TypeRef>() { + }); + assertEquals(List.of("a", "b", "c"), list); + + String encoded = mapper.writeValueAsString(list); + assertTrue(encoded.startsWith("[")); + assertTrue(encoded.contains("\"a\"")); + } + + @Test + void convertValueMapToRecordAndParameterized() { + var mapper = new GsonMcpJsonMapper(); + Map src = Map.of("name", "Bob", "age", 42); + + // Convert to simple record + Person person = mapper.convertValue(src, Person.class); + assertEquals(new Person("Bob", 42), person); + + // Convert to parameterized Map + Map toMap = mapper.convertValue(person, new TypeRef>() { + }); + assertEquals("Bob", toMap.get("name")); + assertEquals(42.0, ((Number) toMap.get("age")).doubleValue(), 0.0); // Gson may + // emit double + // for + // primitives + } + + @Test + void deserializeJsonRpcMessageRequestUsingCustomMapper() throws IOException { + var mapper = new GsonMcpJsonMapper(); + + String json = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "ping", + "params": { "x": 1, "y": "z" } + } + """; + + var msg = McpSchema.deserializeJsonRpcMessage(mapper, json); + assertTrue(msg instanceof McpSchema.JSONRPCRequest); + + var req = (McpSchema.JSONRPCRequest) msg; + assertEquals("2.0", req.jsonrpc()); + assertEquals("ping", req.method()); + assertNotNull(req.id()); + assertEquals("1", req.id().toString()); + + assertNotNull(req.params()); + assertInstanceOf(Map.class, req.params()); + @SuppressWarnings("unchecked") + var params = (Map) req.params(); + assertEquals(1.0, ((Number) params.get("x")).doubleValue(), 0.0); + assertEquals("z", params.get("y")); + } + + @Test + void integrateWithMcpSchemaStaticMapperForStringParsing() { + var gsonMapper = new GsonMcpJsonMapper(); + + // Tool builder parsing of input/output schema strings + var tool = McpSchema.Tool.builder().name("echo").description("Echo tool").inputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "x": { "type": "integer" } }, + "required": ["x"] + } + """).outputSchema(gsonMapper, """ + { + "type": "object", + "properties": { "y": { "type": "string" } } + } + """).build(); + + assertNotNull(tool.inputSchema()); + assertNotNull(tool.outputSchema()); + assertTrue(tool.outputSchema().containsKey("properties")); + + // CallToolRequest builder parsing of JSON arguments string + var call = McpSchema.CallToolRequest.builder().name("echo").arguments(gsonMapper, "{\"x\": 123}").build(); + + assertEquals("echo", call.name()); + assertNotNull(call.arguments()); + assertTrue(call.arguments().get("x") instanceof Number); + assertEquals(123.0, ((Number) call.arguments().get("x")).doubleValue(), 0.0); + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java similarity index 87% rename from mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java index 08555fef5..0038d4e1b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/AssertTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/AssertTests.java @@ -8,7 +8,9 @@ import java.util.List; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; class AssertTests { diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java similarity index 98% rename from mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java index 4de9363c2..d5ef8a91c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java @@ -16,7 +16,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSession; @@ -259,7 +259,7 @@ private static class MockMcpSession implements McpSession { private boolean shouldFailPing = false; @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + public Mono sendRequest(String method, Object requestParams, TypeRef typeRef) { if (McpSchema.METHOD_PING.equals(method)) { pingCount.incrementAndGet(); if (shouldFailPing) { diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..803372056 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,13 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonDefaults.getMapper(); + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolNameValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolNameValidatorTests.java new file mode 100644 index 000000000..f8e301f82 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolNameValidatorTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.function.Consumer; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.LoggerFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link ToolNameValidator}. + */ +class ToolNameValidatorTests { + + private final Logger logger = (Logger) LoggerFactory.getLogger(ToolNameValidator.class); + + private final ListAppender logAppender = new ListAppender<>(); + + @BeforeEach + void setUp() { + logAppender.start(); + logger.addAppender(logAppender); + } + + @AfterEach + void tearDown() { + logger.detachAppender(logAppender); + logAppender.stop(); + } + + @ParameterizedTest + @ValueSource(strings = { "getUser", "DATA_EXPORT_v2", "admin.tools.list", "my-tool", "Tool123", "a", "A", + "_private", "tool_name", "tool-name", "tool.name", "UPPERCASE", "lowercase", "MixedCase123" }) + void validToolNames(String name) { + assertThatCode(() -> ToolNameValidator.validate(name, true)).doesNotThrowAnyException(); + ToolNameValidator.validate(name, false); + assertThat(logAppender.list).isEmpty(); + } + + @Test + void validToolNameMaxLength() { + String name = "a".repeat(128); + assertThatCode(() -> ToolNameValidator.validate(name, true)).doesNotThrowAnyException(); + ToolNameValidator.validate(name, false); + assertThat(logAppender.list).isEmpty(); + } + + @Test + void nullOrEmpty() { + assertThatThrownBy(() -> ToolNameValidator.validate(null, true)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("null or empty"); + assertThatThrownBy(() -> ToolNameValidator.validate("", true)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("null or empty"); + } + + @Test + void strictLength() { + String name = "a".repeat(129); + assertThatThrownBy(() -> ToolNameValidator.validate(name, true)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("128 characters"); + } + + @ParameterizedTest + @ValueSource(strings = { "tool name", // space + "tool,name", // comma + "tool@name", // at sign + "tool#name", // hash + "tool$name", // dollar + "tool%name", // percent + "tool&name", // ampersand + "tool*name", // asterisk + "tool+name", // plus + "tool=name", // equals + "tool/name", // slash + "tool\\name", // backslash + "tool:name", // colon + "tool;name", // semicolon + "tool'name", // single quote + "tool\"name", // double quote + "toolname", // greater than + "tool?name", // question mark + "tool!name", // exclamation + "tool(name)", // parentheses + "tool[name]", // brackets + "tool{name}", // braces + "tool|name", // pipe + "tool~name", // tilde + "tool`name", // backtick + "tool^name", // caret + "tΓΆΓΆl", // non-ASCII + "ε·₯ε…·" // unicode + }) + void strictInvalidCharacters(String name) { + assertThatThrownBy(() -> ToolNameValidator.validate(name, true)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("invalid characters"); + } + + @Test + void lenientNull() { + assertThatCode(() -> ToolNameValidator.validate(null, false)).doesNotThrowAnyException(); + assertThat(logAppender.list).satisfies(hasWarning("null or empty")); + } + + @Test + void lenientEmpty() { + assertThatCode(() -> ToolNameValidator.validate("", false)).doesNotThrowAnyException(); + assertThat(logAppender.list).satisfies(hasWarning("null or empty")); + } + + @Test + void lenientLength() { + assertThatCode(() -> ToolNameValidator.validate("a".repeat(129), false)).doesNotThrowAnyException(); + assertThat(logAppender.list).satisfies(hasWarning("128 characters")); + } + + @Test + void lenientInvalidCharacters() { + assertThatCode(() -> ToolNameValidator.validate("invalid name", false)).doesNotThrowAnyException(); + assertThat(logAppender.list).satisfies(hasWarning("invalid characters")); + } + + private Consumer> hasWarning(String errorMessage) { + return logs -> { + assertThat(logs).hasSize(1).first().satisfies(log -> { + assertThat(log.getLevel()).isEqualTo(Level.WARN); + assertThat(log.getFormattedMessage()).contains(errorMessage); + }); + }; + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java rename to mcp-core/src/test/java/io/modelcontextprotocol/util/UtilsTests.java diff --git a/mcp/src/test/resources/logback.xml b/mcp-core/src/test/resources/logback.xml similarity index 100% rename from mcp/src/test/resources/logback.xml rename to mcp-core/src/test/resources/logback.xml diff --git a/mcp-json-jackson2/pom.xml b/mcp-json-jackson2/pom.xml new file mode 100644 index 000000000..f25877cd3 --- /dev/null +++ b/mcp-json-jackson2/pom.xml @@ -0,0 +1,107 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 1.1.0-SNAPSHOT + + mcp-json-jackson2 + jar + Java MCP SDK JSON Jackson 2 + Java MCP SDK JSON implementation based on Jackson 2 + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + biz.aQute.bnd + bnd-maven-plugin + ${bnd-maven-plugin.version} + + + bnd-process + + bnd-process + + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson2.version} + + + io.modelcontextprotocol.sdk + mcp-core + 1.1.0-SNAPSHOT + + + com.networknt + json-schema-validator + ${json-schema-validator-jackson2.version} + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson2/JacksonMcpJsonMapper.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson2/JacksonMcpJsonMapper.java new file mode 100644 index 000000000..1760cf472 --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson2/JacksonMcpJsonMapper.java @@ -0,0 +1,88 @@ +/* + * Copyright 2026 - 2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson2; + +import java.io.IOException; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +/** + * Jackson-based implementation of JsonMapper. Wraps a Jackson ObjectMapper but keeps the + * SDK decoupled from Jackson at the API level. + */ +public final class JacksonMcpJsonMapper implements McpJsonMapper { + + private final ObjectMapper objectMapper; + + /** + * Constructs a new JacksonMcpJsonMapper instance with the given ObjectMapper. + * @param objectMapper the ObjectMapper to be used for JSON serialization and + * deserialization. Must not be null. + * @throws IllegalArgumentException if the provided ObjectMapper is null. + */ + public JacksonMcpJsonMapper(ObjectMapper objectMapper) { + if (objectMapper == null) { + throw new IllegalArgumentException("ObjectMapper must not be null"); + } + this.objectMapper = objectMapper; + } + + /** + * Returns the underlying Jackson {@link ObjectMapper} used for JSON serialization and + * deserialization. + * @return the ObjectMapper instance + */ + public ObjectMapper getObjectMapper() { + return objectMapper; + } + + @Override + public T readValue(String content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + return objectMapper.readValue(content, type); + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.readValue(content, javaType); + } + + @Override + public T convertValue(Object fromValue, Class type) { + return objectMapper.convertValue(fromValue, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + JavaType javaType = objectMapper.getTypeFactory().constructType(type.getType()); + return objectMapper.convertValue(fromValue, javaType); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + return objectMapper.writeValueAsString(value); + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + return objectMapper.writeValueAsBytes(value); + } + +} diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson2/JacksonMcpJsonMapperSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson2/JacksonMcpJsonMapperSupplier.java new file mode 100644 index 000000000..acd5dddaa --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/jackson2/JacksonMcpJsonMapperSupplier.java @@ -0,0 +1,32 @@ +/* + * Copyright 2026 - 2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson2; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.McpJsonMapperSupplier; + +/** + * A supplier of {@link McpJsonMapper} instances that uses the Jackson library for JSON + * serialization and deserialization. + *

+ * This implementation provides a {@link McpJsonMapper} backed by a Jackson + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + */ +public class JacksonMcpJsonMapperSupplier implements McpJsonMapperSupplier { + + /** + * Returns a new instance of {@link McpJsonMapper} that uses the Jackson library for + * JSON serialization and deserialization. + *

+ * The returned {@link McpJsonMapper} is backed by a new instance of + * {@link com.fasterxml.jackson.databind.ObjectMapper}. + * @return a new {@link McpJsonMapper} instance + */ + @Override + public McpJsonMapper get() { + return new JacksonMcpJsonMapper(new com.fasterxml.jackson.databind.ObjectMapper()); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson2/DefaultJsonSchemaValidator.java similarity index 65% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java rename to mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson2/DefaultJsonSchemaValidator.java index f4bdc02eb..e07bf1759 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson2/DefaultJsonSchemaValidator.java @@ -1,11 +1,10 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2026-2026 the original author or authors. */ +package io.modelcontextprotocol.json.schema.jackson2; -package io.modelcontextprotocol.spec; - +import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; @@ -14,13 +13,12 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.networknt.schema.JsonSchema; -import com.networknt.schema.JsonSchemaFactory; -import com.networknt.schema.SpecVersion; -import com.networknt.schema.ValidationMessage; +import com.networknt.schema.Error; +import com.networknt.schema.Schema; +import com.networknt.schema.SchemaRegistry; +import com.networknt.schema.dialect.Dialects; -import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; /** * Default implementation of the {@link JsonSchemaValidator} interface. This class @@ -35,10 +33,10 @@ public class DefaultJsonSchemaValidator implements JsonSchemaValidator { private final ObjectMapper objectMapper; - private final JsonSchemaFactory schemaFactory; + private final SchemaRegistry schemaFactory; // TODO: Implement a strategy to purge the cache (TTL, size limit, etc.) - private final ConcurrentHashMap schemaCache; + private final ConcurrentHashMap schemaCache; public DefaultJsonSchemaValidator() { this(new ObjectMapper()); @@ -46,21 +44,27 @@ public DefaultJsonSchemaValidator() { public DefaultJsonSchemaValidator(ObjectMapper objectMapper) { this.objectMapper = objectMapper; - this.schemaFactory = JsonSchemaFactory.getInstance(SpecVersion.VersionFlag.V202012); + this.schemaFactory = SchemaRegistry.withDialect(Dialects.getDraft202012()); this.schemaCache = new ConcurrentHashMap<>(); } @Override - public ValidationResponse validate(Map schema, Map structuredContent) { + public ValidationResponse validate(Map schema, Object structuredContent) { - Assert.notNull(schema, "Schema must not be null"); - Assert.notNull(structuredContent, "Structured content must not be null"); + if (schema == null) { + throw new IllegalArgumentException("Schema must not be null"); + } + if (structuredContent == null) { + throw new IllegalArgumentException("Structured content must not be null"); + } try { - JsonNode jsonStructuredOutput = this.objectMapper.valueToTree(structuredContent); + JsonNode jsonStructuredOutput = (structuredContent instanceof String) + ? this.objectMapper.readTree((String) structuredContent) + : this.objectMapper.valueToTree(structuredContent); - Set validationResult = this.getOrCreateJsonSchema(schema).validate(jsonStructuredOutput); + List validationResult = this.getOrCreateJsonSchema(schema).validate(jsonStructuredOutput); // Check if validation passed if (!validationResult.isEmpty()) { @@ -83,36 +87,36 @@ public ValidationResponse validate(Map schema, Map schema) throws JsonProcessingException { + private Schema getOrCreateJsonSchema(Map schema) throws JsonProcessingException { // Generate cache key based on schema content String cacheKey = this.generateCacheKey(schema); // Try to get from cache first - JsonSchema cachedSchema = this.schemaCache.get(cacheKey); + Schema cachedSchema = this.schemaCache.get(cacheKey); if (cachedSchema != null) { return cachedSchema; } // Create new schema if not in cache - JsonSchema newSchema = this.createJsonSchema(schema); + Schema newSchema = this.createJsonSchema(schema); // Cache the schema - JsonSchema existingSchema = this.schemaCache.putIfAbsent(cacheKey, newSchema); + Schema existingSchema = this.schemaCache.putIfAbsent(cacheKey, newSchema); return existingSchema != null ? existingSchema : newSchema; } /** - * Creates a new JsonSchema from the given schema map. + * Creates a new Schema from the given schema map. * @param schema the schema map - * @return the compiled JsonSchema + * @return the compiled Schema * @throws JsonProcessingException if schema processing fails */ - private JsonSchema createJsonSchema(Map schema) throws JsonProcessingException { + private Schema createJsonSchema(Map schema) throws JsonProcessingException { // Convert schema map directly to JsonNode (more efficient than string // serialization) JsonNode schemaNode = this.objectMapper.valueToTree(schema); @@ -123,17 +127,6 @@ private JsonSchema createJsonSchema(Map schema) throws JsonProce }; } - // Handle additionalProperties setting - if (schemaNode.isObject()) { - ObjectNode objectSchemaNode = (ObjectNode) schemaNode; - if (!objectSchemaNode.has("additionalProperties")) { - // Clone the node before modification to avoid mutating the original - objectSchemaNode = objectSchemaNode.deepCopy(); - objectSchemaNode.put("additionalProperties", false); - schemaNode = objectSchemaNode; - } - } - return this.schemaFactory.getSchema(schemaNode); } diff --git a/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson2/JacksonJsonSchemaValidatorSupplier.java b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson2/JacksonJsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..aa280a38e --- /dev/null +++ b/mcp-json-jackson2/src/main/java/io/modelcontextprotocol/json/schema/jackson2/JacksonJsonSchemaValidatorSupplier.java @@ -0,0 +1,29 @@ +/* + * Copyright 2026 - 2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema.jackson2; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier; + +/** + * A concrete implementation of {@link JsonSchemaValidatorSupplier} that provides a + * {@link JsonSchemaValidator} instance based on the Jackson library. + * + * @see JsonSchemaValidatorSupplier + * @see JsonSchemaValidator + */ +public class JacksonJsonSchemaValidatorSupplier implements JsonSchemaValidatorSupplier { + + /** + * Returns a new instance of {@link JsonSchemaValidator} that uses the Jackson library + * for JSON schema validation. + * @return A {@link JsonSchemaValidator} instance. + */ + @Override + public JsonSchemaValidator get() { + return new DefaultJsonSchemaValidator(); + } + +} diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier new file mode 100644 index 000000000..0c62b6478 --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.jackson2.JacksonMcpJsonMapperSupplier \ No newline at end of file diff --git a/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier new file mode 100644 index 000000000..1b2f05f97 --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.schema.jackson2.JacksonJsonSchemaValidatorSupplier \ No newline at end of file diff --git a/mcp-json-jackson2/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.jackson2.JacksonMcpJsonMapperSupplier.xml b/mcp-json-jackson2/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.jackson2.JacksonMcpJsonMapperSupplier.xml new file mode 100644 index 000000000..1d6705f56 --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.jackson2.JacksonMcpJsonMapperSupplier.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/mcp-json-jackson2/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.schema.jackson2.JacksonJsonSchemaValidatorSupplier.xml b/mcp-json-jackson2/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.schema.jackson2.JacksonJsonSchemaValidatorSupplier.xml new file mode 100644 index 000000000..ad628745f --- /dev/null +++ b/mcp-json-jackson2/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.schema.jackson2.JacksonJsonSchemaValidatorSupplier.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/McpJsonMapperTest.java b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/McpJsonMapperTest.java new file mode 100644 index 000000000..7ae5d0887 --- /dev/null +++ b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/McpJsonMapperTest.java @@ -0,0 +1,20 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.json.jackson2.JacksonMcpJsonMapper; + +class McpJsonMapperTest { + + @Test + void shouldUseJackson2Mapper() { + assertThat(McpJsonDefaults.getMapper()).isInstanceOf(JacksonMcpJsonMapper.class); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/jackson2/DefaultJsonSchemaValidatorTests.java similarity index 83% rename from mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java rename to mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/jackson2/DefaultJsonSchemaValidatorTests.java index 30158543d..5ae3fbed4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java +++ b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/jackson2/DefaultJsonSchemaValidatorTests.java @@ -1,8 +1,8 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2026-2026 the original author or authors. */ -package io.modelcontextprotocol.spec; +package io.modelcontextprotocol.json.jackson2; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -13,6 +13,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import java.util.List; import java.util.Map; import java.util.stream.Stream; @@ -27,7 +28,8 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.JsonSchemaValidator.ValidationResponse; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse; +import io.modelcontextprotocol.json.schema.jackson2.DefaultJsonSchemaValidator; /** * Tests for {@link DefaultJsonSchemaValidator}. @@ -63,6 +65,16 @@ private Map toMap(String json) { } } + private List> toListMap(String json) { + try { + return objectMapper.readValue(json, new TypeReference>>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + @Test void testDefaultConstructor() { DefaultJsonSchemaValidator defaultValidator = new DefaultJsonSchemaValidator(); @@ -197,6 +209,74 @@ void testValidateWithValidArraySchema() { assertNull(response.errorMessage()); } + @Test + void testValidateWithValidArraySchemaTopLevelArray() { + String schemaJson = """ + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "array", + "items" : { + "type" : "object", + "properties" : { + "city" : { + "type" : "string" + }, + "summary" : { + "type" : "string" + }, + "temperatureC" : { + "type" : "number", + "format" : "float" + } + }, + "required" : [ "city", "summary", "temperatureC" ] + }, + "additionalProperties" : false + } + """; + + String contentJson = """ + [ + { + "city": "London", + "summary": "Generally mild with frequent rainfall. Winters are cool and damp, summers are warm but rarely hot. Cloudy conditions are common throughout the year.", + "temperatureC": 11.3 + }, + { + "city": "New York", + "summary": "Four distinct seasons with hot and humid summers, cold winters with snow, and mild springs and autumns. Precipitation is fairly evenly distributed throughout the year.", + "temperatureC": 12.8 + }, + { + "city": "San Francisco", + "summary": "Mild year-round with a distinctive Mediterranean climate. Famous for summer fog, mild winters, and little temperature variation throughout the year. Very little rainfall in summer months.", + "temperatureC": 14.6 + }, + { + "city": "Tokyo", + "summary": "Humid subtropical climate with hot, wet summers and mild winters. Experiences a rainy season in early summer and occasional typhoons in late summer to early autumn.", + "temperatureC": 15.4 + } + ] + """; + + Map schema = toMap(schemaJson); + + // Validate as JSON string + ValidationResponse response = validator.validate(schema, contentJson); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + + List> structuredContent = toListMap(contentJson); + + // Validate as List> + response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + @Test void testValidateWithInvalidTypeSchema() { String schemaJson = """ @@ -265,7 +345,8 @@ void testValidateWithAdditionalPropertiesNotAllowed() { "properties": { "name": {"type": "string"} }, - "required": ["name"] + "required": ["name"], + "additionalProperties": false } """; @@ -315,6 +396,35 @@ void testValidateWithAdditionalPropertiesExplicitlyAllowed() { assertNull(response.errorMessage()); } + @Test + void testValidateWithDefaultAdditionalProperties() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + @Test void testValidateWithAdditionalPropertiesExplicitlyDisallowed() { String schemaJson = """ diff --git a/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorTest.java b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorTest.java new file mode 100644 index 000000000..92a80cb9b --- /dev/null +++ b/mcp-json-jackson2/src/test/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorTest.java @@ -0,0 +1,21 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.schema.jackson2.DefaultJsonSchemaValidator; + +class JsonSchemaValidatorTest { + + @Test + void shouldUseJackson2Mapper() { + assertThat(McpJsonDefaults.getSchemaValidator()).isInstanceOf(DefaultJsonSchemaValidator.class); + } + +} diff --git a/mcp-json-jackson3/pom.xml b/mcp-json-jackson3/pom.xml new file mode 100644 index 000000000..99baf14e1 --- /dev/null +++ b/mcp-json-jackson3/pom.xml @@ -0,0 +1,106 @@ + + + 4.0.0 + + io.modelcontextprotocol.sdk + mcp-parent + 1.1.0-SNAPSHOT + + mcp-json-jackson3 + jar + Java MCP SDK JSON Jackson 3 + Java MCP SDK JSON implementation based on Jackson 3 + https://github.com/modelcontextprotocol/java-sdk + + https://github.com/modelcontextprotocol/java-sdk + git://github.com/modelcontextprotocol/java-sdk.git + git@github.com/modelcontextprotocol/java-sdk.git + + + + + biz.aQute.bnd + bnd-maven-plugin + ${bnd-maven-plugin.version} + + + bnd-process + + bnd-process + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + + + io.modelcontextprotocol.sdk + mcp-core + 1.1.0-SNAPSHOT + + + tools.jackson.core + jackson-databind + ${jackson3.version} + + + com.networknt + json-schema-validator + ${json-schema-validator-jackson3.version} + + + + org.assertj + assertj-core + ${assert4j.version} + test + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + diff --git a/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/jackson3/JacksonMcpJsonMapper.java b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/jackson3/JacksonMcpJsonMapper.java new file mode 100644 index 000000000..a0dbdd555 --- /dev/null +++ b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/jackson3/JacksonMcpJsonMapper.java @@ -0,0 +1,119 @@ +/* + * Copyright 2026 - 2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson3; + +import java.io.IOException; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; + +import tools.jackson.core.JacksonException; +import tools.jackson.databind.JavaType; +import tools.jackson.databind.json.JsonMapper; + +/** + * Jackson-based implementation of JsonMapper. Wraps a Jackson JsonMapper but keeps the + * SDK decoupled from Jackson at the API level. + */ +public final class JacksonMcpJsonMapper implements McpJsonMapper { + + private final JsonMapper jsonMapper; + + /** + * Constructs a new JacksonMcpJsonMapper instance with the given JsonMapper. + * @param jsonMapper the JsonMapper to be used for JSON serialization and + * deserialization. Must not be null. + * @throws IllegalArgumentException if the provided JsonMapper is null. + */ + public JacksonMcpJsonMapper(JsonMapper jsonMapper) { + if (jsonMapper == null) { + throw new IllegalArgumentException("JsonMapper must not be null"); + } + this.jsonMapper = jsonMapper; + } + + /** + * Returns the underlying Jackson {@link JsonMapper} used for JSON serialization and + * deserialization. + * @return the JsonMapper instance + */ + public JsonMapper getJsonMapper() { + return jsonMapper; + } + + @Override + public T readValue(String content, Class type) throws IOException { + try { + return jsonMapper.readValue(content, type); + } + catch (JacksonException ex) { + throw new IOException("Failed to read value", ex); + } + } + + @Override + public T readValue(byte[] content, Class type) throws IOException { + try { + return jsonMapper.readValue(content, type); + } + catch (JacksonException ex) { + throw new IOException("Failed to read value", ex); + } + } + + @Override + public T readValue(String content, TypeRef type) throws IOException { + JavaType javaType = jsonMapper.getTypeFactory().constructType(type.getType()); + try { + return jsonMapper.readValue(content, javaType); + } + catch (JacksonException ex) { + throw new IOException("Failed to read value", ex); + } + } + + @Override + public T readValue(byte[] content, TypeRef type) throws IOException { + JavaType javaType = jsonMapper.getTypeFactory().constructType(type.getType()); + try { + return jsonMapper.readValue(content, javaType); + } + catch (JacksonException ex) { + throw new IOException("Failed to read value", ex); + } + } + + @Override + public T convertValue(Object fromValue, Class type) { + return jsonMapper.convertValue(fromValue, type); + } + + @Override + public T convertValue(Object fromValue, TypeRef type) { + JavaType javaType = jsonMapper.getTypeFactory().constructType(type.getType()); + return jsonMapper.convertValue(fromValue, javaType); + } + + @Override + public String writeValueAsString(Object value) throws IOException { + try { + return jsonMapper.writeValueAsString(value); + } + catch (JacksonException ex) { + throw new IOException("Failed to write value as string", ex); + } + } + + @Override + public byte[] writeValueAsBytes(Object value) throws IOException { + try { + return jsonMapper.writeValueAsBytes(value); + } + catch (JacksonException ex) { + throw new IOException("Failed to write value as bytes", ex); + } + } + +} diff --git a/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/jackson3/JacksonMcpJsonMapperSupplier.java b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/jackson3/JacksonMcpJsonMapperSupplier.java new file mode 100644 index 000000000..839862ffe --- /dev/null +++ b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/jackson3/JacksonMcpJsonMapperSupplier.java @@ -0,0 +1,34 @@ +/* + * Copyright 2026 - 2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.jackson3; + +import tools.jackson.databind.json.JsonMapper; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.McpJsonMapperSupplier; + +/** + * A supplier of {@link McpJsonMapper} instances that uses the Jackson library for JSON + * serialization and deserialization. + *

+ * This implementation provides a {@link McpJsonMapper} backed by + * {@link JsonMapper#shared() JsonMapper shared instance}. + */ +public class JacksonMcpJsonMapperSupplier implements McpJsonMapperSupplier { + + /** + * Returns a new instance of {@link McpJsonMapper} that uses the Jackson library for + * JSON serialization and deserialization. + *

+ * The returned {@link McpJsonMapper} is backed by {@link JsonMapper#shared() + * JsonMapper shared instance}. + * @return a new {@link McpJsonMapper} instance + */ + @Override + public McpJsonMapper get() { + return new JacksonMcpJsonMapper(JsonMapper.shared()); + } + +} diff --git a/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/schema/jackson3/DefaultJsonSchemaValidator.java b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/schema/jackson3/DefaultJsonSchemaValidator.java new file mode 100644 index 000000000..8c9b7ccdb --- /dev/null +++ b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/schema/jackson3/DefaultJsonSchemaValidator.java @@ -0,0 +1,162 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.json.schema.jackson3; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.networknt.schema.Schema; +import com.networknt.schema.SchemaRegistry; +import com.networknt.schema.Error; +import com.networknt.schema.dialect.Dialects; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import tools.jackson.core.JacksonException; +import tools.jackson.databind.JsonNode; +import tools.jackson.databind.json.JsonMapper; + +/** + * Default implementation of the {@link JsonSchemaValidator} interface. This class + * provides methods to validate structured content against a JSON schema. It uses the + * NetworkNT JSON Schema Validator library for validation. + * + * @author Filip Hrisafov + */ +public class DefaultJsonSchemaValidator implements JsonSchemaValidator { + + private static final Logger logger = LoggerFactory.getLogger(DefaultJsonSchemaValidator.class); + + private final JsonMapper jsonMapper; + + private final SchemaRegistry schemaFactory; + + // TODO: Implement a strategy to purge the cache (TTL, size limit, etc.) + private final ConcurrentHashMap schemaCache; + + public DefaultJsonSchemaValidator() { + this(JsonMapper.shared()); + } + + public DefaultJsonSchemaValidator(JsonMapper jsonMapper) { + this.jsonMapper = jsonMapper; + this.schemaFactory = SchemaRegistry.withDialect(Dialects.getDraft202012()); + this.schemaCache = new ConcurrentHashMap<>(); + } + + @Override + public ValidationResponse validate(Map schema, Object structuredContent) { + + if (schema == null) { + throw new IllegalArgumentException("Schema must not be null"); + } + if (structuredContent == null) { + throw new IllegalArgumentException("Structured content must not be null"); + } + + try { + + JsonNode jsonStructuredOutput = (structuredContent instanceof String) + ? this.jsonMapper.readTree((String) structuredContent) + : this.jsonMapper.valueToTree(structuredContent); + + List validationResult = this.getOrCreateJsonSchema(schema).validate(jsonStructuredOutput); + + // Check if validation passed + if (!validationResult.isEmpty()) { + return ValidationResponse + .asInvalid("Validation failed: structuredContent does not match tool outputSchema. " + + "Validation errors: " + validationResult); + } + + return ValidationResponse.asValid(jsonStructuredOutput.toString()); + + } + catch (JacksonException e) { + logger.error("Failed to validate CallToolResult: Error parsing schema: {}", e); + return ValidationResponse.asInvalid("Error parsing tool JSON Schema: " + e.getMessage()); + } + catch (Exception e) { + logger.error("Failed to validate CallToolResult: Unexpected error: {}", e); + return ValidationResponse.asInvalid("Unexpected validation error: " + e.getMessage()); + } + } + + /** + * Gets a cached Schema or creates and caches a new one. + * @param schema the schema map to convert + * @return the compiled Schema + * @throws JacksonException if schema processing fails + */ + private Schema getOrCreateJsonSchema(Map schema) throws JacksonException { + // Generate cache key based on schema content + String cacheKey = this.generateCacheKey(schema); + + // Try to get from cache first + Schema cachedSchema = this.schemaCache.get(cacheKey); + if (cachedSchema != null) { + return cachedSchema; + } + + // Create new schema if not in cache + Schema newSchema = this.createJsonSchema(schema); + + // Cache the schema + Schema existingSchema = this.schemaCache.putIfAbsent(cacheKey, newSchema); + return existingSchema != null ? existingSchema : newSchema; + } + + /** + * Creates a new Schema from the given schema map. + * @param schema the schema map + * @return the compiled Schema + * @throws JacksonException if schema processing fails + */ + private Schema createJsonSchema(Map schema) throws JacksonException { + // Convert schema map directly to JsonNode (more efficient than string + // serialization) + JsonNode schemaNode = this.jsonMapper.valueToTree(schema); + + // Handle case where ObjectMapper might return null (e.g., in mocked scenarios) + if (schemaNode == null) { + throw new JacksonException("Failed to convert schema to JsonNode") { + }; + } + + return this.schemaFactory.getSchema(schemaNode); + } + + /** + * Generates a cache key for the given schema map. + * @param schema the schema map + * @return a cache key string + */ + protected String generateCacheKey(Map schema) { + if (schema.containsKey("$id")) { + // Use the (optional) "$id" field as the cache key if present + return "" + schema.get("$id"); + } + // Fall back to schema's hash code as a simple cache key + // For more sophisticated caching, could use content-based hashing + return String.valueOf(schema.hashCode()); + } + + /** + * Clears the schema cache. Useful for testing or memory management. + */ + public void clearCache() { + this.schemaCache.clear(); + } + + /** + * Returns the current size of the schema cache. + * @return the number of cached schemas + */ + public int getCacheSize() { + return this.schemaCache.size(); + } + +} diff --git a/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/schema/jackson3/JacksonJsonSchemaValidatorSupplier.java b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/schema/jackson3/JacksonJsonSchemaValidatorSupplier.java new file mode 100644 index 000000000..87cead5db --- /dev/null +++ b/mcp-json-jackson3/src/main/java/io/modelcontextprotocol/json/schema/jackson3/JacksonJsonSchemaValidatorSupplier.java @@ -0,0 +1,29 @@ +/* + * Copyright 2026 - 2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema.jackson3; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier; + +/** + * A concrete implementation of {@link JsonSchemaValidatorSupplier} that provides a + * {@link JsonSchemaValidator} instance based on the Jackson library. + * + * @see JsonSchemaValidatorSupplier + * @see JsonSchemaValidator + */ +public class JacksonJsonSchemaValidatorSupplier implements JsonSchemaValidatorSupplier { + + /** + * Returns a new instance of {@link JsonSchemaValidator} that uses the Jackson library + * for JSON schema validation. + * @return A {@link JsonSchemaValidator} instance. + */ + @Override + public JsonSchemaValidator get() { + return new DefaultJsonSchemaValidator(); + } + +} diff --git a/mcp-json-jackson3/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier b/mcp-json-jackson3/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier new file mode 100644 index 000000000..6abfb347f --- /dev/null +++ b/mcp-json-jackson3/src/main/resources/META-INF/services/io.modelcontextprotocol.json.McpJsonMapperSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.jackson3.JacksonMcpJsonMapperSupplier \ No newline at end of file diff --git a/mcp-json-jackson3/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier b/mcp-json-jackson3/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier new file mode 100644 index 000000000..2bab3ba8e --- /dev/null +++ b/mcp-json-jackson3/src/main/resources/META-INF/services/io.modelcontextprotocol.json.schema.JsonSchemaValidatorSupplier @@ -0,0 +1 @@ +io.modelcontextprotocol.json.schema.jackson3.JacksonJsonSchemaValidatorSupplier \ No newline at end of file diff --git a/mcp-json-jackson3/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.jackson3.JacksonMcpJsonMapperSupplier.xml b/mcp-json-jackson3/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.jackson3.JacksonMcpJsonMapperSupplier.xml new file mode 100644 index 000000000..0ad8a7b42 --- /dev/null +++ b/mcp-json-jackson3/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.jackson3.JacksonMcpJsonMapperSupplier.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/mcp-json-jackson3/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.schema.jackson3.JacksonJsonSchemaValidatorSupplier.xml b/mcp-json-jackson3/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.schema.jackson3.JacksonJsonSchemaValidatorSupplier.xml new file mode 100644 index 000000000..d14d8bea3 --- /dev/null +++ b/mcp-json-jackson3/src/main/resources/OSGI-INF/io.modelcontextprotocol.json.schema.jackson3.JacksonJsonSchemaValidatorSupplier.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java b/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java new file mode 100644 index 000000000..37c52caf7 --- /dev/null +++ b/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/DefaultJsonSchemaValidatorTests.java @@ -0,0 +1,807 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.json.schema.jackson3.DefaultJsonSchemaValidator; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import tools.jackson.core.type.TypeReference; +import tools.jackson.databind.json.JsonMapper; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse; + +/** + * Tests for {@link DefaultJsonSchemaValidator}. + * + * @author Filip Hrisafov + */ +class DefaultJsonSchemaValidatorTests { + + private DefaultJsonSchemaValidator validator; + + private JsonMapper jsonMapper; + + @Mock + private JsonMapper mockJsonMapper; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + validator = new DefaultJsonSchemaValidator(); + jsonMapper = JsonMapper.shared(); + } + + /** + * Utility method to convert JSON string to Map + */ + private Map toMap(String json) { + try { + return jsonMapper.readValue(json, new TypeReference<>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + private List> toListMap(String json) { + try { + return jsonMapper.readValue(json, new TypeReference<>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + @Test + void testDefaultConstructor() { + DefaultJsonSchemaValidator defaultValidator = new DefaultJsonSchemaValidator(); + + String schemaJson = """ + { + "type": "object", + "properties": { + "test": {"type": "string"} + } + } + """; + String contentJson = """ + { + "test": "value" + } + """; + + ValidationResponse response = defaultValidator.validate(toMap(schemaJson), toMap(contentJson)); + assertTrue(response.valid()); + } + + @Test + void testConstructorWithObjectMapper() { + JsonMapper customMapper = JsonMapper.builder().build(); + DefaultJsonSchemaValidator customValidator = new DefaultJsonSchemaValidator(customMapper); + + String schemaJson = """ + { + "type": "object", + "properties": { + "test": {"type": "string"} + } + } + """; + String contentJson = """ + { + "test": "value" + } + """; + + ValidationResponse response = customValidator.validate(toMap(schemaJson), toMap(contentJson)); + assertTrue(response.valid()); + } + + @Test + void testValidateWithValidStringSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe", + "age": 30 + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + assertNotNull(response.jsonStructuredOutput()); + } + + @Test + void testValidateWithValidNumberSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "price": {"type": "number", "minimum": 0}, + "quantity": {"type": "integer", "minimum": 1} + }, + "required": ["price", "quantity"] + } + """; + + String contentJson = """ + { + "price": 19.99, + "quantity": 5 + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithValidArraySchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["items"] + } + """; + + String contentJson = """ + { + "items": ["apple", "banana", "cherry"] + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithValidArraySchemaTopLevelArray() { + String schemaJson = """ + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "array", + "items" : { + "type" : "object", + "properties" : { + "city" : { + "type" : "string" + }, + "summary" : { + "type" : "string" + }, + "temperatureC" : { + "type" : "number", + "format" : "float" + } + }, + "required" : [ "city", "summary", "temperatureC" ] + }, + "additionalProperties" : false + } + """; + + String contentJson = """ + [ + { + "city": "London", + "summary": "Generally mild with frequent rainfall. Winters are cool and damp, summers are warm but rarely hot. Cloudy conditions are common throughout the year.", + "temperatureC": 11.3 + }, + { + "city": "New York", + "summary": "Four distinct seasons with hot and humid summers, cold winters with snow, and mild springs and autumns. Precipitation is fairly evenly distributed throughout the year.", + "temperatureC": 12.8 + }, + { + "city": "San Francisco", + "summary": "Mild year-round with a distinctive Mediterranean climate. Famous for summer fog, mild winters, and little temperature variation throughout the year. Very little rainfall in summer months.", + "temperatureC": 14.6 + }, + { + "city": "Tokyo", + "summary": "Humid subtropical climate with hot, wet summers and mild winters. Experiences a rainy season in early summer and occasional typhoons in late summer to early autumn.", + "temperatureC": 15.4 + } + ] + """; + + Map schema = toMap(schemaJson); + + // Validate as JSON string + ValidationResponse response = validator.validate(schema, contentJson); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + + List> structuredContent = toListMap(contentJson); + + // Validate as List> + response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithInvalidTypeSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe", + "age": "thirty" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + assertTrue(response.errorMessage().contains("structuredContent does not match tool outputSchema")); + } + + @Test + void testValidateWithMissingRequiredField() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + """; + + String contentJson = """ + { + "name": "John Doe" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithAdditionalPropertiesNotAllowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": false + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should not be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithAdditionalPropertiesExplicitlyAllowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithDefaultAdditionalProperties() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": true + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithAdditionalPropertiesExplicitlyDisallowed() { + String schemaJson = """ + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"], + "additionalProperties": false + } + """; + + String contentJson = """ + { + "name": "John Doe", + "extraField": "should not be allowed" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithEmptySchema() { + String schemaJson = """ + { + "additionalProperties": true + } + """; + + String contentJson = """ + { + "anything": "goes" + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithEmptyContent() { + String schemaJson = """ + { + "type": "object", + "properties": {} + } + """; + + String contentJson = """ + {} + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithNestedObjectSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "required": ["name", "address"] + } + }, + "required": ["person"] + } + """; + + String contentJson = """ + { + "person": { + "name": "John Doe", + "address": { + "street": "123 Main St", + "city": "Anytown" + } + } + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertTrue(response.valid()); + assertNull(response.errorMessage()); + } + + @Test + void testValidateWithInvalidNestedObjectSchema() { + String schemaJson = """ + { + "type": "object", + "properties": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + }, + "required": ["street", "city"] + } + }, + "required": ["name", "address"] + } + }, + "required": ["person"] + } + """; + + String contentJson = """ + { + "person": { + "name": "John Doe", + "address": { + "street": "123 Main St" + } + } + } + """; + + Map schema = toMap(schemaJson); + Map structuredContent = toMap(contentJson); + + ValidationResponse response = validator.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + @Test + void testValidateWithJsonProcessingException() { + DefaultJsonSchemaValidator validatorWithMockMapper = new DefaultJsonSchemaValidator(mockJsonMapper); + + Map schema = Map.of("type", "object"); + Map structuredContent = Map.of("key", "value"); + + // This will trigger our null check and throw JsonProcessingException + when(mockJsonMapper.valueToTree(any())).thenReturn(null); + + ValidationResponse response = validatorWithMockMapper.validate(schema, structuredContent); + + assertFalse(response.valid()); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Error parsing tool JSON Schema")); + assertTrue(response.errorMessage().contains("Failed to convert schema to JsonNode")); + } + + @ParameterizedTest + @MethodSource("provideValidSchemaAndContentPairs") + void testValidateWithVariousValidInputs(Map schema, Map content) { + ValidationResponse response = validator.validate(schema, content); + + assertTrue(response.valid(), "Expected validation to pass for schema: " + schema + " and content: " + content); + assertNull(response.errorMessage()); + } + + @ParameterizedTest + @MethodSource("provideInvalidSchemaAndContentPairs") + void testValidateWithVariousInvalidInputs(Map schema, Map content) { + ValidationResponse response = validator.validate(schema, content); + + assertFalse(response.valid(), "Expected validation to fail for schema: " + schema + " and content: " + content); + assertNotNull(response.errorMessage()); + assertTrue(response.errorMessage().contains("Validation failed")); + } + + private static Map staticToMap(String json) { + try { + return JsonMapper.shared().readValue(json, new TypeReference<>() { + }); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + json, e); + } + } + + private static Stream provideValidSchemaAndContentPairs() { + return Stream.of( + // Boolean schema + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "flag": {"type": "boolean"} + } + } + """), staticToMap(""" + { + "flag": true + } + """)), + // String with additional properties allowed + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "additionalProperties": true + } + """), staticToMap(""" + { + "name": "test", + "extra": "allowed" + } + """)), + // Array with specific items + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "items": {"type": "number"} + } + } + } + """), staticToMap(""" + { + "numbers": [1.0, 2.5, 3.14] + } + """)), + // Enum validation + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"] + } + } + } + """), staticToMap(""" + { + "status": "active" + } + """))); + } + + private static Stream provideInvalidSchemaAndContentPairs() { + return Stream.of( + // Wrong boolean type + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "flag": {"type": "boolean"} + } + } + """), staticToMap(""" + { + "flag": "true" + } + """)), + // Array with wrong item types + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "items": {"type": "number"} + } + } + } + """), staticToMap(""" + { + "numbers": ["one", "two", "three"] + } + """)), + // Invalid enum value + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["active", "inactive", "pending"] + } + } + } + """), staticToMap(""" + { + "status": "unknown" + } + """)), + // Minimum constraint violation + Arguments.of(staticToMap(""" + { + "type": "object", + "properties": { + "age": {"type": "integer", "minimum": 0} + } + } + """), staticToMap(""" + { + "age": -5 + } + """))); + } + + @Test + void testValidationResponseToValid() { + String jsonOutput = "{\"test\":\"value\"}"; + ValidationResponse response = ValidationResponse.asValid(jsonOutput); + assertTrue(response.valid()); + assertNull(response.errorMessage()); + assertEquals(jsonOutput, response.jsonStructuredOutput()); + } + + @Test + void testValidationResponseToInvalid() { + String errorMessage = "Test error message"; + ValidationResponse response = ValidationResponse.asInvalid(errorMessage); + assertFalse(response.valid()); + assertEquals(errorMessage, response.errorMessage()); + assertNull(response.jsonStructuredOutput()); + } + + @Test + void testValidationResponseRecord() { + ValidationResponse response1 = new ValidationResponse(true, null, "{\"valid\":true}"); + ValidationResponse response2 = new ValidationResponse(false, "Error", null); + + assertTrue(response1.valid()); + assertNull(response1.errorMessage()); + assertEquals("{\"valid\":true}", response1.jsonStructuredOutput()); + + assertFalse(response2.valid()); + assertEquals("Error", response2.errorMessage()); + assertNull(response2.jsonStructuredOutput()); + + // Test equality + ValidationResponse response3 = new ValidationResponse(true, null, "{\"valid\":true}"); + assertEquals(response1, response3); + assertNotEquals(response1, response2); + } + +} diff --git a/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/McpJsonMapperTest.java b/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/McpJsonMapperTest.java new file mode 100644 index 000000000..0307fceb5 --- /dev/null +++ b/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/McpJsonMapperTest.java @@ -0,0 +1,20 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.json; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.json.jackson3.JacksonMcpJsonMapper; + +class McpJsonMapperTest { + + @Test + void shouldUseJackson2Mapper() { + assertThat(McpJsonDefaults.getMapper()).isInstanceOf(JacksonMcpJsonMapper.class); + } + +} diff --git a/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorTest.java b/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorTest.java new file mode 100644 index 000000000..05dba4f42 --- /dev/null +++ b/mcp-json-jackson3/src/test/java/io/modelcontextprotocol/json/schema/JsonSchemaValidatorTest.java @@ -0,0 +1,21 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.json.schema; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.schema.jackson3.DefaultJsonSchemaValidator; + +class JsonSchemaValidatorTest { + + @Test + void shouldUseJackson2Mapper() { + assertThat(McpJsonDefaults.getSchemaValidator()).isInstanceOf(DefaultJsonSchemaValidator.class); + } + +} diff --git a/mcp-spring/mcp-spring-webflux/README.md b/mcp-spring/mcp-spring-webflux/README.md deleted file mode 100644 index e701e41e6..000000000 --- a/mcp-spring/mcp-spring-webflux/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# WebFlux SSE Transport - -```xml - - io.modelcontextprotocol.sdk - mcp-spring-webflux - -``` - -```java -String MESSAGE_ENDPOINT = "/mcp/message"; - -@Configuration -static class MyConfig { - - // SSE transport - @Bean - public WebFluxSseServerTransport sseServerTransport() { - return new WebFluxSseServerTransport(new ObjectMapper(), "/mcp/message"); - } - - // Router function for SSE transport used by Spring WebFlux to start an HTTP - // server. - @Bean - public RouterFunction mcpRouterFunction(WebFluxSseServerTransport transport) { - return transport.getRouterFunction(); - } - - @Bean - public McpAsyncServer mcpServer(ServerMcpTransport transport, OpenLibrary openLibrary) { - - // Configure server capabilities with resource support - var capabilities = McpSchema.ServerCapabilities.builder() - .resources(false, true) // No subscribe support, but list changes notifications - .tools(true) // Tool support with list changes notifications - .prompts(true) // Prompt support with list changes notifications - .logging() // Logging support - .build(); - - // Create the server with both tool and resource capabilities - var server = McpServer.using(transport) - .serverInfo("MCP Demo Server", "1.0.0") - .capabilities(capabilities) - .resources(systemInfoResourceRegistration()) - .prompts(greetingPromptRegistration()) - .tools(openLibraryToolRegistrations(openLibrary)) - .async(); - - return server; - } - - // ... - -} -``` diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java deleted file mode 100644 index 853aed2bf..000000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ /dev/null @@ -1,566 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.io.IOException; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; - -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.client.ClientResponse; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.client.WebClientResponseException; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.spec.DefaultMcpTransportSession; -import io.modelcontextprotocol.spec.DefaultMcpTransportStream; -import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransportException; -import io.modelcontextprotocol.spec.McpTransportSession; -import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; -import io.modelcontextprotocol.spec.McpTransportStream; -import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.Utils; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.util.function.Tuple2; -import reactor.util.function.Tuples; - -/** - * An implementation of the Streamable HTTP protocol as defined by the - * 2025-03-26 version of the MCP specification. - * - *

- * The transport is capable of resumability and reconnects. It reacts to transport-level - * session invalidation and will propagate {@link McpTransportSessionNotFoundException - * appropriate exceptions} to the higher level abstraction layer when needed in order to - * allow proper state management. The implementation handles servers that are stateful and - * provide session meta information, but can also communicate with stateless servers that - * do not provide a session identifier and do not support SSE streams. - *

- *

- * This implementation does not handle backwards compatibility with the "HTTP - * with SSE" transport. In order to communicate over the phased-out - * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or - * {@link WebFluxSseClientTransport}. - *

- * - * @author Dariusz JΔ™drzejczyk - * @see Streamable - * HTTP transport specification - */ -public class WebClientStreamableHttpTransport implements McpClientTransport { - - private static final String MISSING_SESSION_ID = "[missing_session_id]"; - - private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); - - private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; - - private static final String DEFAULT_ENDPOINT = "/mcp"; - - /** - * Event type for JSON-RPC messages received through the SSE connection. The server - * sends messages with this event type to transmit JSON-RPC protocol data. - */ - private static final String MESSAGE_EVENT_TYPE = "message"; - - private static final ParameterizedTypeReference> PARAMETERIZED_TYPE_REF = new ParameterizedTypeReference<>() { - }; - - private final ObjectMapper objectMapper; - - private final WebClient webClient; - - private final String endpoint; - - private final boolean openConnectionOnStartup; - - private final boolean resumableStreams; - - private final AtomicReference activeSession = new AtomicReference<>(); - - private final AtomicReference, Mono>> handler = new AtomicReference<>(); - - private final AtomicReference> exceptionHandler = new AtomicReference<>(); - - private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Builder webClientBuilder, - String endpoint, boolean resumableStreams, boolean openConnectionOnStartup) { - this.objectMapper = objectMapper; - this.webClient = webClientBuilder.build(); - this.endpoint = endpoint; - this.resumableStreams = resumableStreams; - this.openConnectionOnStartup = openConnectionOnStartup; - this.activeSession.set(createTransportSession()); - } - - @Override - public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); - } - - /** - * Create a stateful builder for creating {@link WebClientStreamableHttpTransport} - * instances. - * @param webClientBuilder the {@link WebClient.Builder} to use - * @return a builder which will create an instance of - * {@link WebClientStreamableHttpTransport} once {@link Builder#build()} is called - */ - public static Builder builder(WebClient.Builder webClientBuilder) { - return new Builder(webClientBuilder); - } - - @Override - public Mono connect(Function, Mono> handler) { - return Mono.deferContextual(ctx -> { - this.handler.set(handler); - if (openConnectionOnStartup) { - logger.debug("Eagerly opening connection on startup"); - return this.reconnect(null).then(); - } - return Mono.empty(); - }); - } - - private DefaultMcpTransportSession createTransportSession() { - Function> onClose = sessionId -> sessionId == null ? Mono.empty() - : webClient.delete() - .uri(this.endpoint) - .header(HttpHeaders.MCP_SESSION_ID, sessionId) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) - .retrieve() - .toBodilessEntity() - .onErrorComplete(e -> { - logger.warn("Got error when closing transport", e); - return true; - }) - .then(); - return new DefaultMcpTransportSession(onClose); - } - - @Override - public void setExceptionHandler(Consumer handler) { - logger.debug("Exception handler registered"); - this.exceptionHandler.set(handler); - } - - private void handleException(Throwable t) { - logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); - if (t instanceof McpTransportSessionNotFoundException) { - McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); - logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); - invalidSession.close(); - } - Consumer handler = this.exceptionHandler.get(); - if (handler != null) { - handler.accept(t); - } - } - - @Override - public Mono closeGracefully() { - return Mono.defer(() -> { - logger.debug("Graceful close triggered"); - DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); - if (currentSession != null) { - return currentSession.closeGracefully(); - } - return Mono.empty(); - }); - } - - private Mono reconnect(McpTransportStream stream) { - return Mono.deferContextual(ctx -> { - if (stream != null) { - logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); - } - else { - logger.debug("Reconnecting with no prior stream"); - } - // Here we attempt to initialize the client. In case the server supports SSE, - // we will establish a long-running - // session here and listen for messages. If it doesn't, that's ok, the server - // is a simple, stateless one. - final AtomicReference disposableRef = new AtomicReference<>(); - final McpTransportSession transportSession = this.activeSession.get(); - - Disposable connection = webClient.get() - .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) - .headers(httpHeaders -> { - transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); - if (stream != null) { - stream.lastId().ifPresent(id -> httpHeaders.add(HttpHeaders.LAST_EVENT_ID, id)); - } - }) - .exchangeToFlux(response -> { - if (isEventStream(response)) { - logger.debug("Established SSE stream via GET"); - return eventStream(stream, response); - } - else if (isNotAllowed(response)) { - logger.debug("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } - else if (isNotFound(response)) { - if (transportSession.sessionId().isPresent()) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - return mcpSessionNotFoundError(sessionIdRepresentation); - } - else { - return this.extractError(response, MISSING_SESSION_ID); - } - } - else { - return response.createError().doOnError(e -> { - logger.info("Opening an SSE stream failed. This can be safely ignored.", e); - }).flux(); - } - }) - .flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) - .onErrorComplete(t -> { - this.handleException(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); - } - }) - .contextWrite(ctx) - .subscribe(); - - disposableRef.set(connection); - transportSession.addConnection(connection); - return Mono.just(connection); - }); - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return Mono.create(sink -> { - logger.debug("Sending message {}", message); - // Here we attempt to initialize the client. - // In case the server supports SSE, we will establish a long-running session - // here and - // listen for messages. - // If it doesn't, nothing actually happens here, that's just the way it is... - final AtomicReference disposableRef = new AtomicReference<>(); - final McpTransportSession transportSession = this.activeSession.get(); - - Disposable connection = webClient.post() - .uri(this.endpoint) - .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) - .headers(httpHeaders -> { - transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); - }) - .bodyValue(message) - .exchangeToFlux(response -> { - if (transportSession - .markInitialized(response.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID))) { - // Once we have a session, we try to open an async stream for - // the server to send notifications and requests out-of-band. - reconnect(null).contextWrite(sink.contextView()).subscribe(); - } - - String sessionRepresentation = sessionIdOrPlaceholder(transportSession); - - // The spec mentions only ACCEPTED, but the existing SDKs can return - // 200 OK for notifications - if (response.statusCode().is2xxSuccessful()) { - Optional contentType = response.headers().contentType(); - // Existing SDKs consume notifications with no response body nor - // content type - if (contentType.isEmpty()) { - logger.trace("Message was successfully sent via POST for session {}", - sessionRepresentation); - // signal the caller that the message was successfully - // delivered - sink.success(); - // communicate to downstream there is no streamed data coming - return Flux.empty(); - } - else { - MediaType mediaType = contentType.get(); - if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { - logger.debug("Established SSE stream via POST"); - // communicate to caller that the message was delivered - sink.success(); - // starting a stream - return newEventStream(response, sessionRepresentation); - } - else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { - logger.trace("Received response to POST for session {}", sessionRepresentation); - // communicate to caller the message was delivered - sink.success(); - return directResponseFlux(message, response); - } - else { - logger.warn("Unknown media type {} returned for POST in session {}", contentType, - sessionRepresentation); - return Flux.error(new RuntimeException("Unknown media type returned: " + contentType)); - } - } - } - else { - if (isNotFound(response) && !sessionRepresentation.equals(MISSING_SESSION_ID)) { - return mcpSessionNotFoundError(sessionRepresentation); - } - return this.extractError(response, sessionRepresentation); - } - }) - .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) - .onErrorComplete(t -> { - // handle the error first - this.handleException(t); - // inform the caller of sendMessage - sink.error(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); - } - }) - .contextWrite(sink.contextView()) - .subscribe(); - disposableRef.set(connection); - transportSession.addConnection(connection); - }); - } - - private static Flux mcpSessionNotFoundError(String sessionRepresentation) { - logger.warn("Session {} was not found on the MCP server", sessionRepresentation); - // inform the stream/connection subscriber - return Flux.error(new McpTransportSessionNotFoundException(sessionRepresentation)); - } - - private Flux extractError(ClientResponse response, String sessionRepresentation) { - return response.createError().onErrorResume(e -> { - WebClientResponseException responseException = (WebClientResponseException) e; - byte[] body = responseException.getResponseBodyAsByteArray(); - McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = null; - Exception toPropagate; - try { - McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, - McpSchema.JSONRPCResponse.class); - jsonRpcError = jsonRpcResponse.error(); - toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) - : new McpTransportException("Can't parse the jsonResponse " + jsonRpcResponse); - } - catch (IOException ex) { - toPropagate = new McpTransportException("Sending request failed, " + e.getMessage(), e); - logger.debug("Received content together with {} HTTP code response: {}", response.statusCode(), body); - } - - // Some implementations can return 400 when presented with a - // session id that it doesn't know about, so we will - // invalidate the session - // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 - if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { - if (!sessionRepresentation.equals(MISSING_SESSION_ID)) { - return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); - } - return Mono.error(new McpTransportException("Received 400 BAD REQUEST for session " - + sessionRepresentation + ". " + toPropagate.getMessage(), toPropagate)); - } - return Mono.error(toPropagate); - }).flux(); - } - - private Flux eventStream(McpTransportStream stream, ClientResponse response) { - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); - - var idWithMessages = response.bodyToFlux(PARAMETERIZED_TYPE_REF).map(this::parse); - return Flux.from(sessionStream.consumeSseStream(idWithMessages)); - } - - private static boolean isNotFound(ClientResponse response) { - return response.statusCode().isSameCodeAs(HttpStatus.NOT_FOUND); - } - - private static boolean isNotAllowed(ClientResponse response) { - return response.statusCode().isSameCodeAs(HttpStatus.METHOD_NOT_ALLOWED); - } - - private static boolean isEventStream(ClientResponse response) { - return response.statusCode().is2xxSuccessful() && response.headers().contentType().isPresent() - && response.headers().contentType().get().isCompatibleWith(MediaType.TEXT_EVENT_STREAM); - } - - private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { - return transportSession.sessionId().orElse(MISSING_SESSION_ID); - } - - private Flux directResponseFlux(McpSchema.JSONRPCMessage sentMessage, - ClientResponse response) { - return response.bodyToMono(String.class).>handle((responseMessage, s) -> { - try { - if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(responseMessage)) { - logger.warn("Notification: {} received non-compliant response: {}", sentMessage, responseMessage); - s.complete(); - } - else { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, - responseMessage); - s.next(List.of(jsonRpcResponse)); - } - } - catch (IOException e) { - s.error(new McpTransportException(e)); - } - }).flatMapIterable(Function.identity()); - } - - private Flux newEventStream(ClientResponse response, String sessionRepresentation) { - McpTransportStream sessionStream = new DefaultMcpTransportStream<>(this.resumableStreams, - this::reconnect); - logger.trace("Sent POST and opened a stream ({}) for session {}", sessionStream.streamId(), - sessionRepresentation); - return eventStream(sessionStream, response); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - private Tuple2, Iterable> parse(ServerSentEvent event) { - if (MESSAGE_EVENT_TYPE.equals(event.event())) { - try { - // We don't support batching ATM and probably won't since the next version - // considers removing it. - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); - return Tuples.of(Optional.ofNullable(event.id()), List.of(message)); - } - catch (IOException ioException) { - throw new McpTransportException("Error parsing JSON-RPC message: " + event.data(), ioException); - } - } - else { - logger.debug("Received SSE event with type: {}", event); - return Tuples.of(Optional.empty(), List.of()); - } - } - - /** - * Builder for {@link WebClientStreamableHttpTransport}. - */ - public static class Builder { - - private ObjectMapper objectMapper; - - private WebClient.Builder webClientBuilder; - - private String endpoint = DEFAULT_ENDPOINT; - - private boolean resumableStreams = true; - - private boolean openConnectionOnStartup = false; - - private Builder(WebClient.Builder webClientBuilder) { - Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); - this.webClientBuilder = webClientBuilder; - } - - /** - * Configure the {@link ObjectMapper} to use. - * @param objectMapper instance to use - * @return the builder instance - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Configure the {@link WebClient.Builder} to construct the {@link WebClient}. - * @param webClientBuilder instance to use - * @return the builder instance - */ - public Builder webClientBuilder(WebClient.Builder webClientBuilder) { - Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); - this.webClientBuilder = webClientBuilder; - return this; - } - - /** - * Configure the endpoint to make HTTP requests against. - * @param endpoint endpoint to use - * @return the builder instance - */ - public Builder endpoint(String endpoint) { - Assert.hasText(endpoint, "endpoint must be a non-empty String"); - this.endpoint = endpoint; - return this; - } - - /** - * Configure whether to use the stream resumability feature by keeping track of - * SSE event ids. - * @param resumableStreams if {@code true} event ids will be tracked and upon - * disconnection, the last seen id will be used upon reconnection as a header to - * resume consuming messages. - * @return the builder instance - */ - public Builder resumableStreams(boolean resumableStreams) { - this.resumableStreams = resumableStreams; - return this; - } - - /** - * Configure whether the client should open an SSE connection upon startup. Not - * all servers support this (although it is in theory possible with the current - * specification), so use with caution. By default, this value is {@code false}. - * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} - * method call will try to open an SSE connection before sending any JSON-RPC - * request - * @return the builder instance - */ - public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { - this.openConnectionOnStartup = openConnectionOnStartup; - return this; - } - - /** - * Construct a fresh instance of {@link WebClientStreamableHttpTransport} using - * the current builder configuration. - * @return a new instance of {@link WebClientStreamableHttpTransport} - */ - public WebClientStreamableHttpTransport build() { - ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - - return new WebClientStreamableHttpTransport(objectMapper, this.webClientBuilder, endpoint, resumableStreams, - openConnectionOnStartup); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java deleted file mode 100644 index 51d21d18b..000000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ /dev/null @@ -1,423 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.io.IOException; -import java.util.List; -import java.util.function.BiConsumer; -import java.util.function.Function; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.core.publisher.SynchronousSink; -import reactor.core.scheduler.Schedulers; -import reactor.util.retry.Retry; -import reactor.util.retry.Retry.RetrySignal; - -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.client.WebClient; - -/** - * Server-Sent Events (SSE) implementation of the - * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE - * transport specification. - * - *

- * This transport establishes a bidirectional communication channel where: - *

    - *
  • Inbound messages are received through an SSE connection from the server
  • - *
  • Outbound messages are sent via HTTP POST requests to a server-provided - * endpoint
  • - *
- * - *

- * The message flow follows these steps: - *

    - *
  1. The client establishes an SSE connection to the server's /sse endpoint
  2. - *
  3. The server sends an 'endpoint' event containing the URI for sending messages
  4. - *
- * - * This implementation uses {@link WebClient} for HTTP communications and supports JSON - * serialization/deserialization of messages. - * - * @author Christian Tzolov - * @see MCP - * HTTP with SSE Transport Specification - */ -public class WebFluxSseClientTransport implements McpClientTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); - - private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2024_11_05; - - /** - * Event type for JSON-RPC messages received through the SSE connection. The server - * sends messages with this event type to transmit JSON-RPC protocol data. - */ - private static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for receiving the message endpoint URI from the server. The server MUST - * send this event when a client connects, providing the URI where the client should - * send its messages via HTTP POST. - */ - private static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. This - * endpoint is used to establish the SSE connection with the server. - */ - private static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - /** - * Type reference for parsing SSE events containing string data. - */ - private static final ParameterizedTypeReference> SSE_TYPE = new ParameterizedTypeReference<>() { - }; - - /** - * WebClient instance for handling both SSE connections and HTTP POST requests. Used - * for establishing the SSE connection and sending outbound messages. - */ - private final WebClient webClient; - - /** - * ObjectMapper for serializing outbound messages and deserializing inbound messages. - * Handles conversion between JSON-RPC messages and their string representation. - */ - protected ObjectMapper objectMapper; - - /** - * Subscription for the SSE connection handling inbound messages. Used for cleanup - * during transport shutdown. - */ - private Disposable inboundSubscription; - - /** - * Flag indicating if the transport is in the process of shutting down. Used to - * prevent new operations during shutdown and handle cleanup gracefully. - */ - private volatile boolean isClosing = false; - - /** - * Sink for managing the message endpoint URI provided by the server. Stores the most - * recent endpoint URI and makes it available for outbound message processing. - */ - protected final Sinks.One messageEndpointSink = Sinks.one(); - - /** - * The SSE endpoint URI provided by the server. Used for sending outbound messages via - * HTTP POST requests. - */ - private String sseEndpoint; - - /** - * Constructs a new SseClientTransport with the specified WebClient builder. Uses a - * default ObjectMapper instance for JSON processing. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @throws IllegalArgumentException if webClientBuilder is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder) { - this(webClientBuilder, new ObjectMapper()); - } - - /** - * Constructs a new SseClientTransport with the specified WebClient builder and - * ObjectMapper. Initializes both inbound and outbound message processing pipelines. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @param objectMapper the ObjectMapper to use for JSON processing - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - this(webClientBuilder, objectMapper, DEFAULT_SSE_ENDPOINT); - } - - /** - * Constructs a new SseClientTransport with the specified WebClient builder and - * ObjectMapper. Initializes both inbound and outbound message processing pipelines. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @param objectMapper the ObjectMapper to use for JSON processing - * @param sseEndpoint the SSE endpoint URI to use for establishing the connection - * @throws IllegalArgumentException if either parameter is null - */ - public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper, - String sseEndpoint) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); - Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty"); - - this.objectMapper = objectMapper; - this.webClient = webClientBuilder.build(); - this.sseEndpoint = sseEndpoint; - } - - @Override - public List protocolVersions() { - return List.of(MCP_PROTOCOL_VERSION); - } - - /** - * Establishes a connection to the MCP server using Server-Sent Events (SSE). This - * method initiates the SSE connection and sets up the message processing pipeline. - * - *

- * The connection process follows these steps: - *

    - *
  1. Establishes an SSE connection to the server's /sse endpoint
  2. - *
  3. Waits for the server to send an 'endpoint' event with the message posting - * URI
  4. - *
  5. Sets up message handling for incoming JSON-RPC messages
  6. - *
- * - *

- * The connection is considered established only after receiving the endpoint event - * from the server. - * @param handler a function that processes incoming JSON-RPC messages and returns - * responses - * @return a Mono that completes when the connection is fully established - */ - @Override - public Mono connect(Function, Mono> handler) { - // TODO: Avoid eager connection opening and enable resilience - // -> upon disconnects, re-establish connection - // -> allow optimizing for eager connection start using a constructor flag - Flux> events = eventStream(); - this.inboundSubscription = events.concatMap(event -> Mono.just(event).handle((e, s) -> { - if (ENDPOINT_EVENT_TYPE.equals(event.event())) { - String messageEndpointUri = event.data(); - if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { - s.complete(); - } - else { - // TODO: clarify with the spec if multiple events can be - // received - s.error(new RuntimeException("Failed to handle SSE endpoint event")); - } - } - else if (MESSAGE_EVENT_TYPE.equals(event.event())) { - try { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, event.data()); - s.next(message); - } - catch (IOException ioException) { - s.error(ioException); - } - } - else { - logger.debug("Received unrecognized SSE event type: {}", event); - s.complete(); - } - }).transform(handler)).subscribe(); - - // The connection is established once the server sends the endpoint event - return messageEndpointSink.asMono().then(); - } - - /** - * Sends a JSON-RPC message to the server using the endpoint provided during - * connection. - * - *

- * Messages are sent via HTTP POST requests to the server-provided endpoint URI. The - * message is serialized to JSON before transmission. If the transport is in the - * process of closing, the message send operation is skipped gracefully. - * @param message the JSON-RPC message to send - * @return a Mono that completes when the message has been sent successfully - * @throws RuntimeException if message serialization fails - */ - @Override - public Mono sendMessage(JSONRPCMessage message) { - // The messageEndpoint is the endpoint URI to send the messages - // It is provided by the server as part of the endpoint event - return messageEndpointSink.asMono().flatMap(messageEndpointUri -> { - if (isClosing) { - return Mono.empty(); - } - try { - String jsonText = this.objectMapper.writeValueAsString(message); - return webClient.post() - .uri(messageEndpointUri) - .contentType(MediaType.APPLICATION_JSON) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) - .bodyValue(jsonText) - .retrieve() - .toBodilessEntity() - .doOnSuccess(response -> { - logger.debug("Message sent successfully"); - }) - .doOnError(error -> { - if (!isClosing) { - logger.error("Error sending message: {}", error.getMessage()); - } - }); - } - catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); - } - }).then(); // TODO: Consider non-200-ok response - } - - /** - * Initializes and starts the inbound SSE event processing. Establishes the SSE - * connection and sets up event handling for both message and endpoint events. - * Includes automatic retry logic for handling transient connection failures. - */ - // visible for tests - protected Flux> eventStream() {// @formatter:off - return this.webClient - .get() - .uri(this.sseEndpoint) - .accept(MediaType.TEXT_EVENT_STREAM) - .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) - .retrieve() - .bodyToFlux(SSE_TYPE) - .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); - } // @formatter:on - - /** - * Retry handler for the inbound SSE stream. Implements the retry logic for handling - * connection failures and other errors. - */ - private BiConsumer> inboundRetryHandler = (retrySpec, sink) -> { - if (isClosing) { - logger.debug("SSE connection closed during shutdown"); - sink.error(retrySpec.failure()); - return; - } - if (retrySpec.failure() instanceof IOException) { - logger.debug("Retrying SSE connection after IO error"); - sink.next(retrySpec); - return; - } - logger.error("Fatal SSE error, not retrying: {}", retrySpec.failure().getMessage()); - sink.error(retrySpec.failure()); - }; - - /** - * Implements graceful shutdown of the transport. Cleans up all resources including - * subscriptions and schedulers. Ensures orderly shutdown of both inbound and outbound - * message processing. - * @return a Mono that completes when shutdown is finished - */ - @Override - public Mono closeGracefully() { // @formatter:off - return Mono.fromRunnable(() -> { - isClosing = true; - - // Dispose of subscriptions - - if (inboundSubscription != null) { - inboundSubscription.dispose(); - } - - }) - .then() - .subscribeOn(Schedulers.boundedElastic()); - } // @formatter:on - - /** - * Unmarshalls data from a generic Object into the specified type using the configured - * ObjectMapper. - * - *

- * This method is particularly useful when working with JSON-RPC parameters or result - * objects that need to be converted to specific Java types. It leverages Jackson's - * type conversion capabilities to handle complex object structures. - * @param the target type to convert the data into - * @param data the source object to convert - * @param typeRef the TypeReference describing the target type - * @return the unmarshalled object of type T - * @throws IllegalArgumentException if the conversion cannot be performed - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - - /** - * Creates a new builder for {@link WebFluxSseClientTransport}. - * @param webClientBuilder the WebClient.Builder to use for creating the WebClient - * instance - * @return a new builder instance - */ - public static Builder builder(WebClient.Builder webClientBuilder) { - return new Builder(webClientBuilder); - } - - /** - * Builder for {@link WebFluxSseClientTransport}. - */ - public static class Builder { - - private final WebClient.Builder webClientBuilder; - - private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - - private ObjectMapper objectMapper = new ObjectMapper(); - - /** - * Creates a new builder with the specified WebClient.Builder. - * @param webClientBuilder the WebClient.Builder to use - */ - public Builder(WebClient.Builder webClientBuilder) { - Assert.notNull(webClientBuilder, "WebClient.Builder must not be null"); - this.webClientBuilder = webClientBuilder; - } - - /** - * Sets the SSE endpoint path. - * @param sseEndpoint the SSE endpoint path - * @return this builder - */ - public Builder sseEndpoint(String sseEndpoint) { - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - this.sseEndpoint = sseEndpoint; - return this; - } - - /** - * Sets the object mapper for JSON serialization/deserialization. - * @param objectMapper the object mapper - * @return this builder - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Builds a new {@link WebFluxSseClientTransport} instance. - * @return a new transport instance - */ - public WebFluxSseClientTransport build() { - return new WebFluxSseClientTransport(webClientBuilder, objectMapper, sseEndpoint); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java deleted file mode 100644 index ead7380f0..000000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ /dev/null @@ -1,589 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.concurrent.ConcurrentHashMap; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContext; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.KeepAliveScheduler; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Exceptions; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSink; -import reactor.core.publisher.Mono; - -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.server.RouterFunction; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; -import org.springframework.web.reactive.function.server.ServerResponse; - -/** - * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using - * Server-Sent Events (SSE). This implementation provides a bidirectional communication - * channel between MCP clients and servers using HTTP POST for client-to-server messages - * and SSE for server-to-client messages. - * - *

- * Key features: - *

    - *
  • Implements the {@link McpServerTransportProvider} interface that allows managing - * {@link McpServerSession} instances and enabling their communication with the - * {@link McpServerTransport} abstraction.
  • - *
  • Uses WebFlux for non-blocking request handling and SSE support
  • - *
  • Maintains client sessions for reliable message delivery
  • - *
  • Supports graceful shutdown with session cleanup
  • - *
  • Thread-safe message broadcasting to multiple clients
  • - *
- * - *

- * The transport sets up two main endpoints: - *

    - *
  • SSE endpoint (/sse) - For establishing SSE connections with clients
  • - *
  • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
  • - *
- * - *

- * This implementation is thread-safe and can handle multiple concurrent client - * connections. It uses {@link ConcurrentHashMap} for session management and Project - * Reactor's non-blocking APIs for message processing and delivery. - * - * @author Christian Tzolov - * @author Alexandros Pappas - * @author Dariusz JΔ™drzejczyk - * @see McpServerTransport - * @see ServerSentEvent - */ -public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { - - private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); - - /** - * Event type for JSON-RPC messages sent through the SSE connection. - */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for sending the message endpoint URI to clients. - */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - private static final String MCP_PROTOCOL_VERSION = "2025-06-18"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. - */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - public static final String DEFAULT_BASE_URL = ""; - - private final ObjectMapper objectMapper; - - /** - * Base URL for the message endpoint. This is used to construct the full URL for - * clients to send their JSON-RPC messages. - */ - private final String baseUrl; - - private final String messageEndpoint; - - private final String sseEndpoint; - - private final RouterFunction routerFunction; - - private McpServerSession.Factory sessionFactory; - - /** - * Map of active client sessions, keyed by session ID. - */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - private McpTransportContextExtractor contextExtractor; - - /** - * Flag indicating if the transport is shutting down. - */ - private volatile boolean isClosing = false; - - /** - * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is - * set. Disabled by default. - */ - private KeepAliveScheduler keepAliveScheduler; - - /** - * Constructs a new WebFlux SSE server transport provider instance with the default - * SSE endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); - } - - /** - * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param baseUrl webflux message base path - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); - } - - /** - * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param baseUrl webflux message base path - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @param sseEndpoint The SSE endpoint path. Must not be null. - * @param keepAliveInterval The interval for sending keep-alive pings to clients. - * @throws IllegalArgumentException if either parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, - (serverRequest, context) -> context); - } - - /** - * Constructs a new WebFlux SSE server transport provider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of MCP messages. Must not be null. - * @param baseUrl webflux message base path - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages. This endpoint will be communicated to clients during SSE connection - * setup. Must not be null. - * @param sseEndpoint The SSE endpoint path. Must not be null. - * @param keepAliveInterval The interval for sending keep-alive pings to clients. - * @param contextExtractor The context extractor to use for extracting MCP transport - * context from HTTP requests. Must not be null. - * @throws IllegalArgumentException if either parameter is null - */ - private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval, - McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(baseUrl, "Message base path must not be null"); - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - Assert.notNull(contextExtractor, "Context extractor must not be null"); - - this.objectMapper = objectMapper; - this.baseUrl = baseUrl; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - this.contextExtractor = contextExtractor; - this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) - .build(); - - if (keepAliveInterval != null) { - - this.keepAliveScheduler = KeepAliveScheduler - .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) - .initialDelay(keepAliveInterval) - .interval(keepAliveInterval) - .build(); - - this.keepAliveScheduler.start(); - } - } - - @Override - public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05); - } - - @Override - public void setSessionFactory(McpServerSession.Factory sessionFactory) { - this.sessionFactory = sessionFactory; - } - - /** - * Broadcasts a JSON-RPC message to all connected clients through their SSE - * connections. The message is serialized to JSON and sent as a server-sent event to - * each active session. - * - *

- * The method: - *

    - *
  • Serializes the message to JSON
  • - *
  • Creates a server-sent event with the message data
  • - *
  • Attempts to send the event to all active sessions
  • - *
  • Tracks and reports any delivery failures
  • - *
- * @param method The JSON-RPC method to send to clients - * @param params The method parameters to send to clients - * @return A Mono that completes when the message has been sent to all sessions, or - * errors if any session fails to receive the message - */ - @Override - public Mono notifyClients(String method, Object params) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - return Flux.fromIterable(sessions.values()) - .flatMap(session -> session.sendNotification(method, params) - .doOnError( - e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) - .onErrorComplete()) - .then(); - } - - // FIXME: This javadoc makes claims about using isClosing flag but it's not - // actually - // doing that. - /** - * Initiates a graceful shutdown of all the sessions. This method ensures all active - * sessions are properly closed and cleaned up. - * @return A Mono that completes when all sessions have been closed - */ - @Override - public Mono closeGracefully() { - return Flux.fromIterable(sessions.values()) - .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) - .flatMap(McpServerSession::closeGracefully) - .then() - .doOnSuccess(v -> { - logger.debug("Graceful shutdown completed"); - sessions.clear(); - if (this.keepAliveScheduler != null) { - this.keepAliveScheduler.shutdown(); - } - }); - } - - /** - * Returns the WebFlux router function that defines the transport's HTTP endpoints. - * This router function should be integrated into the application's web configuration. - * - *

- * The router function defines two endpoints: - *

    - *
  • GET {sseEndpoint} - For establishing SSE connections
  • - *
  • POST {messageEndpoint} - For receiving client messages
  • - *
- * @return The configured {@link RouterFunction} for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - /** - * Handles new SSE connection requests from clients. Creates a new session for each - * connection and sets up the SSE event stream. - * @param request The incoming server request - * @return A Mono which emits a response with the SSE event stream - */ - private Mono handleSseConnection(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - return ServerResponse.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body(Flux.>create(sink -> { - WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - - McpServerSession session = sessionFactory.create(sessionTransport); - String sessionId = session.getId(); - - logger.debug("Created new SSE connection for session: {}", sessionId); - sessions.put(sessionId, session); - - // Send initial endpoint event - logger.debug("Sending initial endpoint event to session: {}", sessionId); - sink.next(ServerSentEvent.builder() - .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) - .build()); - sink.onCancel(() -> { - logger.debug("Session {} cancelled", sessionId); - sessions.remove(sessionId); - }); - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); - } - - /** - * Handles incoming JSON-RPC messages from clients. Deserializes the message and - * processes it through the configured message handler. - * - *

- * The handler: - *

    - *
  • Deserializes the incoming JSON-RPC message
  • - *
  • Passes it through the message handler chain
  • - *
  • Returns appropriate HTTP responses based on processing results
  • - *
  • Handles various error conditions with appropriate error responses
  • - *
- * @param request The incoming server request containing the JSON-RPC message - * @return A Mono emitting the response indicating the message processing result - */ - private Mono handleMessage(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - if (request.queryParam("sessionId").isEmpty()) { - return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); - } - - McpServerSession session = sessions.get(request.queryParam("sessionId").get()); - - if (session == null) { - return ServerResponse.status(HttpStatus.NOT_FOUND) - .bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get())); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - return request.bodyToMono(String.class).flatMap(body -> { - try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - // TODO: instead of signalling the error, just respond with 200 OK - // - the error is signalled on the SSE connection - // return ServerResponse.ok().build(); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .bodyValue(new McpError(error.getMessage())); - }); - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); - } - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); - } - - private class WebFluxMcpSessionTransport implements McpServerTransport { - - private final FluxSink> sink; - - public WebFluxMcpSessionTransport(FluxSink> sink) { - this.sink = sink; - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return Mono.fromSupplier(() -> { - try { - return objectMapper.writeValueAsString(message); - } - catch (IOException e) { - throw Exceptions.propagate(e); - } - }).doOnNext(jsonText -> { - ServerSentEvent event = ServerSentEvent.builder() - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); - sink.next(event); - }).doOnError(e -> { - // TODO log with sessionid - Throwable exception = Exceptions.unwrap(e); - sink.error(exception); - }).then(); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); - } - - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(sink::complete); - } - - @Override - public void close() { - sink.complete(); - } - - } - - public static Builder builder() { - return new Builder(); - } - - /** - * Builder for creating instances of {@link WebFluxSseServerTransportProvider}. - *

- * This builder provides a fluent API for configuring and creating instances of - * WebFluxSseServerTransportProvider with custom settings. - */ - public static class Builder { - - private ObjectMapper objectMapper; - - private String baseUrl = DEFAULT_BASE_URL; - - private String messageEndpoint; - - private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - - private Duration keepAliveInterval; - - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; - - /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP - * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Sets the project basePath as endpoint prefix where clients should send their - * JSON-RPC messages - * @param baseUrl the message basePath . Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if basePath is null - */ - public Builder basePath(String baseUrl) { - Assert.notNull(baseUrl, "basePath must not be null"); - this.baseUrl = baseUrl; - return this; - } - - /** - * Sets the endpoint URI where clients should send their JSON-RPC messages. - * @param messageEndpoint The message endpoint URI. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if messageEndpoint is null - */ - public Builder messageEndpoint(String messageEndpoint) { - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - this.messageEndpoint = messageEndpoint; - return this; - } - - /** - * Sets the SSE endpoint path. - * @param sseEndpoint The SSE endpoint path. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if sseEndpoint is null - */ - public Builder sseEndpoint(String sseEndpoint) { - Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - this.sseEndpoint = sseEndpoint; - return this; - } - - /** - * Sets the interval for sending keep-alive pings to clients. - * @param keepAliveInterval The keep-alive interval duration. If null, keep-alive - * is disabled. - * @return this builder instance - */ - public Builder keepAliveInterval(Duration keepAliveInterval) { - this.keepAliveInterval = keepAliveInterval; - return this; - } - - /** - * Sets the context extractor that allows providing the MCP feature - * implementations to inspect HTTP transport level metadata that was present at - * HTTP request processing time. This allows to extract custom headers and other - * useful data for use during execution later on in the process. - * @param contextExtractor The contextExtractor to fill in a - * {@link McpTransportContext}. - * @return this builder instance - * @throws IllegalArgumentException if contextExtractor is null - */ - public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { - Assert.notNull(contextExtractor, "contextExtractor must not be null"); - this.contextExtractor = contextExtractor; - return this; - } - - /** - * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the - * configured settings. - * @return A new WebFluxSseServerTransportProvider instance - * @throws IllegalStateException if required parameters are not set - */ - public WebFluxSseServerTransportProvider build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); - Assert.notNull(messageEndpoint, "Message endpoint must be set"); - - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval, contextExtractor); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java deleted file mode 100644 index 23fff25b3..000000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.McpStatelessServerHandler; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpStatelessServerTransport; -import io.modelcontextprotocol.server.McpTransportContext; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.web.reactive.function.server.RouterFunction; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; -import org.springframework.web.reactive.function.server.ServerResponse; -import reactor.core.publisher.Mono; - -import java.io.IOException; -import java.util.List; - -/** - * Implementation of a WebFlux based {@link McpStatelessServerTransport}. - * - * @author Dariusz JΔ™drzejczyk - */ -public class WebFluxStatelessServerTransport implements McpStatelessServerTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebFluxStatelessServerTransport.class); - - private final ObjectMapper objectMapper; - - private final String mcpEndpoint; - - private final RouterFunction routerFunction; - - private McpStatelessServerHandler mcpHandler; - - private McpTransportContextExtractor contextExtractor; - - private volatile boolean isClosing = false; - - private WebFluxStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); - Assert.notNull(contextExtractor, "contextExtractor must not be null"); - - this.objectMapper = objectMapper; - this.mcpEndpoint = mcpEndpoint; - this.contextExtractor = contextExtractor; - this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) - .build(); - } - - @Override - public void setMcpHandler(McpStatelessServerHandler mcpHandler) { - this.mcpHandler = mcpHandler; - } - - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> this.isClosing = true); - } - - /** - * Returns the WebFlux router function that defines the transport's HTTP endpoints. - * This router function should be integrated into the application's web configuration. - * - *

- * The router function defines one endpoint handling two HTTP methods: - *

    - *
  • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
  • - *
  • POST {messageEndpoint} - For handling client requests and notifications
  • - *
- * @return The configured {@link RouterFunction} for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - private Mono handleGet(ServerRequest request) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); - } - - private Mono handlePost(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) - && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { - return ServerResponse.badRequest().build(); - } - - return request.bodyToMono(String.class).flatMap(body -> { - try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - - if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { - return this.mcpHandler.handleRequest(transportContext, jsonrpcRequest) - .flatMap(jsonrpcResponse -> ServerResponse.ok() - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(jsonrpcResponse)); - } - else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { - return this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) - .then(ServerResponse.accepted().build()); - } - else { - return ServerResponse.badRequest() - .bodyValue(new McpError("The server accepts either requests or notifications")); - } - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); - } - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); - } - - /** - * Create a builder for the server. - * @return a fresh {@link Builder} instance. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * Builder for creating instances of {@link WebFluxStatelessServerTransport}. - *

- * This builder provides a fluent API for configuring and creating instances of - * WebFluxSseServerTransportProvider with custom settings. - */ - public static class Builder { - - private ObjectMapper objectMapper; - - private String mcpEndpoint = "/mcp"; - - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; - - private Builder() { - // used by a static method - } - - /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP - * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Sets the endpoint URI where clients should send their JSON-RPC messages. - * @param messageEndpoint The message endpoint URI. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if messageEndpoint is null - */ - public Builder messageEndpoint(String messageEndpoint) { - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - this.mcpEndpoint = messageEndpoint; - return this; - } - - /** - * Sets the context extractor that allows providing the MCP feature - * implementations to inspect HTTP transport level metadata that was present at - * HTTP request processing time. This allows to extract custom headers and other - * useful data for use during execution later on in the process. - * @param contextExtractor The contextExtractor to fill in a - * {@link McpTransportContext}. - * @return this builder instance - * @throws IllegalArgumentException if contextExtractor is null - */ - public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { - Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.contextExtractor = contextExtractor; - return this; - } - - /** - * Builds a new instance of {@link WebFluxStatelessServerTransport} with the - * configured settings. - * @return A new WebFluxSseServerTransportProvider instance - * @throws IllegalStateException if required parameters are not set - */ - public WebFluxStatelessServerTransport build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); - Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new WebFluxStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java deleted file mode 100644 index 963a50249..000000000 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ /dev/null @@ -1,494 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpStreamableServerSession; -import io.modelcontextprotocol.spec.McpStreamableServerTransport; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.server.McpTransportContext; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.KeepAliveScheduler; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.server.RouterFunction; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; -import org.springframework.web.reactive.function.server.ServerResponse; -import reactor.core.Disposable; -import reactor.core.Exceptions; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSink; -import reactor.core.publisher.Mono; - -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.concurrent.ConcurrentHashMap; - -/** - * Implementation of a WebFlux based {@link McpStreamableServerTransportProvider}. - * - * @author Dariusz JΔ™drzejczyk - */ -public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider { - - private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class); - - public static final String MESSAGE_EVENT_TYPE = "message"; - - private final ObjectMapper objectMapper; - - private final String mcpEndpoint; - - private final boolean disallowDelete; - - private final RouterFunction routerFunction; - - private McpStreamableServerSession.Factory sessionFactory; - - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - private McpTransportContextExtractor contextExtractor; - - private volatile boolean isClosing = false; - - private KeepAliveScheduler keepAliveScheduler; - - private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor, boolean disallowDelete, - Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); - Assert.notNull(contextExtractor, "Context extractor must not be null"); - - this.objectMapper = objectMapper; - this.mcpEndpoint = mcpEndpoint; - this.contextExtractor = contextExtractor; - this.disallowDelete = disallowDelete; - this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) - .DELETE(this.mcpEndpoint, this::handleDelete) - .build(); - - if (keepAliveInterval != null) { - this.keepAliveScheduler = KeepAliveScheduler - .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) - .initialDelay(keepAliveInterval) - .interval(keepAliveInterval) - .build(); - - this.keepAliveScheduler.start(); - } - } - - @Override - public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); - } - - @Override - public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { - this.sessionFactory = sessionFactory; - } - - @Override - public Mono notifyClients(String method, Object params) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - return Flux.fromIterable(sessions.values()) - .flatMap(session -> session.sendNotification(method, params) - .doOnError( - e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) - .onErrorComplete()) - .then(); - } - - @Override - public Mono closeGracefully() { - return Mono.defer(() -> { - this.isClosing = true; - return Flux.fromIterable(sessions.values()) - .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) - .flatMap(McpStreamableServerSession::closeGracefully) - .then(); - }).then().doOnSuccess(v -> { - sessions.clear(); - if (this.keepAliveScheduler != null) { - this.keepAliveScheduler.shutdown(); - } - }); - } - - /** - * Returns the WebFlux router function that defines the transport's HTTP endpoints. - * This router function should be integrated into the application's web configuration. - * - *

- * The router function defines one endpoint with three methods: - *

    - *
  • GET {messageEndpoint} - For the client listening SSE stream
  • - *
  • POST {messageEndpoint} - For receiving client messages
  • - *
  • DELETE {messageEndpoint} - For removing sessions
  • - *
- * @return The configured {@link RouterFunction} for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - /** - * Opens the listening SSE streams for clients. - * @param request The incoming server request - * @return A Mono which emits a response with the SSE event stream - */ - private Mono handleGet(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - return Mono.defer(() -> { - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { - return ServerResponse.badRequest().build(); - } - - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { - return ServerResponse.badRequest().build(); // TODO: say we need a session - // id - } - - String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - - McpStreamableServerSession session = this.sessions.get(sessionId); - - if (session == null) { - return ServerResponse.notFound().build(); - } - - if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { - String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); - return ServerResponse.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body(session.replay(lastId) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), - ServerSentEvent.class); - } - - return ServerResponse.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body(Flux.>create(sink -> { - WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( - sink); - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - sink.onDispose(listeningStream::close); - // TODO Clarify why the outer context is not present in the - // Flux.create sink? - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); - - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); - } - - /** - * Handles incoming JSON-RPC messages from clients. - * @param request The incoming server request containing the JSON-RPC message - * @return A Mono with the response appropriate to a particular Streamable HTTP flow. - */ - private Mono handlePost(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) - && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { - return ServerResponse.badRequest().build(); - } - - return request.bodyToMono(String.class).flatMap(body -> { - try { - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest - && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { - McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), - new TypeReference() { - }); - McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory - .startSession(initializeRequest); - sessions.put(init.session().getId(), init.session()); - return init.initResult().map(initializeResult -> { - McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( - McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null); - try { - return this.objectMapper.writeValueAsString(jsonrpcResponse); - } - catch (IOException e) { - logger.warn("Failed to serialize initResponse", e); - throw Exceptions.propagate(e); - } - }) - .flatMap(initResult -> ServerResponse.ok() - .contentType(MediaType.APPLICATION_JSON) - .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) - .bodyValue(initResult)); - } - - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { - return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); - } - - String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - McpStreamableServerSession session = sessions.get(sessionId); - - if (session == null) { - return ServerResponse.status(HttpStatus.NOT_FOUND) - .bodyValue(new McpError("Session not found: " + sessionId)); - } - - if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { - return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build()); - } - else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { - return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build()); - } - else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { - return ServerResponse.ok() - .contentType(MediaType.TEXT_EVENT_STREAM) - .body(Flux.>create(sink -> { - WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); - Mono stream = session.responseStream(jsonrpcRequest, st); - Disposable streamSubscription = stream.onErrorComplete(err -> { - sink.error(err); - return true; - }).contextWrite(sink.contextView()).subscribe(); - sink.onCancel(streamSubscription); - // TODO Clarify why the outer context is not present in the - // Flux.create sink? - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), - ServerSentEvent.class); - } - else { - return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type")); - } - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); - } - }) - .switchIfEmpty(ServerResponse.badRequest().build()) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); - } - - private Mono handleDelete(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - return Mono.defer(() -> { - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { - return ServerResponse.badRequest().build(); // TODO: say we need a session - // id - } - - if (this.disallowDelete) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); - } - - String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - - McpStreamableServerSession session = this.sessions.get(sessionId); - - if (session == null) { - return ServerResponse.notFound().build(); - } - - return session.delete().then(ServerResponse.ok().build()); - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); - } - - private class WebFluxStreamableMcpSessionTransport implements McpStreamableServerTransport { - - private final FluxSink> sink; - - public WebFluxStreamableMcpSessionTransport(FluxSink> sink) { - this.sink = sink; - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return this.sendMessage(message, null); - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { - return Mono.fromSupplier(() -> { - try { - return objectMapper.writeValueAsString(message); - } - catch (IOException e) { - throw Exceptions.propagate(e); - } - }).doOnNext(jsonText -> { - ServerSentEvent event = ServerSentEvent.builder() - .id(messageId) - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); - sink.next(event); - }).doOnError(e -> { - // TODO log with sessionid - Throwable exception = Exceptions.unwrap(e); - sink.error(exception); - }).then(); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); - } - - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(sink::complete); - } - - @Override - public void close() { - sink.complete(); - } - - } - - public static Builder builder() { - return new Builder(); - } - - /** - * Builder for creating instances of {@link WebFluxStreamableServerTransportProvider}. - *

- * This builder provides a fluent API for configuring and creating instances of - * WebFluxStreamableServerTransportProvider with custom settings. - */ - public static class Builder { - - private ObjectMapper objectMapper; - - private String mcpEndpoint = "/mcp"; - - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; - - private boolean disallowDelete; - - private Duration keepAliveInterval; - - private Builder() { - // used by a static method - } - - /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP - * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Sets the endpoint URI where clients should send their JSON-RPC messages. - * @param messageEndpoint The message endpoint URI. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if messageEndpoint is null - */ - public Builder messageEndpoint(String messageEndpoint) { - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - this.mcpEndpoint = messageEndpoint; - return this; - } - - /** - * Sets the context extractor that allows providing the MCP feature - * implementations to inspect HTTP transport level metadata that was present at - * HTTP request processing time. This allows to extract custom headers and other - * useful data for use during execution later on in the process. - * @param contextExtractor The contextExtractor to fill in a - * {@link McpTransportContext}. - * @return this builder instance - * @throws IllegalArgumentException if contextExtractor is null - */ - public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { - Assert.notNull(contextExtractor, "contextExtractor must not be null"); - this.contextExtractor = contextExtractor; - return this; - } - - /** - * Sets whether the session removal capability is disabled. - * @param disallowDelete if {@code true}, the DELETE endpoint will not be - * supported and sessions won't be deleted. - * @return this builder instance - */ - public Builder disallowDelete(boolean disallowDelete) { - this.disallowDelete = disallowDelete; - return this; - } - - /** - * Sets the keep-alive interval for the server transport. - * @param keepAliveInterval The interval for sending keep-alive messages. If null, - * no keep-alive will be scheduled. - * @return this builder instance - */ - public Builder keepAliveInterval(Duration keepAliveInterval) { - this.keepAliveInterval = keepAliveInterval; - return this; - } - - /** - * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with - * the configured settings. - * @return A new WebFluxStreamableServerTransportProvider instance - * @throws IllegalStateException if required parameters are not set - */ - public WebFluxStreamableServerTransportProvider build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); - Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new WebFluxStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor, - disallowDelete, keepAliveInterval); - } - - } - -} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java deleted file mode 100644 index c8dc6e90b..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ - -package io.modelcontextprotocol; - -import java.time.Duration; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.server.TestUtil; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -@Timeout(15) -class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; - - private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; - - private DisposableServer httpServer; - - private WebFluxSseServerTransportProvider mcpServerTransportProvider; - - static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r, tc) -> { - tc.put("important", "value"); - return tc; - }; - - @Override - protected void prepareClients(int port, String mcpEndpoint) { - - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build()).requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient - .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build()) - .requestTimeout(Duration.ofHours(10))); - - } - - @Override - protected AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(mcpServerTransportProvider); - } - - @Override - protected SingleSessionSyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(mcpServerTransportProvider); - } - - @BeforeEach - public void before() { - - this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .contextExtractor(TEST_CONTEXT_EXTRACTOR) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - prepareClients(PORT, null); - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java deleted file mode 100644 index 5516e55b7..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ - -package io.modelcontextprotocol; - -import java.time.Duration; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; -import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; -import io.modelcontextprotocol.server.TestUtil; -import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -@Timeout(15) -class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; - - private DisposableServer httpServer; - - private WebFluxStatelessServerTransport mcpStreamableServerTransport; - - @Override - protected void prepareClients(int port, String mcpEndpoint) { - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); - clientBuilders - .put("webflux", McpClient - .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build()) - .initializationTimeout(Duration.ofHours(10)) - .requestTimeout(Duration.ofHours(10))); - } - - @Override - protected StatelessAsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(this.mcpStreamableServerTransport); - } - - @Override - protected StatelessSyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(this.mcpStreamableServerTransport); - } - - @BeforeEach - public void before() { - this.mcpStreamableServerTransport = WebFluxStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - prepareClients(PORT, null); - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java deleted file mode 100644 index a7aac0f1e..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ - -package io.modelcontextprotocol; - -import java.time.Duration; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; -import org.springframework.web.reactive.function.server.ServerRequest; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SyncSpecification; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.server.TestUtil; -import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -@Timeout(15) -class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; - - private DisposableServer httpServer; - - private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; - - static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r, tc) -> { - tc.put("important", "value"); - return tc; - }; - - @Override - protected void prepareClients(int port, String mcpEndpoint) { - - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build()).requestTimeout(Duration.ofHours(10))); - clientBuilders.put("webflux", - McpClient - .sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(CUSTOM_MESSAGE_ENDPOINT) - .build()) - .requestTimeout(Duration.ofHours(10))); - } - - @Override - protected AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(mcpStreamableServerTransportProvider); - } - - @Override - protected SyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(mcpStreamableServerTransportProvider); - } - - @BeforeEach - public void before() { - - this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .contextExtractor(TEST_CONTEXT_EXTRACTOR) - .build(); - - HttpHandler httpHandler = RouterFunctions - .toHttpHandler(mcpStreamableServerTransportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - prepareClients(PORT, null); - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java deleted file mode 100644 index 191f10376..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; -import org.junit.jupiter.api.Timeout; -import org.springframework.web.reactive.function.client.WebClient; - -@Timeout(15) -public class WebClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { - - @Override - protected McpClientTransport createMcpTransport() { - return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java deleted file mode 100644 index f8a16c153..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import org.junit.jupiter.api.Timeout; -import org.springframework.web.reactive.function.client.WebClient; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; - -@Timeout(15) -public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { - - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - public void onClose() { - container.stop(); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java deleted file mode 100644 index 5e9960d0e..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import org.junit.jupiter.api.Timeout; -import org.springframework.web.reactive.function.client.WebClient; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; - -@Timeout(15) -public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { - - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(host)).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - public void onClose() { - container.stop(); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java deleted file mode 100644 index 0edf4cd54..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import java.time.Duration; - -import org.junit.jupiter.api.Timeout; -import org.springframework.web.reactive.function.client.WebClient; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; - -/** - * Tests for the {@link McpAsyncClient} with {@link WebFluxSseClientTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - public void onClose() { - container.stop(); - } - - protected Duration getInitializationTimeout() { - return Duration.ofSeconds(1); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java deleted file mode 100644 index 9b0959a35..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import java.time.Duration; - -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -import org.springframework.web.reactive.function.client.WebClient; - -/** - * Tests for the {@link McpSyncClient} with {@link WebFluxSseClientTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { - - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - protected void onClose() { - container.stop(); - } - - protected Duration getInitializationTimeout() { - return Duration.ofSeconds(1); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java deleted file mode 100644 index cdbb97e17..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransportErrorHandlingTest.java +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.timeout; -import static org.mockito.Mockito.verify; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.time.Duration; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.springframework.web.reactive.function.client.WebClient; - -import com.sun.net.httpserver.HttpServer; - -import io.modelcontextprotocol.server.TestUtil; -import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransportException; -import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; -import io.modelcontextprotocol.spec.ProtocolVersions; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -/** - * Tests for error handling in WebClientStreamableHttpTransport. Addresses concurrency - * issues with proper Reactor patterns. - * - * @author Christian Tzolov - */ -@Timeout(15) -public class WebClientStreamableHttpTransportErrorHandlingTest { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String HOST = "http://localhost:" + PORT; - - private HttpServer server; - - private AtomicReference serverResponseStatus = new AtomicReference<>(200); - - private AtomicReference currentServerSessionId = new AtomicReference<>(null); - - private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); - - private McpClientTransport transport; - - // Initialize latches for proper request synchronization - CountDownLatch firstRequestLatch; - - CountDownLatch secondRequestLatch; - - CountDownLatch getRequestLatch; - - @BeforeEach - void startServer() throws IOException { - - // Initialize latches for proper synchronization - firstRequestLatch = new CountDownLatch(1); - secondRequestLatch = new CountDownLatch(1); - getRequestLatch = new CountDownLatch(1); - - server = HttpServer.create(new InetSocketAddress(PORT), 0); - - // Configure the /mcp endpoint with dynamic response - server.createContext("/mcp", exchange -> { - String method = exchange.getRequestMethod(); - - if ("GET".equals(method)) { - // This is the SSE connection attempt after session establishment - getRequestLatch.countDown(); - // Return 405 Method Not Allowed to indicate SSE not supported - exchange.sendResponseHeaders(405, 0); - exchange.close(); - return; - } - - String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - lastReceivedSessionId.set(requestSessionId); - - int status = serverResponseStatus.get(); - - // Track which request this is - if (firstRequestLatch.getCount() > 0) { - // // First request - should have no session ID - firstRequestLatch.countDown(); - } - else if (secondRequestLatch.getCount() > 0) { - // Second request - should have session ID - secondRequestLatch.countDown(); - } - - exchange.getResponseHeaders().set("Content-Type", "application/json"); - - // Don't include session ID in 404 and 400 responses - the implementation - // checks if the transport has a session stored locally - String responseSessionId = currentServerSessionId.get(); - if (responseSessionId != null && status == 200) { - exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); - } - if (status == 200) { - String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; - exchange.sendResponseHeaders(200, response.length()); - exchange.getResponseBody().write(response.getBytes()); - } - else { - exchange.sendResponseHeaders(status, 0); - } - exchange.close(); - }); - - server.setExecutor(null); - server.start(); - - transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)).build(); - } - - @AfterEach - void stopServer() { - if (server != null) { - server.stop(0); - } - StepVerifier.create(transport.closeGracefully()).verifyComplete(); - } - - /** - * Test that 404 response WITHOUT session ID throws McpTransportException (not - * SessionNotFoundException) - */ - @Test - void test404WithoutSessionId() { - serverResponseStatus.set(404); - currentServerSessionId.set(null); // No session ID in response - - var testMessage = createTestMessage(); - - StepVerifier.create(transport.sendMessage(testMessage)) - .expectErrorMatches(throwable -> throwable instanceof McpTransportException - && throwable.getMessage().contains("Not Found") && throwable.getMessage().contains("404") - && !(throwable instanceof McpTransportSessionNotFoundException)) - .verify(Duration.ofSeconds(5)); - } - - /** - * Test that 404 response WITH session ID throws McpTransportSessionNotFoundException - * Fixed version using proper async coordination - */ - @Test - void test404WithSessionId() throws InterruptedException { - // First establish a session - serverResponseStatus.set(200); - currentServerSessionId.set("test-session-123"); - - // Set up exception handler to verify session invalidation - @SuppressWarnings("unchecked") - Consumer exceptionHandler = mock(Consumer.class); - transport.setExceptionHandler(exceptionHandler); - - // Connect with handler - StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); - - // Send initial message to establish session - var testMessage = createTestMessage(); - - // Send first message to establish session - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - // Wait for first request to complete - assertThat(firstRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); - - // Wait for the GET request (SSE connection attempt) to complete - assertThat(getRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); - - // Now return 404 for next request - serverResponseStatus.set(404); - - // Use delaySubscription to ensure session is fully processed before next - // request - StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) - .expectError(McpTransportSessionNotFoundException.class) - .verify(Duration.ofSeconds(5)); - - // Wait for second request to be made - assertThat(secondRequestLatch.await(5, TimeUnit.SECONDS)).isTrue(); - - // Verify the second request included the session ID - assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-123"); - - // Verify exception handler was called with SessionNotFoundException using - // timeout - verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); - } - - /** - * Test that 400 response WITHOUT session ID throws McpTransportException (not - * SessionNotFoundException) - */ - @Test - void test400WithoutSessionId() { - serverResponseStatus.set(400); - currentServerSessionId.set(null); // No session ID - - var testMessage = createTestMessage(); - - StepVerifier.create(transport.sendMessage(testMessage)) - .expectErrorMatches(throwable -> throwable instanceof McpTransportException - && throwable.getMessage().contains("Bad Request") && throwable.getMessage().contains("400") - && !(throwable instanceof McpTransportSessionNotFoundException)) - .verify(Duration.ofSeconds(5)); - } - - /** - * Test that 400 response WITH session ID throws McpTransportSessionNotFoundException - * Fixed version using proper async coordination - */ - @Test - void test400WithSessionId() throws InterruptedException { - - // First establish a session - serverResponseStatus.set(200); - currentServerSessionId.set("test-session-456"); - - // Set up exception handler - @SuppressWarnings("unchecked") - Consumer exceptionHandler = mock(Consumer.class); - transport.setExceptionHandler(exceptionHandler); - - // Connect with handler - StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); - - // Send initial message to establish session - var testMessage = createTestMessage(); - - // Send first message to establish session - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - // Wait for first request to complete - boolean firstCompleted = firstRequestLatch.await(5, TimeUnit.SECONDS); - assertThat(firstCompleted).isTrue(); - - // Wait for the GET request (SSE connection attempt) to complete - boolean getCompleted = getRequestLatch.await(5, TimeUnit.SECONDS); - assertThat(getCompleted).isTrue(); - - // Now return 400 for next request (simulating unknown session ID) - serverResponseStatus.set(400); - - // Use delaySubscription to ensure session is fully processed before next - // request - StepVerifier.create(Mono.delay(Duration.ofMillis(200)).then(transport.sendMessage(testMessage))) - .expectError(McpTransportSessionNotFoundException.class) - .verify(Duration.ofSeconds(5)); - - // Wait for second request to be made - boolean secondCompleted = secondRequestLatch.await(5, TimeUnit.SECONDS); - assertThat(secondCompleted).isTrue(); - - // Verify the second request included the session ID - assertThat(lastReceivedSessionId.get()).isEqualTo("test-session-456"); - - // Verify exception handler was called with timeout - verify(exceptionHandler, timeout(5000)).accept(any(McpTransportSessionNotFoundException.class)); - } - - /** - * Test session recovery after SessionNotFoundException Fixed version using reactive - * patterns and proper synchronization - */ - @Test - void testSessionRecoveryAfter404() { - // First establish a session - serverResponseStatus.set(200); - currentServerSessionId.set("session-1"); - - // Send initial message to establish session - var testMessage = createTestMessage(); - - // Use Mono.defer to ensure proper sequencing - Mono establishSession = transport.sendMessage(testMessage).then(Mono.defer(() -> { - // Simulate session loss - return 404 - serverResponseStatus.set(404); - return transport.sendMessage(testMessage).onErrorResume(McpTransportSessionNotFoundException.class, e -> { - // Expected error, continue with recovery - return Mono.empty(); - }); - })).then(Mono.defer(() -> { - // Now server is back with new session - serverResponseStatus.set(200); - currentServerSessionId.set("session-2"); - lastReceivedSessionId.set(null); // Reset to verify new session - - // Should be able to establish new session - return transport.sendMessage(testMessage); - })).then(Mono.defer(() -> { - // Verify no session ID was sent (since old session was invalidated) - assertThat(lastReceivedSessionId.get()).isNull(); - - // Next request should use the new session ID - return transport.sendMessage(testMessage); - })).doOnSuccess(v -> { - // Session ID should now be sent with requests - assertThat(lastReceivedSessionId.get()).isEqualTo("session-2"); - }); - - StepVerifier.create(establishSession).verifyComplete(); - } - - /** - * Test that reconnect (GET request) also properly handles 404/400 errors Fixed - * version with proper async handling - */ - @Test - void testReconnectErrorHandling() throws InterruptedException { - // Initialize latch for SSE connection - CountDownLatch sseConnectionLatch = new CountDownLatch(1); - - // Set up SSE endpoint for GET requests - server.createContext("/mcp-sse", exchange -> { - String method = exchange.getRequestMethod(); - String requestSessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - - if ("GET".equals(method)) { - sseConnectionLatch.countDown(); - int status = serverResponseStatus.get(); - - if (status == 404 && requestSessionId != null) { - // 404 with session ID - should trigger SessionNotFoundException - exchange.sendResponseHeaders(404, 0); - } - else if (status == 404) { - // 404 without session ID - should trigger McpTransportException - exchange.sendResponseHeaders(404, 0); - } - else { - // Normal SSE response - exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); - exchange.sendResponseHeaders(200, 0); - // Send a test SSE event - String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; - exchange.getResponseBody().write(sseData.getBytes()); - } - } - else { - // POST request handling - exchange.getResponseHeaders().set("Content-Type", "application/json"); - String responseSessionId = currentServerSessionId.get(); - if (responseSessionId != null) { - exchange.getResponseHeaders().set(HttpHeaders.MCP_SESSION_ID, responseSessionId); - } - String response = "{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":\"test-id\"}"; - exchange.sendResponseHeaders(200, response.length()); - exchange.getResponseBody().write(response.getBytes()); - } - exchange.close(); - }); - - // Test with session ID - should get SessionNotFoundException - serverResponseStatus.set(200); - currentServerSessionId.set("sse-session-1"); - - var transport = WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl(HOST)) - .endpoint("/mcp-sse") - .openConnectionOnStartup(true) // This will trigger GET request on connect - .build(); - - // First connect successfully - StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); - - // Wait for SSE connection to be established - boolean connected = sseConnectionLatch.await(5, TimeUnit.SECONDS); - assertThat(connected).isTrue(); - - // Send message to establish session - var testMessage = createTestMessage(); - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - // Clean up - StepVerifier.create(transport.closeGracefully()).verifyComplete(); - } - - private McpSchema.JSONRPCRequest createTestMessage() { - var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, - McpSchema.ClientCapabilities.builder().roots(true).build(), - new McpSchema.Implementation("Test Client", "1.0.0")); - return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", - initializeRequest); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java deleted file mode 100644 index 1cf5dffe2..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ /dev/null @@ -1,369 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - -import org.springframework.http.codec.ServerSentEvent; -import org.springframework.web.reactive.function.client.WebClient; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Tests for the {@link WebFluxSseClientTransport} class. - * - * @author Christian Tzolov - */ -@Timeout(15) -class WebFluxSseClientTransportTests { - - static String host = "http://localhost:3001"; - - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - private TestSseClientTransport transport; - - private WebClient.Builder webClientBuilder; - - private ObjectMapper objectMapper; - - // Test class to access protected methods - static class TestSseClientTransport extends WebFluxSseClientTransport { - - private final AtomicInteger inboundMessageCount = new AtomicInteger(0); - - private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - - public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) { - super(webClientBuilder, objectMapper); - } - - @Override - protected Flux> eventStream() { - return super.eventStream().mergeWith(events.asFlux()); - } - - public String getLastEndpoint() { - return messageEndpointSink.asMono().block(); - } - - public int getInboundMessageCount() { - return inboundMessageCount.get(); - } - - public void simulateSseComment(String comment) { - events.tryEmitNext(ServerSentEvent.builder().comment(comment).build()); - inboundMessageCount.incrementAndGet(); - } - - public void simulateEndpointEvent(String jsonMessage) { - events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); - inboundMessageCount.incrementAndGet(); - } - - public void simulateMessageEvent(String jsonMessage) { - events.tryEmitNext(ServerSentEvent.builder().event("message").data(jsonMessage).build()); - inboundMessageCount.incrementAndGet(); - } - - } - - void startContainer() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @BeforeEach - void setUp() { - startContainer(); - webClientBuilder = WebClient.builder().baseUrl(host); - objectMapper = new ObjectMapper(); - transport = new TestSseClientTransport(webClientBuilder, objectMapper); - transport.connect(Function.identity()).block(); - } - - @AfterEach - void afterEach() { - if (transport != null) { - assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - cleanup(); - } - - void cleanup() { - container.stop(); - } - - @Test - void testEndpointEventHandling() { - assertThat(transport.getLastEndpoint()).startsWith("/message?"); - } - - @Test - void constructorValidation() { - assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("WebClient.Builder must not be null"); - - assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ObjectMapper must not be null"); - } - - @Test - void testBuilderPattern() { - // Test default builder - WebFluxSseClientTransport transport1 = WebFluxSseClientTransport.builder(webClientBuilder).build(); - assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException(); - - // Test builder with custom ObjectMapper - ObjectMapper customMapper = new ObjectMapper(); - WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder) - .objectMapper(customMapper) - .build(); - assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException(); - - // Test builder with custom SSE endpoint - WebFluxSseClientTransport transport3 = WebFluxSseClientTransport.builder(webClientBuilder) - .sseEndpoint("/custom-sse") - .build(); - assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException(); - - // Test builder with all custom parameters - WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder) - .objectMapper(customMapper) - .sseEndpoint("/custom-sse") - .build(); - assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); - } - - @Test - void testCommentSseMessage() { - // If the line starts with a character (:) are comment lins and should be ingored - // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation - - CopyOnWriteArrayList droppedErrors = new CopyOnWriteArrayList<>(); - reactor.core.publisher.Hooks.onErrorDropped(droppedErrors::add); - - try { - // Simulate receiving the SSE comment line - transport.simulateSseComment("sse comment"); - - StepVerifier.create(transport.closeGracefully()).verifyComplete(); - - assertThat(droppedErrors).hasSize(0); - } - finally { - reactor.core.publisher.Hooks.resetOnErrorDropped(); - } - } - - @Test - void testMessageProcessing() { - // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Simulate receiving the message - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "test-method", - "id": "test-id", - "params": {"key": "value"} - } - """); - - // Subscribe to messages and verify - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testResponseMessageProcessing() { - // Simulate receiving a response message - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "id": "test-id", - "result": {"status": "success"} - } - """); - - // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Verify message handling - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testErrorMessageProcessing() { - // Simulate receiving an error message - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "id": "test-id", - "error": { - "code": -32600, - "message": "Invalid Request" - } - } - """); - - // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Verify message handling - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testNotificationMessageProcessing() { - // Simulate receiving a notification message (no id) - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "update", - "params": {"status": "processing"} - } - """); - - // Verify the notification was processed - assertThat(transport.getInboundMessageCount()).isEqualTo(1); - } - - @Test - void testGracefulShutdown() { - // Test graceful shutdown - StepVerifier.create(transport.closeGracefully()).verifyComplete(); - - // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", - Map.of("key", "value")); - - // Verify message is not processed after shutdown - StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); - - // Message count should remain 0 after shutdown - assertThat(transport.getInboundMessageCount()).isEqualTo(0); - } - - @Test - void testRetryBehavior() { - // Create a WebClient that simulates connection failures - WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host"); - - WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build(); - - // Verify that the transport attempts to reconnect - StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete(); - - // Clean up - failingTransport.closeGracefully().block(); - } - - @Test - void testMultipleMessageProcessing() { - // Simulate receiving multiple messages in sequence - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "method1", - "id": "id1", - "params": {"key": "value1"} - } - """); - - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "method2", - "id": "id2", - "params": {"key": "value2"} - } - """); - - // Create and send corresponding messages - JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", - Map.of("key", "value1")); - - JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", - Map.of("key", "value2")); - - // Verify both messages are processed - StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete(); - - // Verify message count - assertThat(transport.getInboundMessageCount()).isEqualTo(2); - } - - @Test - void testMessageOrderPreservation() { - // Simulate receiving messages in a specific order - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "first", - "id": "1", - "params": {"sequence": 1} - } - """); - - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "second", - "id": "2", - "params": {"sequence": 2} - } - """); - - transport.simulateMessageEvent(""" - { - "jsonrpc": "2.0", - "method": "third", - "id": "3", - "params": {"sequence": 3} - } - """); - - // Verify message count and order - assertThat(transport.getInboundMessageCount()).isEqualTo(3); - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java deleted file mode 100644 index a3bdf10b0..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import org.junit.jupiter.api.Timeout; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.server.RouterFunctions; - -/** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - private McpServerTransportProvider createMcpTransportProvider() { - var transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - return transportProvider; - } - - @Override - protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(createMcpTransportProvider()); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java deleted file mode 100644 index 3e28e96b8..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import org.junit.jupiter.api.Timeout; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.server.RouterFunctions; - -/** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - private WebFluxSseServerTransportProvider transportProvider; - - @Override - protected McpServer.SyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(createMcpTransportProvider()); - } - - private McpServerTransportProvider createMcpTransportProvider() { - transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); - return transportProvider; - } - - @Override - protected void onStart() { - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java deleted file mode 100644 index 959f2f472..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import org.junit.jupiter.api.Timeout; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.server.RouterFunctions; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -/** - * Tests for {@link McpAsyncServer} using - * {@link WebFluxStreamableServerTransportProvider}. - * - * @author Christian Tzolov - * @author Dariusz JΔ™drzejczyk - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - private McpStreamableServerTransportProvider createMcpTransportProvider() { - var transportProvider = WebFluxStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - return transportProvider; - } - - @Override - protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(createMcpTransportProvider()); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java deleted file mode 100644 index 3396d489c..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import org.junit.jupiter.api.Timeout; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.server.RouterFunctions; -import reactor.netty.DisposableServer; -import reactor.netty.http.server.HttpServer; - -/** - * Tests for {@link McpAsyncServer} using - * {@link WebFluxStreamableServerTransportProvider}. - * - * @author Christian Tzolov - * @author Dariusz JΔ™drzejczyk - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebFluxStreamableMcpSyncServerTests extends AbstractMcpSyncServerTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private DisposableServer httpServer; - - private McpStreamableServerTransportProvider createMcpTransportProvider() { - var transportProvider = WebFluxStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - return transportProvider; - } - - @Override - protected McpServer.SyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(createMcpTransportProvider()); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java deleted file mode 100644 index dfb004e9b..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java +++ /dev/null @@ -1,70 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.io.InputStream; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -public class BlockingInputStream extends InputStream { - - private final BlockingQueue queue = new LinkedBlockingQueue<>(); - - private volatile boolean completed = false; - - private volatile boolean closed = false; - - @Override - public int read() throws IOException { - if (closed) { - throw new IOException("Stream is closed"); - } - - try { - Integer value = queue.poll(); - if (value == null) { - if (completed) { - return -1; - } - value = queue.take(); // Blocks until data is available - if (value == null && completed) { - return -1; - } - } - return value; - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Read interrupted", e); - } - } - - public void write(int b) { - if (!closed && !completed) { - queue.offer(b); - } - } - - public void write(byte[] data) { - if (!closed && !completed) { - for (byte b : data) { - queue.offer((int) b & 0xFF); - } - } - } - - public void complete() { - this.completed = true; - } - - @Override - public void close() { - this.closed = true; - this.completed = true; - this.queue.clear(); - } - -} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml deleted file mode 100644 index abc831d13..000000000 --- a/mcp-spring/mcp-spring-webflux/src/test/resources/logback.xml +++ /dev/null @@ -1,25 +0,0 @@ - - - - - - - %d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n - - - - - - - - - - - - - - - - - - diff --git a/mcp-spring/mcp-spring-webmvc/README.md b/mcp-spring/mcp-spring-webmvc/README.md deleted file mode 100644 index 9adf5b2ee..000000000 --- a/mcp-spring/mcp-spring-webmvc/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# WebMVC SSE Server Transport - -```xml - - io.modelcontextprotocol.sdk - mcp-spring-webmvc - -``` - - - -```java -String MESSAGE_ENDPOINT = "/mcp/message"; - -@Configuration -@EnableWebMvc -static class MyConfig { - - @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); - } -} -``` diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml deleted file mode 100644 index ea262d3a1..000000000 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ /dev/null @@ -1,148 +0,0 @@ - - - 4.0.0 - - io.modelcontextprotocol.sdk - mcp-parent - 0.12.0-SNAPSHOT - ../../pom.xml - - mcp-spring-webmvc - jar - Spring Web MVC transports - Web MVC implementation for the SSE and Streamable Http Server transports - https://github.com/modelcontextprotocol/java-sdk - - - https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git - - - - - io.modelcontextprotocol.sdk - mcp - 0.12.0-SNAPSHOT - - - - org.springframework - spring-webmvc - ${springframework.version} - - - - io.modelcontextprotocol.sdk - mcp-test - 0.12.0-SNAPSHOT - test - - - - io.modelcontextprotocol.sdk - mcp-spring-webflux - 0.12.0-SNAPSHOT - test - - - - - - org.springframework - spring-context - ${springframework.version} - test - - - - org.springframework - spring-test - ${springframework.version} - test - - - - org.assertj - assertj-core - ${assert4j.version} - test - - - org.junit.jupiter - junit-jupiter-api - ${junit.version} - test - - - org.mockito - mockito-core - ${mockito.version} - test - - - net.bytebuddy - byte-buddy - ${byte-buddy.version} - test - - - org.testcontainers - junit-jupiter - ${testcontainers.version} - test - - - - org.awaitility - awaitility - ${awaitility.version} - test - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - io.projectreactor.netty - reactor-netty-http - test - - - io.projectreactor - reactor-test - test - - - jakarta.servlet - jakarta.servlet-api - ${jakarta.servlet.version} - provided - - - - org.apache.tomcat.embed - tomcat-embed-core - ${tomcat.version} - test - - - - net.javacrumbs.json-unit - json-unit-assertj - ${json-unit-assertj.version} - test - - - - - - diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java deleted file mode 100644 index 6e92cf10c..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ /dev/null @@ -1,649 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.locks.ReentrantLock; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContext; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.KeepAliveScheduler; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import org.springframework.http.HttpStatus; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.RouterFunctions; -import org.springframework.web.servlet.function.ServerRequest; -import org.springframework.web.servlet.function.ServerResponse; -import org.springframework.web.servlet.function.ServerResponse.SseBuilder; - -/** - * Server-side implementation of the Model Context Protocol (MCP) transport layer using - * HTTP with Server-Sent Events (SSE) through Spring WebMVC. This implementation provides - * a bridge between synchronous WebMVC operations and reactive programming patterns to - * maintain compatibility with the reactive transport interface. - * - *

- * Key features: - *

    - *
  • Implements bidirectional communication using HTTP POST for client-to-server - * messages and SSE for server-to-client messages
  • - *
  • Manages client sessions with unique IDs for reliable message delivery
  • - *
  • Supports graceful shutdown with proper session cleanup
  • - *
  • Provides JSON-RPC message handling through configured endpoints
  • - *
  • Includes built-in error handling and logging
  • - *
- * - *

- * The transport operates on two main endpoints: - *

    - *
  • {@code /sse} - The SSE endpoint where clients establish their event stream - * connection
  • - *
  • A configurable message endpoint where clients send their JSON-RPC messages via HTTP - * POST
  • - *
- * - *

- * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client - * sessions in a thread-safe manner. Each client session is assigned a unique ID and - * maintains its own SSE connection. - * - * @author Christian Tzolov - * @author Alexandros Pappas - * @see McpServerTransportProvider - * @see RouterFunction - */ -public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { - - private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class); - - /** - * Event type for JSON-RPC messages sent through the SSE connection. - */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for sending the message endpoint URI to clients. - */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default SSE endpoint path as specified by the MCP transport specification. - */ - public static final String DEFAULT_SSE_ENDPOINT = "/sse"; - - private final ObjectMapper objectMapper; - - private final String messageEndpoint; - - private final String sseEndpoint; - - private final String baseUrl; - - private final RouterFunction routerFunction; - - private McpServerSession.Factory sessionFactory; - - /** - * Map of active client sessions, keyed by session ID. - */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - private McpTransportContextExtractor contextExtractor; - - /** - * Flag indicating if the transport is shutting down. - */ - private volatile boolean isClosing = false; - - private KeepAliveScheduler keepAliveScheduler; - - /** - * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); - } - - /** - * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @throws IllegalArgumentException if any parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); - } - - /** - * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param baseUrl The base URL for the message endpoint, used to construct the full - * endpoint URL for clients. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @throws IllegalArgumentException if any parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); - } - - /** - * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param baseUrl The base URL for the message endpoint, used to construct the full - * endpoint URL for clients. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @param keepAliveInterval The interval for sending keep-alive messages to clients. - * @throws IllegalArgumentException if any parameter is null - * @deprecated Use the builder {@link #builder()} instead for better configuration - * options. - */ - @Deprecated - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, - (serverRequest, context) -> context); - } - - /** - * Constructs a new WebMvcSseServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param baseUrl The base URL for the message endpoint, used to construct the full - * endpoint URL for clients. - * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP POST. This endpoint will be communicated to clients through the - * SSE connection's initial endpoint event. - * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @param keepAliveInterval The interval for sending keep-alive messages to clients. - * @param contextExtractor The contextExtractor to fill in a - * {@link McpTransportContext}. - * @throws IllegalArgumentException if any parameter is null - */ - private WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, Duration keepAliveInterval, - McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(baseUrl, "Message base URL must not be null"); - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - Assert.notNull(contextExtractor, "Context extractor must not be null"); - - this.objectMapper = objectMapper; - this.baseUrl = baseUrl; - this.messageEndpoint = messageEndpoint; - this.sseEndpoint = sseEndpoint; - this.contextExtractor = contextExtractor; - this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) - .build(); - - if (keepAliveInterval != null) { - - this.keepAliveScheduler = KeepAliveScheduler - .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) - .initialDelay(keepAliveInterval) - .interval(keepAliveInterval) - .build(); - - this.keepAliveScheduler.start(); - } - } - - @Override - public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05); - } - - @Override - public void setSessionFactory(McpServerSession.Factory sessionFactory) { - this.sessionFactory = sessionFactory; - } - - /** - * Broadcasts a notification to all connected clients through their SSE connections. - * The message is serialized to JSON and sent as an SSE event with type "message". If - * any errors occur during sending to a particular client, they are logged but don't - * prevent sending to other clients. - * @param method The method name for the notification - * @param params The parameters for the notification - * @return A Mono that completes when the broadcast attempt is finished - */ - @Override - public Mono notifyClients(String method, Object params) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - return Flux.fromIterable(sessions.values()) - .flatMap(session -> session.sendNotification(method, params) - .doOnError( - e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) - .onErrorComplete()) - .then(); - } - - /** - * Initiates a graceful shutdown of the transport. This method: - *

    - *
  • Sets the closing flag to prevent new connections
  • - *
  • Closes all active SSE connections
  • - *
  • Removes all session records
  • - *
- * @return A Mono that completes when all cleanup operations are finished - */ - @Override - public Mono closeGracefully() { - return Flux.fromIterable(sessions.values()).doFirst(() -> { - this.isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - }).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { - logger.debug("Graceful shutdown completed"); - sessions.clear(); - if (this.keepAliveScheduler != null) { - this.keepAliveScheduler.shutdown(); - } - }); - } - - /** - * Returns the RouterFunction that defines the HTTP endpoints for this transport. The - * router function handles two endpoints: - *
    - *
  • GET /sse - For establishing SSE connections
  • - *
  • POST [messageEndpoint] - For receiving JSON-RPC messages from clients
  • - *
- * @return The configured RouterFunction for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - /** - * Handles new SSE connection requests from clients by creating a new session and - * establishing an SSE connection. This method: - *
    - *
  • Generates a unique session ID
  • - *
  • Creates a new session with a WebMvcMcpSessionTransport
  • - *
  • Sends an initial endpoint event to inform the client where to send - * messages
  • - *
  • Maintains the session in the sessions map
  • - *
- * @param request The incoming server request - * @return A ServerResponse configured for SSE communication, or an error response if - * the server is shutting down or the connection fails - */ - private ServerResponse handleSseConnection(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - - // Send initial endpoint event - try { - return ServerResponse.sse(sseBuilder -> { - sseBuilder.onComplete(() -> { - logger.debug("SSE connection completed for session: {}", sessionId); - sessions.remove(sessionId); - }); - sseBuilder.onTimeout(() -> { - logger.debug("SSE connection timed out for session: {}", sessionId); - sessions.remove(sessionId); - }); - - WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); - this.sessions.put(sessionId, session); - - try { - sseBuilder.id(sessionId) - .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event: {}", e.getMessage()); - sseBuilder.error(e); - } - }, Duration.ZERO); - } - catch (Exception e) { - logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); - sessions.remove(sessionId); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); - } - } - - /** - * Handles incoming JSON-RPC messages from clients. This method: - *
    - *
  • Deserializes the request body into a JSON-RPC message
  • - *
  • Processes the message through the session's handle method
  • - *
  • Returns appropriate HTTP responses based on the processing result
  • - *
- * @param request The incoming server request containing the JSON-RPC message - * @return A ServerResponse indicating success (200 OK) or appropriate error status - * with error details in case of failures - */ - private ServerResponse handleMessage(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - if (request.param("sessionId").isEmpty()) { - return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); - } - - String sessionId = request.param("sessionId").get(); - McpServerSession session = sessions.get(sessionId); - - if (session == null) { - return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId)); - } - - try { - final McpTransportContext transportContext = this.contextExtractor.extract(request, - new DefaultMcpTransportContext()); - - String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - - // Process the message through the session's handle method - session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); // Block - // for - // WebMVC - // compatibility - - return ServerResponse.ok().build(); - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().body(new McpError("Invalid message format")); - } - catch (Exception e) { - logger.error("Error handling message: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); - } - } - - /** - * Implementation of McpServerTransport for WebMVC SSE sessions. This class handles - * the transport-level communication for a specific client session. - */ - private class WebMvcMcpSessionTransport implements McpServerTransport { - - private final String sessionId; - - private final SseBuilder sseBuilder; - - /** - * Lock to ensure thread-safe access to the SSE builder when sending messages. - * This prevents concurrent modifications that could lead to corrupted SSE events. - */ - private final ReentrantLock sseBuilderLock = new ReentrantLock(); - - /** - * Creates a new session transport with the specified ID and SSE builder. - * @param sessionId The unique identifier for this session - * @param sseBuilder The SSE builder for sending server events to the client - */ - WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { - this.sessionId = sessionId; - this.sseBuilder = sseBuilder; - logger.debug("Session transport {} initialized with SSE builder", sessionId); - } - - /** - * Sends a JSON-RPC message to the client through the SSE connection. - * @param message The JSON-RPC message to send - * @return A Mono that completes when the message has been sent - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return Mono.fromRunnable(() -> { - sseBuilderLock.lock(); - try { - String jsonText = objectMapper.writeValueAsString(message); - sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); - logger.debug("Message sent to session {}", sessionId); - } - catch (Exception e) { - logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); - sseBuilder.error(e); - } - finally { - sseBuilderLock.unlock(); - } - }); - } - - /** - * Converts data from one type to another using the configured ObjectMapper. - * @param data The source data object to convert - * @param typeRef The target type reference - * @return The converted object of type T - * @param The target type - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. - * @return A Mono that completes when the shutdown is complete - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - logger.debug("Closing session transport: {}", sessionId); - sseBuilderLock.lock(); - try { - sseBuilder.complete(); - logger.debug("Successfully completed SSE builder for session {}", sessionId); - } - catch (Exception e) { - logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); - } - finally { - sseBuilderLock.unlock(); - } - }); - } - - /** - * Closes the transport immediately. - */ - @Override - public void close() { - sseBuilderLock.lock(); - try { - sseBuilder.complete(); - logger.debug("Successfully completed SSE builder for session {}", sessionId); - } - catch (Exception e) { - logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); - } - finally { - sseBuilderLock.unlock(); - } - } - - } - - /** - * Creates a new Builder instance for configuring and creating instances of - * WebMvcSseServerTransportProvider. - * @return A new Builder instance - */ - public static Builder builder() { - return new Builder(); - } - - /** - * Builder for creating instances of WebMvcSseServerTransportProvider. - *

- * This builder provides a fluent API for configuring and creating instances of - * WebMvcSseServerTransportProvider with custom settings. - */ - public static class Builder { - - private ObjectMapper objectMapper = new ObjectMapper(); - - private String baseUrl = ""; - - private String messageEndpoint; - - private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - - private Duration keepAliveInterval; - - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; - - /** - * Sets the JSON object mapper to use for message serialization/deserialization. - * @param objectMapper The object mapper to use - * @return This builder instance for method chaining - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Sets the base URL for the server transport. - * @param baseUrl The base URL to use - * @return This builder instance for method chaining - */ - public Builder baseUrl(String baseUrl) { - Assert.notNull(baseUrl, "Base URL must not be null"); - this.baseUrl = baseUrl; - return this; - } - - /** - * Sets the endpoint path where clients will send their messages. - * @param messageEndpoint The message endpoint path - * @return This builder instance for method chaining - */ - public Builder messageEndpoint(String messageEndpoint) { - Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); - this.messageEndpoint = messageEndpoint; - return this; - } - - /** - * Sets the endpoint path where clients will establish SSE connections. - *

- * If not specified, the default value of {@link #DEFAULT_SSE_ENDPOINT} will be - * used. - * @param sseEndpoint The SSE endpoint path - * @return This builder instance for method chaining - */ - public Builder sseEndpoint(String sseEndpoint) { - Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); - this.sseEndpoint = sseEndpoint; - return this; - } - - /** - * Sets the interval for keep-alive pings. - *

- * If not specified, keep-alive pings will be disabled. - * @param keepAliveInterval The interval duration for keep-alive pings - * @return This builder instance for method chaining - */ - public Builder keepAliveInterval(Duration keepAliveInterval) { - this.keepAliveInterval = keepAliveInterval; - return this; - } - - /** - * Sets the context extractor that allows providing the MCP feature - * implementations to inspect HTTP transport level metadata that was present at - * HTTP request processing time. This allows to extract custom headers and other - * useful data for use during execution later on in the process. - * @param contextExtractor The contextExtractor to fill in a - * {@link McpTransportContext}. - * @return this builder instance - * @throws IllegalArgumentException if contextExtractor is null - */ - public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { - Assert.notNull(contextExtractor, "contextExtractor must not be null"); - this.contextExtractor = contextExtractor; - return this; - } - - /** - * Builds a new instance of WebMvcSseServerTransportProvider with the configured - * settings. - * @return A new WebMvcSseServerTransportProvider instance - * @throws IllegalStateException if objectMapper or messageEndpoint is not set - */ - public WebMvcSseServerTransportProvider build() { - if (messageEndpoint == null) { - throw new IllegalStateException("MessageEndpoint must be set"); - } - return new WebMvcSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval, contextExtractor); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java deleted file mode 100644 index fef1920fc..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.McpStatelessServerHandler; -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpStatelessServerTransport; -import io.modelcontextprotocol.server.McpTransportContext; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.RouterFunctions; -import org.springframework.web.servlet.function.ServerRequest; -import org.springframework.web.servlet.function.ServerResponse; -import reactor.core.publisher.Mono; - -import java.io.IOException; -import java.util.List; - -/** - * Implementation of a WebMVC based {@link McpStatelessServerTransport}. - * - *

- * This is the non-reactive version of - * {@link io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport} - * - * @author Christian Tzolov - */ -public class WebMvcStatelessServerTransport implements McpStatelessServerTransport { - - private static final Logger logger = LoggerFactory.getLogger(WebMvcStatelessServerTransport.class); - - private final ObjectMapper objectMapper; - - private final String mcpEndpoint; - - private final RouterFunction routerFunction; - - private McpStatelessServerHandler mcpHandler; - - private McpTransportContextExtractor contextExtractor; - - private volatile boolean isClosing = false; - - private WebMvcStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor) { - Assert.notNull(objectMapper, "objectMapper must not be null"); - Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); - Assert.notNull(contextExtractor, "contextExtractor must not be null"); - - this.objectMapper = objectMapper; - this.mcpEndpoint = mcpEndpoint; - this.contextExtractor = contextExtractor; - this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) - .build(); - } - - @Override - public void setMcpHandler(McpStatelessServerHandler mcpHandler) { - this.mcpHandler = mcpHandler; - } - - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> this.isClosing = true); - } - - /** - * Returns the WebMVC router function that defines the transport's HTTP endpoints. - * This router function should be integrated into the application's web configuration. - * - *

- * The router function defines one endpoint handling two HTTP methods: - *

    - *
  • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
  • - *
  • POST {messageEndpoint} - For handling client requests and notifications
  • - *
- * @return The configured {@link RouterFunction} for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - private ServerResponse handleGet(ServerRequest request) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); - } - - private ServerResponse handlePost(ServerRequest request) { - if (isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) - && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { - return ServerResponse.badRequest().build(); - } - - try { - String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - - if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { - try { - McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler - .handleRequest(transportContext, jsonrpcRequest) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) - .block(); - return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).body(jsonrpcResponse); - } - catch (Exception e) { - logger.error("Failed to handle request: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .body(new McpError("Failed to handle request: " + e.getMessage())); - } - } - else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { - try { - this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) - .block(); - return ServerResponse.accepted().build(); - } - catch (Exception e) { - logger.error("Failed to handle notification: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .body(new McpError("Failed to handle notification: " + e.getMessage())); - } - } - else { - return ServerResponse.badRequest() - .body(new McpError("The server accepts either requests or notifications")); - } - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().body(new McpError("Invalid message format")); - } - catch (Exception e) { - logger.error("Unexpected error handling message: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .body(new McpError("Unexpected error: " + e.getMessage())); - } - } - - /** - * Create a builder for the server. - * @return a fresh {@link Builder} instance. - */ - public static Builder builder() { - return new Builder(); - } - - /** - * Builder for creating instances of {@link WebMvcStatelessServerTransport}. - *

- * This builder provides a fluent API for configuring and creating instances of - * WebMvcStatelessServerTransport with custom settings. - */ - public static class Builder { - - private ObjectMapper objectMapper; - - private String mcpEndpoint = "/mcp"; - - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; - - private Builder() { - // used by a static method - } - - /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP - * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Sets the endpoint URI where clients should send their JSON-RPC messages. - * @param messageEndpoint The message endpoint URI. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if messageEndpoint is null - */ - public Builder messageEndpoint(String messageEndpoint) { - Assert.notNull(messageEndpoint, "Message endpoint must not be null"); - this.mcpEndpoint = messageEndpoint; - return this; - } - - /** - * Sets the context extractor that allows providing the MCP feature - * implementations to inspect HTTP transport level metadata that was present at - * HTTP request processing time. This allows to extract custom headers and other - * useful data for use during execution later on in the process. - * @param contextExtractor The contextExtractor to fill in a - * {@link McpTransportContext}. - * @return this builder instance - * @throws IllegalArgumentException if contextExtractor is null - */ - public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { - Assert.notNull(contextExtractor, "Context extractor must not be null"); - this.contextExtractor = contextExtractor; - return this; - } - - /** - * Builds a new instance of {@link WebMvcStatelessServerTransport} with the - * configured settings. - * @return A new WebMvcStatelessServerTransport instance - * @throws IllegalStateException if required parameters are not set - */ - public WebMvcStatelessServerTransport build() { - Assert.notNull(objectMapper, "ObjectMapper must be set"); - Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - - return new WebMvcStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java deleted file mode 100644 index fa51a0130..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ /dev/null @@ -1,690 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.locks.ReentrantLock; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.RouterFunctions; -import org.springframework.web.servlet.function.ServerRequest; -import org.springframework.web.servlet.function.ServerResponse; -import org.springframework.web.servlet.function.ServerResponse.SseBuilder; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.server.DefaultMcpTransportContext; -import io.modelcontextprotocol.server.McpTransportContext; -import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpStreamableServerSession; -import io.modelcontextprotocol.spec.McpStreamableServerTransport; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.ProtocolVersions; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.KeepAliveScheduler; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * Server-side implementation of the Model Context Protocol (MCP) streamable transport - * layer using HTTP with Server-Sent Events (SSE) through Spring WebMVC. This - * implementation provides a bridge between synchronous WebMVC operations and reactive - * programming patterns to maintain compatibility with the reactive transport interface. - * - *

- * This is the non-reactive version of - * {@link io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider} - * - * @author Christian Tzolov - * @author Dariusz JΔ™drzejczyk - * @see McpStreamableServerTransportProvider - * @see RouterFunction - */ -public class WebMvcStreamableServerTransportProvider implements McpStreamableServerTransportProvider { - - private static final Logger logger = LoggerFactory.getLogger(WebMvcStreamableServerTransportProvider.class); - - /** - * Event type for JSON-RPC messages sent through the SSE connection. - */ - public static final String MESSAGE_EVENT_TYPE = "message"; - - /** - * Event type for sending the message endpoint URI to clients. - */ - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - - /** - * Default base URL for the message endpoint. - */ - public static final String DEFAULT_BASE_URL = ""; - - /** - * The endpoint URI where clients should send their JSON-RPC messages. Defaults to - * "/mcp". - */ - private final String mcpEndpoint; - - /** - * Flag indicating whether DELETE requests are disallowed on the endpoint. - */ - private final boolean disallowDelete; - - private final ObjectMapper objectMapper; - - private final RouterFunction routerFunction; - - private McpStreamableServerSession.Factory sessionFactory; - - /** - * Map of active client sessions, keyed by mcp-session-id. - */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - - private McpTransportContextExtractor contextExtractor; - - /** - * Flag indicating if the transport is shutting down. - */ - private volatile boolean isClosing = false; - - private KeepAliveScheduler keepAliveScheduler; - - /** - * Constructs a new WebMvcStreamableServerTransportProvider instance. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - * of messages. - * @param baseUrl The base URL for the message endpoint, used to construct the full - * endpoint URL for clients. - * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC - * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. - * @param disallowDelete Whether to disallow DELETE requests on the endpoint. - * @throws IllegalArgumentException if any parameter is null - */ - private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, - boolean disallowDelete, McpTransportContextExtractor contextExtractor, - Duration keepAliveInterval) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); - Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); - - this.objectMapper = objectMapper; - this.mcpEndpoint = mcpEndpoint; - this.disallowDelete = disallowDelete; - this.contextExtractor = contextExtractor; - this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) - .DELETE(this.mcpEndpoint, this::handleDelete) - .build(); - - if (keepAliveInterval != null) { - this.keepAliveScheduler = KeepAliveScheduler - .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) - .initialDelay(keepAliveInterval) - .interval(keepAliveInterval) - .build(); - - this.keepAliveScheduler.start(); - } - } - - @Override - public List protocolVersions() { - return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); - } - - @Override - public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { - this.sessionFactory = sessionFactory; - } - - /** - * Broadcasts a notification to all connected clients through their SSE connections. - * If any errors occur during sending to a particular client, they are logged but - * don't prevent sending to other clients. - * @param method The method name for the notification - * @param params The parameters for the notification - * @return A Mono that completes when the broadcast attempt is finished - */ - @Override - public Mono notifyClients(String method, Object params) { - if (this.sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); - - return Mono.fromRunnable(() -> { - this.sessions.values().parallelStream().forEach(session -> { - try { - session.sendNotification(method, params).block(); - } - catch (Exception e) { - logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); - } - }); - }); - } - - /** - * Initiates a graceful shutdown of the transport. - * @return A Mono that completes when all cleanup operations are finished - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - this.isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); - - this.sessions.values().parallelStream().forEach(session -> { - try { - session.closeGracefully().block(); - } - catch (Exception e) { - logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); - } - }); - - this.sessions.clear(); - logger.debug("Graceful shutdown completed"); - }).then().doOnSuccess(v -> { - if (this.keepAliveScheduler != null) { - this.keepAliveScheduler.shutdown(); - } - }); - } - - /** - * Returns the RouterFunction that defines the HTTP endpoints for this transport. The - * router function handles three endpoints: - *

    - *
  • GET [mcpEndpoint] - For establishing SSE connections and message replay
  • - *
  • POST [mcpEndpoint] - For receiving JSON-RPC messages from clients
  • - *
  • DELETE [mcpEndpoint] - For session deletion (if enabled)
  • - *
- * @return The configured RouterFunction for handling HTTP requests - */ - public RouterFunction getRouterFunction() { - return this.routerFunction; - } - - /** - * Setup the listening SSE connections and message replay. - * @param request The incoming server request - * @return A ServerResponse configured for SSE communication, or an error response - */ - private ServerResponse handleGet(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { - return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { - return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); - } - - String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - McpStreamableServerSession session = this.sessions.get(sessionId); - - if (session == null) { - return ServerResponse.notFound().build(); - } - - logger.debug("Handling GET request for session: {}", sessionId); - - try { - return ServerResponse.sse(sseBuilder -> { - sseBuilder.onTimeout(() -> { - logger.debug("SSE connection timed out for session: {}", sessionId); - }); - - WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( - sessionId, sseBuilder); - - // Check if this is a replay request - if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { - String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); - - try { - session.replay(lastId) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) - .toIterable() - .forEach(message -> { - try { - sessionTransport.sendMessage(message) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) - .block(); - } - catch (Exception e) { - logger.error("Failed to replay message: {}", e.getMessage()); - sseBuilder.error(e); - } - }); - } - catch (Exception e) { - logger.error("Failed to replay messages: {}", e.getMessage()); - sseBuilder.error(e); - } - } - else { - // Establish new listening stream - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - - sseBuilder.onComplete(() -> { - logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - }); - } - }, Duration.ZERO); - } - catch (Exception e) { - logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); - } - } - - /** - * Handles POST requests for incoming JSON-RPC messages from clients. - * @param request The incoming server request containing the JSON-RPC message - * @return A ServerResponse indicating success or appropriate error status - */ - private ServerResponse handlePost(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM) - || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) { - return ServerResponse.badRequest() - .body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - try { - String body = request.body(String.class); - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - - // Handle initialization request - if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest - && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { - McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), - new TypeReference() { - }); - McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory - .startSession(initializeRequest); - this.sessions.put(init.session().getId(), init.session()); - - try { - McpSchema.InitializeResult initResult = init.initResult().block(); - - return ServerResponse.ok() - .contentType(MediaType.APPLICATION_JSON) - .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) - .body(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, - null)); - } - catch (Exception e) { - logger.error("Failed to initialize session: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); - } - } - - // Handle other messages that require a session - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { - return ServerResponse.badRequest().body(new McpError("Session ID missing")); - } - - String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - McpStreamableServerSession session = this.sessions.get(sessionId); - - if (session == null) { - return ServerResponse.status(HttpStatus.NOT_FOUND) - .body(new McpError("Session not found: " + sessionId)); - } - - if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { - session.accept(jsonrpcResponse) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) - .block(); - return ServerResponse.accepted().build(); - } - else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { - session.accept(jsonrpcNotification) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) - .block(); - return ServerResponse.accepted().build(); - } - else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { - // For streaming responses, we need to return SSE - return ServerResponse.sse(sseBuilder -> { - sseBuilder.onComplete(() -> { - logger.debug("Request response stream completed for session: {}", sessionId); - }); - sseBuilder.onTimeout(() -> { - logger.debug("Request response stream timed out for session: {}", sessionId); - }); - - WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( - sessionId, sseBuilder); - - try { - session.responseStream(jsonrpcRequest, sessionTransport) - .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) - .block(); - } - catch (Exception e) { - logger.error("Failed to handle request stream: {}", e.getMessage()); - sseBuilder.error(e); - } - }, Duration.ZERO); - } - else { - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .body(new McpError("Unknown message type")); - } - } - catch (IllegalArgumentException | IOException e) { - logger.error("Failed to deserialize message: {}", e.getMessage()); - return ServerResponse.badRequest().body(new McpError("Invalid message format")); - } - catch (Exception e) { - logger.error("Error handling message: {}", e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); - } - } - - /** - * Handles DELETE requests for session deletion. - * @param request The incoming server request - * @return A ServerResponse indicating success or appropriate error status - */ - private ServerResponse handleDelete(ServerRequest request) { - if (this.isClosing) { - return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); - } - - if (this.disallowDelete) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); - } - - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); - - if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { - return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); - } - - String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); - McpStreamableServerSession session = this.sessions.get(sessionId); - - if (session == null) { - return ServerResponse.notFound().build(); - } - - try { - session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); - this.sessions.remove(sessionId); - return ServerResponse.ok().build(); - } - catch (Exception e) { - logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); - } - } - - /** - * Implementation of McpStreamableServerTransport for WebMVC SSE sessions. This class - * handles the transport-level communication for a specific client session. - * - *

- * This class is thread-safe and uses a ReentrantLock to synchronize access to the - * underlying SSE builder to prevent race conditions when multiple threads attempt to - * send messages concurrently. - */ - private class WebMvcStreamableMcpSessionTransport implements McpStreamableServerTransport { - - private final String sessionId; - - private final SseBuilder sseBuilder; - - private final ReentrantLock lock = new ReentrantLock(); - - private volatile boolean closed = false; - - /** - * Creates a new session transport with the specified ID and SSE builder. - * @param sessionId The unique identifier for this session - * @param sseBuilder The SSE builder for sending server events to the client - */ - WebMvcStreamableMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { - this.sessionId = sessionId; - this.sseBuilder = sseBuilder; - logger.debug("Streamable session transport {} initialized with SSE builder", sessionId); - } - - /** - * Sends a JSON-RPC message to the client through the SSE connection. - * @param message The JSON-RPC message to send - * @return A Mono that completes when the message has been sent - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - return sendMessage(message, null); - } - - /** - * Sends a JSON-RPC message to the client through the SSE connection with a - * specific message ID. - * @param message The JSON-RPC message to send - * @param messageId The message ID for SSE event identification - * @return A Mono that completes when the message has been sent - */ - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { - return Mono.fromRunnable(() -> { - if (this.closed) { - logger.debug("Attempted to send message to closed session: {}", this.sessionId); - return; - } - - this.lock.lock(); - try { - if (this.closed) { - logger.debug("Session {} was closed during message send attempt", this.sessionId); - return; - } - - String jsonText = objectMapper.writeValueAsString(message); - this.sseBuilder.id(messageId != null ? messageId : this.sessionId) - .event(MESSAGE_EVENT_TYPE) - .data(jsonText); - logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); - } - catch (Exception e) { - logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); - try { - this.sseBuilder.error(e); - } - catch (Exception errorException) { - logger.error("Failed to send error to SSE builder for session {}: {}", this.sessionId, - errorException.getMessage()); - } - } - finally { - this.lock.unlock(); - } - }); - } - - /** - * Converts data from one type to another using the configured ObjectMapper. - * @param data The source data object to convert - * @param typeRef The target type reference - * @return The converted object of type T - * @param The target type - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); - } - - /** - * Initiates a graceful shutdown of the transport. - * @return A Mono that completes when the shutdown is complete - */ - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - WebMvcStreamableMcpSessionTransport.this.close(); - }); - } - - /** - * Closes the transport immediately. - */ - @Override - public void close() { - this.lock.lock(); - try { - if (this.closed) { - logger.debug("Session transport {} already closed", this.sessionId); - return; - } - - this.closed = true; - - this.sseBuilder.complete(); - logger.debug("Successfully completed SSE builder for session {}", sessionId); - } - catch (Exception e) { - logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); - } - finally { - this.lock.unlock(); - } - } - - } - - public static Builder builder() { - return new Builder(); - } - - /** - * Builder for creating instances of {@link WebMvcStreamableServerTransportProvider}. - */ - public static class Builder { - - private ObjectMapper objectMapper; - - private String mcpEndpoint = "/mcp"; - - private boolean disallowDelete = false; - - private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; - - private Duration keepAliveInterval; - - /** - * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP - * messages. - * @param objectMapper The ObjectMapper instance. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if objectMapper is null - */ - public Builder objectMapper(ObjectMapper objectMapper) { - Assert.notNull(objectMapper, "ObjectMapper must not be null"); - this.objectMapper = objectMapper; - return this; - } - - /** - * Sets the endpoint URI where clients should send their JSON-RPC messages. - * @param mcpEndpoint The MCP endpoint URI. Must not be null. - * @return this builder instance - * @throws IllegalArgumentException if mcpEndpoint is null - */ - public Builder mcpEndpoint(String mcpEndpoint) { - Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); - this.mcpEndpoint = mcpEndpoint; - return this; - } - - /** - * Sets whether to disallow DELETE requests on the endpoint. - * @param disallowDelete true to disallow DELETE requests, false otherwise - * @return this builder instance - */ - public Builder disallowDelete(boolean disallowDelete) { - this.disallowDelete = disallowDelete; - return this; - } - - /** - * Sets the context extractor that allows providing the MCP feature - * implementations to inspect HTTP transport level metadata that was present at - * HTTP request processing time. This allows to extract custom headers and other - * useful data for use during execution later on in the process. - * @param contextExtractor The contextExtractor to fill in a - * {@link McpTransportContext}. - * @return this builder instance - * @throws IllegalArgumentException if contextExtractor is null - */ - public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { - Assert.notNull(contextExtractor, "contextExtractor must not be null"); - this.contextExtractor = contextExtractor; - return this; - } - - /** - * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler - * will be created to periodically check and send keep-alive messages to clients. - * @param keepAliveInterval The interval duration for keep-alive messages, or null - * to disable keep-alive - * @return this builder instance - */ - public Builder keepAliveInterval(Duration keepAliveInterval) { - this.keepAliveInterval = keepAliveInterval; - return this; - } - - /** - * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with - * the configured settings. - * @return A new WebMvcStreamableServerTransportProvider instance - * @throws IllegalStateException if required parameters are not set - */ - public WebMvcStreamableServerTransportProvider build() { - Assert.notNull(this.objectMapper, "ObjectMapper must be set"); - Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); - - return new WebMvcStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, this.disallowDelete, - this.contextExtractor, this.keepAliveInterval); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java deleted file mode 100644 index 8625b6a70..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java +++ /dev/null @@ -1,64 +0,0 @@ -/* -* Copyright 2025 - 2025 the original author or authors. -*/ -package io.modelcontextprotocol.server; - -import org.apache.catalina.Context; -import org.apache.catalina.startup.Tomcat; - -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; - -/** - * @author Christian Tzolov - */ -public class TomcatTestUtil { - - TomcatTestUtil() { - // Prevent instantiation - } - - public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) { - } - - public static TomcatServer createTomcatServer(String contextPath, int port, Class componentClass) { - - // Set up Tomcat first - var tomcat = new Tomcat(); - tomcat.setPort(port); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext(contextPath, baseDir); - - // Create and configure Spring WebMvc context - var appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(componentClass); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - // Configure and start the connector with async support - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - return new TomcatServer(tomcat, appContext); - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java deleted file mode 100644 index 66349216d..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.Timeout; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import reactor.netty.DisposableServer; - -/** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebMcpStreamableAsyncServerTransportTests extends AbstractMcpAsyncServerTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MCP_ENDPOINT = "/mcp"; - - private DisposableServer httpServer; - - private AnnotationConfigWebApplicationContext appContext; - - private Tomcat tomcat; - - private McpStreamableServerTransportProvider transportProvider; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { - return WebMvcStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint(MCP_ENDPOINT) - .build(); - } - - @Bean - public RouterFunction routerFunction( - WebMvcStreamableServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - private McpStreamableServerTransportProvider createMcpTransportProvider() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - tomcat.start(); - tomcat.getConnector(); // Create and start the connector - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - return transportProvider; - } - - @Override - protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(createMcpTransportProvider()); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java deleted file mode 100644 index cab487f12..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.Timeout; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import reactor.netty.DisposableServer; - -/** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class WebMcpStreamableSyncServerTransportTests extends AbstractMcpSyncServerTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MCP_ENDPOINT = "/mcp"; - - private DisposableServer httpServer; - - private AnnotationConfigWebApplicationContext appContext; - - private Tomcat tomcat; - - private McpStreamableServerTransportProvider transportProvider; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { - return WebMvcStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint(MCP_ENDPOINT) - .build(); - } - - @Bean - public RouterFunction routerFunction( - WebMvcStreamableServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - private McpStreamableServerTransportProvider createMcpTransportProvider() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - tomcat.start(); - tomcat.getConnector(); // Create and start the connector - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - return transportProvider; - } - - @Override - protected McpServer.SyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(createMcpTransportProvider()); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java deleted file mode 100644 index bb4c2bf37..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.Timeout; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -@Timeout(15) -class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private static final int PORT = TestUtil.findAvailablePort(); - - private Tomcat tomcat; - - private McpServerTransportProvider transportProvider; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - private AnnotationConfigWebApplicationContext appContext; - - private McpServerTransportProvider createMcpTransportProvider() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - tomcat.start(); - tomcat.getConnector(); // Create and start the connector - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - return transportProvider; - } - - @Override - protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(createMcpTransportProvider()); - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (transportProvider != null) { - transportProvider.closeGracefully().block(); - } - if (appContext != null) { - appContext.close(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java deleted file mode 100644 index cce36d191..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpSchema; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import static org.assertj.core.api.Assertions.assertThat; - -class WebMvcSseCustomContextPathTests { - - private static final String CUSTOM_CONTEXT_PATH = "/app/1"; - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private WebMvcSseServerTransportProvider mcpServerTransportProvider; - - McpClient.SyncSpec clientBuilder; - - private TomcatTestUtil.TomcatServer tomcatServer; - - @BeforeEach - public void before() { - - tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class); - - try { - tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) - .build(); - - clientBuilder = McpClient.sync(clientTransport); - - mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); - } - - @AfterEach - public void after() { - if (mcpServerTransportProvider != null) { - mcpServerTransportProvider.closeGracefully().block(); - } - if (tomcatServer.appContext() != null) { - tomcatServer.appContext().close(); - } - if (tomcatServer.tomcat() != null) { - try { - tomcatServer.tomcat().stop(); - tomcatServer.tomcat().destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - @Test - void testCustomContextPath() { - McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - assertThat(client.initialize()).isNotNull(); - } - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - - return WebMvcSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .baseUrl(CUSTOM_CONTEXT_PATH) - .messageEndpoint(MESSAGE_ENDPOINT) - .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) - .build(); - // return new WebMvcSseServerTransportProvider(new ObjectMapper(), - // CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, - // WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java deleted file mode 100644 index 8cb2973ed..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server; - -import static org.assertj.core.api.Assertions.assertThat; - -import java.time.Duration; - -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerRequest; -import org.springframework.web.servlet.function.ServerResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; -import reactor.core.scheduler.Schedulers; - -@Timeout(15) -class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private WebMvcSseServerTransportProvider mcpServerTransportProvider; - - static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r, tc) -> { - tc.put("important", "value"); - return tc; - }; - - @Override - protected void prepareClients(int port, String mcpEndpoint) { - - clientBuilders.put("httpclient", - McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port).build()) - .requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", McpClient - .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build()) - .requestTimeout(Duration.ofHours(10))); - } - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return WebMvcSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .contextExtractor(TEST_CONTEXT_EXTRACTOR) - .build(); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - private TomcatTestUtil.TomcatServer tomcatServer; - - @BeforeEach - public void before() { - - tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); - - try { - tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - prepareClients(PORT, MESSAGE_ENDPOINT); - - // Get the transport from Spring context - mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); - - } - - @AfterEach - public void after() { - reactor.netty.http.HttpResources.disposeLoopsAndConnections(); - if (mcpServerTransportProvider != null) { - mcpServerTransportProvider.closeGracefully().block(); - } - Schedulers.shutdownNow(); - if (tomcatServer.appContext() != null) { - tomcatServer.appContext().close(); - } - if (tomcatServer.tomcat() != null) { - try { - tomcatServer.tomcat().stop(); - tomcatServer.tomcat().destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - @Override - protected AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(mcpServerTransportProvider); - } - - @Override - protected SingleSessionSyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(mcpServerTransportProvider); - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java deleted file mode 100644 index 101a067ad..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; -import org.apache.catalina.Context; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.Timeout; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; -import org.springframework.web.servlet.DispatcherServlet; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -@Timeout(15) -class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private static final int PORT = TestUtil.findAvailablePort(); - - private Tomcat tomcat; - - private WebMvcSseServerTransportProvider transportProvider; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return WebMvcSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - private AnnotationConfigWebApplicationContext appContext; - - @Override - protected McpServer.SyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(createMcpTransportProvider()); - } - - private WebMvcSseServerTransportProvider createMcpTransportProvider() { - // Set up Tomcat first - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - // Set Tomcat base directory to java.io.tmpdir to avoid permission issues - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - // Use the same directory for document base - Context context = tomcat.addContext("", baseDir); - - // Create and configure Spring WebMvc context - appContext = new AnnotationConfigWebApplicationContext(); - appContext.register(TestConfig.class); - appContext.setServletContext(context.getServletContext()); - appContext.refresh(); - - // Get the transport from Spring context - transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); - - // Create DispatcherServlet with our Spring context - DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); - - // Add servlet to Tomcat and get the wrapper - var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); - wrapper.setLoadOnStartup(1); - context.addServletMappingDecoded("/*", "dispatcherServlet"); - - try { - tomcat.start(); - tomcat.getConnector(); // Create and start the connector - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - return transportProvider; - } - - @Override - protected void onStart() { - } - - @Override - protected void onClose() { - if (transportProvider != null) { - transportProvider.closeGracefully().block(); - } - if (appContext != null) { - appContext.close(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java deleted file mode 100644 index c7c1e710d..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server; - -import static org.assertj.core.api.Assertions.assertThat; - -import java.time.Duration; - -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.AbstractStatelessIntegrationTests; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; -import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; -import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; -import reactor.core.scheduler.Schedulers; - -@Timeout(15) -class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private WebMvcStatelessServerTransport mcpServerTransport; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { - - return WebMvcStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(MESSAGE_ENDPOINT) - .build(); - - } - - @Bean - public RouterFunction routerFunction(WebMvcStatelessServerTransport statelessServerTransport) { - return statelessServerTransport.getRouterFunction(); - } - - } - - private TomcatTestUtil.TomcatServer tomcatServer; - - @Override - protected StatelessAsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(this.mcpServerTransport); - } - - @Override - protected StatelessSyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(this.mcpServerTransport); - } - - @Override - protected void prepareClients(int port, String mcpEndpoint) { - - clientBuilders.put("httpclient", McpClient - .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) - .requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient - .sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + port)) - .endpoint(mcpEndpoint) - .build()) - .requestTimeout(Duration.ofHours(10))); - } - - @BeforeEach - public void before() { - - tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); - - try { - tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - prepareClients(PORT, MESSAGE_ENDPOINT); - - // Get the transport from Spring context - this.mcpServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); - - } - - @AfterEach - public void after() { - reactor.netty.http.HttpResources.disposeLoopsAndConnections(); - if (this.mcpServerTransport != null) { - this.mcpServerTransport.closeGracefully().block(); - } - Schedulers.shutdownNow(); - if (tomcatServer.appContext() != null) { - tomcatServer.appContext().close(); - } - if (tomcatServer.tomcat() != null) { - try { - tomcatServer.tomcat().stop(); - tomcatServer.tomcat().destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java deleted file mode 100644 index 2f4c651fd..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ -package io.modelcontextprotocol.server; - -import static org.assertj.core.api.Assertions.assertThat; - -import java.time.Duration; - -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.servlet.function.ServerRequest; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SyncSpecification; -import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; -import reactor.core.scheduler.Schedulers; - -@Timeout(15) -class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { - - private static final int PORT = TestUtil.findAvailablePort(); - - private static final String MESSAGE_ENDPOINT = "/mcp/message"; - - private WebMvcStreamableServerTransportProvider mcpServerTransportProvider; - - static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r, tc) -> { - tc.put("important", "value"); - return tc; - }; - - @Configuration - @EnableWebMvc - static class TestConfig { - - @Bean - public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransportProvider() { - return WebMvcStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .contextExtractor(TEST_CONTEXT_EXTRACTOR) - .mcpEndpoint(MESSAGE_ENDPOINT) - .build(); - } - - @Bean - public RouterFunction routerFunction( - WebMvcStreamableServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - private TomcatTestUtil.TomcatServer tomcatServer; - - @BeforeEach - public void before() { - - tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); - - try { - tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(MESSAGE_ENDPOINT) - .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(MESSAGE_ENDPOINT) - .build())); - - // Get the transport from Spring context - this.mcpServerTransportProvider = tomcatServer.appContext() - .getBean(WebMvcStreamableServerTransportProvider.class); - - } - - @Override - protected AsyncSpecification prepareAsyncServerBuilder() { - return McpServer.async(this.mcpServerTransportProvider); - } - - @Override - protected SyncSpecification prepareSyncServerBuilder() { - return McpServer.sync(this.mcpServerTransportProvider); - } - - @AfterEach - public void after() { - reactor.netty.http.HttpResources.disposeLoopsAndConnections(); - if (mcpServerTransportProvider != null) { - mcpServerTransportProvider.closeGracefully().block(); - } - Schedulers.shutdownNow(); - if (tomcatServer.appContext() != null) { - tomcatServer.appContext().close(); - } - if (tomcatServer.tomcat() != null) { - try { - tomcatServer.tomcat().stop(); - tomcatServer.tomcat().destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - @Override - protected void prepareClients(int port, String mcpEndpoint) { - - clientBuilders.put("httpclient", McpClient - .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) - .requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient - .sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + port)) - .endpoint(mcpEndpoint) - .build()) - .requestTimeout(Duration.ofHours(10))); - } - -} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml deleted file mode 100644 index d4ccbc173..000000000 --- a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml +++ /dev/null @@ -1,27 +0,0 @@ - - - - - - - %d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n - - - - - - - - - - - - - - - - - - - - diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index 563f60de9..531c0bbc5 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -1,12 +1,12 @@ + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 1.1.0-SNAPSHOT mcp-test jar @@ -23,8 +23,8 @@ io.modelcontextprotocol.sdk - mcp - 0.12.0-SNAPSHOT + mcp-core + 1.1.0-SNAPSHOT @@ -33,12 +33,6 @@ ${slf4j-api.version} - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - io.projectreactor reactor-core @@ -97,8 +91,91 @@ ${json-unit-assertj.version} + + + org.springframework + spring-webmvc + ${springframework.version} + test + + + + org.springframework + spring-context + ${springframework.version} + test + + + + org.springframework + spring-test + ${springframework.version} + test + + + + io.projectreactor.netty + reactor-netty-http + test + + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + test + + + + org.apache.tomcat.embed + tomcat-embed-websocket + ${tomcat.version} + test + + + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + test + + + + jackson3 + + true + + + + io.modelcontextprotocol.sdk + mcp-json-jackson3 + 1.1.0-SNAPSHOT + test + + + + + jackson2 + + + io.modelcontextprotocol.sdk + mcp-json-jackson2 + 1.1.0-SNAPSHOT + test + + + + + \ No newline at end of file diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 5246c1e2d..270bc4308 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -4,14 +4,6 @@ package io.modelcontextprotocol; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -29,15 +21,12 @@ import java.util.function.Function; import java.util.stream.Collectors; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; import io.modelcontextprotocol.server.McpSyncServerExchange; -import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -56,12 +45,25 @@ import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.util.Utils; import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -73,7 +75,7 @@ public abstract class AbstractMcpClientServerIntegrationTests { abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -81,7 +83,6 @@ void simple(String clientType) { var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .requestTimeout(Duration.ofSeconds(1000)) .build(); - try ( // Create client without sampling capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -91,23 +92,25 @@ void simple(String clientType) { assertThat(client.initialize()).isNotNull(); } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } // --------------------------------------- // Sampling Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - return Mono.just(mock(CallToolResult.class)); + return exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)) + .then(Mono.just(mock(CallToolResult.class))); }) .build(); @@ -128,11 +131,13 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { .hasMessage("Client must be configured with sampling capabilities"); } } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -145,13 +150,14 @@ void testCreateMessageSuccess(String clientType) { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -195,11 +201,13 @@ void testCreateMessageSuccess(String clientType) { assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); }); } - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { // Client @@ -219,20 +227,16 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - // Server - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); AtomicReference samplingResult = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -256,30 +260,35 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr .requestTimeout(Duration.ofSeconds(4)) .tools(tool) .build(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - - mcpClient.close(); - mcpServer.close(); + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { var clientBuilder = clientBuilders.get(clientType); @@ -297,16 +306,12 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) .build(); - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var createMessageRequest = McpSchema.CreateMessageRequest.builder() @@ -329,28 +334,34 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("1000ms"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - mcpClient.close(); - mcpServer.close(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + } + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Elicitation Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithoutElicitationCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) .then(Mono.just(mock(CallToolResult.class)))) .build(); @@ -370,11 +381,13 @@ void testCreateElicitationWithoutElicitationCapabilities(String clientType) { .hasMessage("Client must be configured with elicitation capabilities"); } } - server.closeGracefully().block(); + finally { + server.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -387,11 +400,12 @@ void testCreateElicitationSuccess(String clientType) { Map.of("message", request.message())); }; - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = McpSchema.ElicitRequest.builder() @@ -425,11 +439,13 @@ void testCreateElicitationSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); } - mcpServer.closeGracefully().block(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -440,18 +456,14 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) + CallToolResult callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) .build(); - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - AtomicReference resultRef = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = McpSchema.ElicitRequest.builder() @@ -471,25 +483,31 @@ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - assertWith(resultRef.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + } + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCreateElicitationWithRequestTimeoutFail(String clientType) { var latch = new CountDownLatch(1); @@ -511,17 +529,12 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); }; - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + CallToolResult callResponse = CallToolResult.builder().addContent(new TextContent("CALL RESPONSE")).build(); AtomicReference resultRef = new AtomicReference<>(); McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { var elicitationRequest = ElicitRequest.builder() @@ -541,25 +554,31 @@ void testCreateElicitationWithRequestTimeoutFail(String clientType) { .tools(tool) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); - ElicitResult elicitResult = resultRef.get(); - assertThat(elicitResult).isNull(); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + } + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Roots Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -601,18 +620,19 @@ void testRootsSuccess(String clientType) { assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsWithoutCapability(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { exchange.listRoots(); // try to list roots @@ -639,12 +659,13 @@ void testRootsWithoutCapability(String clientType) { assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); } } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsNotificationWithEmptyRootsList(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -668,12 +689,13 @@ void testRootsNotificationWithEmptyRootsList(String clientType) { assertThat(rootsRef.get()).isEmpty(); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -701,12 +723,13 @@ void testRootsWithMultipleHandlers(String clientType) { assertThat(rootsRef2.get()).containsAll(roots); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testRootsServerCloseWithActiveSubscription(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -732,31 +755,26 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { assertThat(rootsRef.get()).containsAll(roots); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); var responseBodyIsNullOrBlank = new AtomicBoolean(false); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=importantValue")) + .build(); McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { try { @@ -793,12 +811,13 @@ void testToolCallSuccess(String clientType) { assertThat(responseBodyIsNullOrBlank.get()).isFalse(); assertThat(response).isNotNull().isEqualTo(callResponse); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -809,7 +828,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { .tool(Tool.builder() .name("tool1") .description("tool1 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { // We trigger a timeout on blocking read, raising an exception @@ -824,18 +843,18 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { assertThat(initResult).isNotNull(); // We expect the tool call to fail immediately with the exception raised by - // the offending tool - // instead of getting back a timeout. + // the offending tool instead of getting back a timeout. assertThatExceptionOfType(McpError.class) .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) .withMessageContaining("Timeout on blocking read"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccessWithTranportContextExtraction(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -844,10 +863,11 @@ void testToolCallSuccessWithTranportContextExtraction(String clientType) { var transportContextIsEmpty = new AtomicBoolean(false); var responseBodyIsNullOrBlank = new AtomicBoolean(false); - var expectedCallResponse = new McpSchema.CallToolResult( - List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=value")), null); + var expectedCallResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=value")) + .build(); McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { McpTransportContext transportContext = exchange.transportContext(); @@ -863,8 +883,9 @@ void testToolCallSuccessWithTranportContextExtraction(String clientType) { e.printStackTrace(); } - return new McpSchema.CallToolResult( - List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)), null); + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)) + .build(); }) .build(); @@ -886,19 +907,23 @@ void testToolCallSuccessWithTranportContextExtraction(String clientType) { assertThat(responseBodyIsNullOrBlank.get()).isFalse(); assertThat(response).isNotNull().isEqualTo(expectedCallResponse); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("CALL RESPONSE")) + .build(); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { // perform a blocking call to a remote service try { @@ -967,7 +992,7 @@ void testToolListChangeHandlingSuccess(String clientType) { .tool(Tool.builder() .name("tool2") .description("tool2 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> callResponse) .build(); @@ -978,12 +1003,13 @@ void testToolListChangeHandlingSuccess(String clientType) { assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); }); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -995,15 +1021,16 @@ void testInitialize(String clientType) { InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Logging Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testLoggingNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 3; CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); @@ -1017,7 +1044,7 @@ void testLoggingNotification(String clientType) throws InterruptedException { .tool(Tool.builder() .name("logging-test") .description("Test logging notifications") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -1054,7 +1081,10 @@ void testLoggingNotification(String clientType) throws InterruptedException { .logger("test-logger") .data("Another error message") .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Logging test completed"))) + .isError(false) + .build()); //@formatter:on }) .build(); @@ -1107,14 +1137,16 @@ void testLoggingNotification(String clientType) throws InterruptedException { assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); } - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Progress Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress // token @@ -1129,7 +1161,7 @@ void testProgressNotification(String clientType) throws InterruptedException { .tool(McpSchema.Tool.builder() .name("progress-test") .description("Test progress notifications") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -1147,7 +1179,10 @@ void testProgressNotification(String clientType) throws InterruptedException { 0.0, 1.0, "Another processing started"))) .then(exchange.progressNotification( new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) - .thenReturn(new CallToolResult(("Progress test completed"), false)); + .thenReturn(CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Progress test completed"))) + .isError(false) + .build()); }) .build(); @@ -1212,7 +1247,7 @@ void testProgressNotification(String clientType) throws InterruptedException { assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); } finally { - mcpServer.close(); + mcpServer.closeGracefully().block(); } } @@ -1220,7 +1255,7 @@ void testProgressNotification(String clientType) throws InterruptedException { // Completion Tests // --------------------------------------- @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testCompletionShouldReturnExpectedSuggestions(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1242,7 +1277,8 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { List.of(new PromptArgument("language", "Language", "string", false))), (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( - new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + new McpSchema.PromptReference(PromptReference.TYPE, "code_review", "Code review"), + completionHandler)) .build(); try (var mcpClient = clientBuilder.build()) { @@ -1251,7 +1287,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(initResult).isNotNull(); CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), new CompleteRequest.CompleteArgument("language", "py")); CompleteResult result = mcpClient.completeCompletion(request); @@ -1260,17 +1296,18 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.closeGracefully(); } - - mcpServer.close(); } // --------------------------------------- // Ping Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testPingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1282,7 +1319,7 @@ void testPingSuccess(String clientType) { .tool(Tool.builder() .name("ping-async-test") .description("Test ping async behavior") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> { @@ -1299,7 +1336,10 @@ void testPingSuccess(String clientType) { assertThat(result).isNotNull(); }).then(Mono.fromCallable(() -> { executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Async ping test completed"))) + .isError(false) + .build(); })); }) .build(); @@ -1324,15 +1364,16 @@ void testPingSuccess(String clientType) { // Verify execution order assertThat(executionOrder.get()).isEqualTo("123"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } // --------------------------------------- // Tool Structured Output Schema Tests // --------------------------------------- @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1384,7 +1425,7 @@ void testStructuredOutputValidationSuccess(String clientType) { // In WebMVC, structured content is returned properly if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) .containsEntry("operation", "2 + 3") .containsEntry("timestamp", "2024-01-01T10:00:00Z"); } @@ -1400,12 +1441,129 @@ void testStructuredOutputValidationSuccess(String clientType) { .isEqualTo(json(""" {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } - mcpServer.close(); + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @ValueSource(strings = { "httpclient" }) void testStructuredOutputValidationFailure(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1454,12 +1612,13 @@ void testStructuredOutputValidationFailure(String clientType) { String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); assertThat(errorMessage).contains("Validation failed"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1504,12 +1663,13 @@ void testStructuredOutputMissingStructuredContent(String clientType) { assertThat(errorMessage).isEqualTo( "Response missing structured content which is expected when calling tool with non-empty outputSchema"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -1581,8 +1741,9 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { .isEqualTo(json(""" {"count":3,"message":"Dynamic execution"}""")); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } private double evaluateExpression(String expression) { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java index 618247d61..7755ce456 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -4,12 +4,6 @@ package io.modelcontextprotocol; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; - import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -20,9 +14,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; @@ -33,10 +24,21 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; + public abstract class AbstractStatelessIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -48,7 +50,7 @@ public abstract class AbstractStatelessIntegrationTests { abstract protected StatelessSyncSpecification prepareSyncServerBuilder(); @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void simple(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -66,31 +68,27 @@ void simple(String clientType) { assertThat(client.initialize()).isNotNull(); } - server.closeGracefully(); + finally { + server.closeGracefully().block(); + } } // --------------------------------------- // Tools Tests // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = McpSchema.CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("CALL RESPONSE"))) + .isError(false) + .build(); McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification .builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((ctx, request) -> { try { @@ -126,12 +124,13 @@ void testToolCallSuccess(String clientType) { assertThat(response).isNotNull().isEqualTo(callResponse); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully().block(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -142,7 +141,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { .tool(Tool.builder() .name("tool1") .description("tool1 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((context, request) -> { // We trigger a timeout on blocking read, raising an exception @@ -163,20 +162,24 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) .withMessageContaining("Timeout on blocking read"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = McpSchema.CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("CALL RESPONSE"))) + .isError(false) + .build(); McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification .builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((ctx, request) -> { // perform a blocking call to a remote service try { @@ -237,19 +240,20 @@ void testToolListChangeHandlingSuccess(String clientType) { .tool(Tool.builder() .name("tool2") .description("tool2 description") - .inputSchema(emptyJsonSchema) + .inputSchema(EMPTY_JSON_SCHEMA) .build()) .callHandler((exchange, request) -> callResponse) .build(); mcpServer.addTool(tool2); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testInitialize(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -261,16 +265,16 @@ void testInitialize(String clientType) { InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } // --------------------------------------- // Tool Structured Output Schema Tests // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputValidationSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -323,7 +327,7 @@ void testStructuredOutputValidationSuccess(String clientType) { // In WebMVC, structured content is returned properly if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) + assertThat((Map) response.structuredContent()).containsEntry("result", 5.0) .containsEntry("operation", "2 + 3") .containsEntry("timestamp", "2024-01-01T10:00:00Z"); } @@ -339,12 +343,131 @@ void testStructuredOutputValidationSuccess(String clientType) { .isEqualTo(json(""" {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); - mcpServer.close(); + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that throws an exception to simulate an error + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") void testStructuredOutputValidationFailure(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -394,12 +517,13 @@ void testStructuredOutputValidationFailure(String clientType) { String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); assertThat(errorMessage).contains("Validation failed"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputMissingStructuredContent(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -444,12 +568,13 @@ void testStructuredOutputMissingStructuredContent(String clientType) { assertThat(errorMessage).isEqualTo( "Response missing structured content which is expected when calling tool with non-empty outputSchema"); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) + @MethodSource("clientsForTesting") void testStructuredOutputRuntimeToolAddition(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -521,8 +646,9 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { .isEqualTo(json(""" {"count":3,"message":"Dynamic execution"}""")); } - - mcpServer.close(); + finally { + mcpServer.closeGracefully(); + } } private double evaluateExpression(String expression) { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index ed34ebff6..8fb8093ac 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -10,6 +10,7 @@ import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpTransportSessionClosedException; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,12 +45,12 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); static Network network = Network.newNetwork(); - static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + public static String host = "http://localhost:3001"; + @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withNetwork(network) .withNetworkAliases("everything-server") @@ -134,10 +135,13 @@ McpAsyncClient client(McpClientTransport transport, Function client = new AtomicReference<>(); assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + .capabilities(McpSchema.ClientCapabilities.builder().build()); builder = customizer.apply(builder); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -217,9 +221,10 @@ void testSessionClose() { // In case of Streamable HTTP this call should issue a HTTP DELETE request // invalidating the session StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); - // The next use should immediately re-initialize with no issue and send the - // request without any broken connections. - StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + // The next tries to use the closed session and fails + StepVerifier.create(mcpAsyncClient.ping()) + .expectErrorMatches(err -> err.getCause() instanceof McpTransportSessionClosedException) + .verify(); }); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index ea3739da5..bee8f4f16 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -22,8 +23,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -67,18 +66,12 @@ public abstract class AbstractMcpAsyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(2); + return Duration.ofSeconds(20); } McpAsyncClient client(McpClientTransport transport) { @@ -117,16 +110,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, String action) { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -192,7 +175,12 @@ void testListAllToolsReturnsImmutableList() { .consumeNextWith(result -> { assertThat(result.tools()).isNotNull(); // Verify that the returned list is immutable - assertThatThrownBy(() -> result.tools().add(new Tool("test", "test", "{\"type\":\"object\"}"))) + assertThatThrownBy(() -> result.tools() + .add(Tool.builder() + .name("test") + .title("test") + .inputSchema(JSON_MAPPER, "{\"type\":\"object\"}") + .build())) .isInstanceOf(UnsupportedOperationException.class); }) .verifyComplete(); @@ -514,57 +502,64 @@ void testRemoveNonExistentRoot() { @Test void testReadResource() { + AtomicInteger resourceCount = new AtomicInteger(); withClient(createMcpTransport(), client -> { Flux resources = client.initialize() .then(client.listResources(null)) - .flatMapMany(r -> Flux.fromIterable(r.resources())) + .flatMapMany(r -> { + List l = r.resources(); + resourceCount.set(l.size()); + return Flux.fromIterable(l); + }) .flatMap(r -> client.readResource(r)); - StepVerifier.create(resources).recordWith(ArrayList::new).consumeRecordedWith(readResourceResults -> { - - for (ReadResourceResult result : readResourceResults) { - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull().isNotEmpty(); - - // Validate each content item - for (ResourceContents content : result.contents()) { - assertThat(content).isNotNull(); - assertThat(content.uri()).isNotNull().isNotEmpty(); - assertThat(content.mimeType()).isNotNull().isNotEmpty(); - - // Validate content based on its type with more comprehensive - // checks - switch (content.mimeType()) { - case "text/plain" -> { - TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, - content); - assertThat(textContent.text()).isNotNull().isNotEmpty(); - assertThat(textContent.uri()).isNotEmpty(); - } - case "application/octet-stream" -> { - BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, - content); - assertThat(blobContent.blob()).isNotNull().isNotEmpty(); - assertThat(blobContent.uri()).isNotNull().isNotEmpty(); - // Validate base64 encoding format - assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); - } - default -> { - - // Still validate basic properties - if (content instanceof TextResourceContents textContent) { - assertThat(textContent.text()).isNotNull(); + StepVerifier.create(resources) + .recordWith(ArrayList::new) + .thenConsumeWhile(res -> true) + .consumeRecordedWith(readResourceResults -> { + assertThat(readResourceResults.size()).isEqualTo(resourceCount.get()); + for (ReadResourceResult result : readResourceResults) { + + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull().isNotEmpty(); + + // Validate each content item + for (ResourceContents content : result.contents()) { + assertThat(content).isNotNull(); + assertThat(content.uri()).isNotNull().isNotEmpty(); + assertThat(content.mimeType()).isNotNull().isNotEmpty(); + + // Validate content based on its type with more comprehensive + // checks + switch (content.mimeType()) { + case "text/plain" -> { + TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, + content); + assertThat(textContent.text()).isNotNull().isNotEmpty(); + assertThat(textContent.uri()).isNotEmpty(); + } + case "application/octet-stream" -> { + BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, + content); + assertThat(blobContent.blob()).isNotNull().isNotEmpty(); + assertThat(blobContent.uri()).isNotNull().isNotEmpty(); + // Validate base64 encoding format + assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); } - else if (content instanceof BlobResourceContents blobContent) { - assertThat(blobContent.blob()).isNotNull(); + default -> { + + // Still validate basic properties + if (content instanceof TextResourceContents textContent) { + assertThat(textContent.text()).isNotNull(); + } + else if (content instanceof BlobResourceContents blobContent) { + assertThat(blobContent.blob()).isNotNull(); + } } } } } - } - }) - .expectNextCount(10) // Expect 10 elements + }) .verifyComplete(); }); } @@ -684,7 +679,7 @@ void testInitializeWithElicitationCapability() { @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) + .experimental(Map.of("feature", Map.of("featureFlag", true))) .roots(true) .sampling() .build(); @@ -704,7 +699,6 @@ void testInitializeWithAllCapabilities() { assertThat(result.capabilities()).isNotNull(); }).verifyComplete()); } - // --------------------------------------- // Logging Tests // --------------------------------------- @@ -784,7 +778,7 @@ void testSampling() { if (!(content instanceof McpSchema.TextContent text)) return; - assertThat(text.text()).endsWith(response); // Prefixed + assertThat(text.text()).contains(response); }); // Verify sampling request parameters received in our callback diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 175a0107c..26d60568a 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -22,8 +22,6 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -66,12 +64,6 @@ public abstract class AbstractMcpSyncClientTests { abstract protected McpClientTransport createMcpTransport(); - protected void onStart() { - } - - protected void onClose() { - } - protected Duration getRequestTimeout() { return Duration.ofSeconds(14); } @@ -114,17 +106,6 @@ void withClient(McpClientTransport transport, Function void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { @@ -554,11 +535,13 @@ void testNotificationHandlers() { AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); + AtomicBoolean resourcesUpdatedNotificationReceived = new AtomicBoolean(false); withClient(createMcpTransport(), builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) + .resourcesUpdateConsumer(resources -> resourcesUpdatedNotificationReceived.set(true)), client -> { assertThatCode(() -> { @@ -641,7 +624,7 @@ void testSampling() { if (!(content instanceof McpSchema.TextContent text)) return; - assertThat(text.text()).endsWith(response); // Prefixed + assertThat(text.text()).contains(response); }); // Verify sampling request parameters received in our callback diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 1e87d4420..9cd1191d1 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -7,7 +7,6 @@ import java.time.Duration; import java.util.List; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -26,6 +25,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -89,115 +89,95 @@ void testGracefulShutdown() { void testImmediateClose() { var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpAsyncServer::close).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - @Deprecated - void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - @Test void testAddToolCall() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() .tool(newTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build())).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } - @Test - @Deprecated - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier - .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - @Test void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build())).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .build())).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .toolCall(duplicateTool, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) // Duplicate! .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); } @Test void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + List specs = List.of( McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(), McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build() // Duplicate! ); @@ -210,17 +190,23 @@ void testDuplicateToolsInBatchListRegistration() { @Test void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(), McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .callHandler((exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build() // Duplicate! ) .build()).isInstanceOf(IllegalArgumentException.class) @@ -229,11 +215,17 @@ void testDuplicateToolsInBatchVarargsRegistration() { @Test void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, request) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -247,20 +239,23 @@ void testRemoveNonexistentTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool too = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) + .toolCall(too, + (exchange, args) -> Mono.just(CallToolResult.builder().content(List.of()).isError(false).build())) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -299,8 +294,13 @@ void testAddResource() { .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); @@ -317,7 +317,7 @@ void testAddResourceWithNullSpecification() { StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class).hasMessage("Resource must not be null"); }); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -328,14 +328,19 @@ void testAddResourceWithoutCapability() { // Create a server without resource capabilities McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } @@ -345,11 +350,191 @@ void testRemoveResourceWithoutCapability() { McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); }); } + @Test + void testListResources() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.listResources().collectList())) + .expectNextMatches(resources -> resources.size() == 1 && resources.get(0).uri().equals(TEST_RESOURCE_URI)) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier + .create(mcpAsyncServer.addResource(specification).then(mcpAsyncServer.removeResource(TEST_RESOURCE_URI))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResource("nonexistent://resource")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("test://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + // Note: Based on the current implementation, listResourceTemplates() returns + // Flux + // This appears to be a bug in the implementation that should return + // Flux + StepVerifier.create(mcpAsyncServer.listResourceTemplates().collectList()) + .expectNextMatches(resources -> resources.size() >= 0) // Just verify it + // doesn't error + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + // --------------------------------------- // Prompts Tests // --------------------------------------- @@ -371,7 +556,8 @@ void testAddPromptWithNullSpecification() { StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); + assertThat(error).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Prompt specification must not be null"); }); } @@ -386,7 +572,7 @@ void testAddPromptWithoutCapability() { .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -397,7 +583,7 @@ void testRemovePromptWithoutCapability() { McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(error).isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); }); } @@ -427,10 +613,7 @@ void testRemoveNonexistentPrompt() { .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyComplete(); assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) .doesNotThrowAnyException(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 5d70ae4c0..eee5f1a4d 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -4,17 +4,8 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - import java.util.List; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -25,6 +16,14 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Test suite for the {@link McpSyncServer} that can be used with different @@ -77,14 +76,14 @@ void testConstructorWithInvalidArguments() { void testGracefulShutdown() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testImmediateClose() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::close).doesNotThrowAnyException(); } @Test @@ -93,111 +92,90 @@ void testGetAsyncServer() { assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- // Tools Tests // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @Test - @Deprecated - void testAddTool() { + void testAddToolCall() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddToolCall() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) + Tool newTool = McpSchema.Tool.builder() + .name("new-tool") + .title("New test tool") + .inputSchema(EMPTY_JSON_SCHEMA) .build(); - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() .tool(newTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build())).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - @Deprecated - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Duplicate tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() + assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build())).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build())).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("duplicate-build-toolcall") + .title("Duplicate toolcall during building") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .toolCall(duplicateTool, + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) // Duplicate! .build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); } @Test void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-list-tool") + .title("Duplicate tool in batch list") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); List specs = List.of( McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(), McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler( + (exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build() // Duplicate! ); @@ -210,17 +188,22 @@ void testDuplicateToolsInBatchListRegistration() { @Test void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); + Tool duplicateTool = McpSchema.Tool.builder() + .name("batch-varargs-tool") + .title("Duplicate tool in batch varargs") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(), McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) + .callHandler((exchange, + request) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build() // Duplicate! ) .build()).isInstanceOf(IllegalArgumentException.class) @@ -229,16 +212,20 @@ void testDuplicateToolsInBatchVarargsRegistration() { @Test void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + Tool tool = McpSchema.Tool.builder() + .name(TEST_TOOL_NAME) + .title("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) + .toolCall(tool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -247,19 +234,18 @@ void testRemoveNonexistentTool() { .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); + assertThatCode(() -> mcpSyncServer.removeTool("nonexistent-tool")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testNotifyToolsListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyToolsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -270,9 +256,9 @@ void testNotifyToolsListChanged() { void testNotifyResourcesListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyResourcesListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -283,7 +269,7 @@ void testNotifyResourcesUpdated() { .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) .doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -292,14 +278,19 @@ void testAddResource() { .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -309,31 +300,211 @@ void testAddResourceWithNullSpecification() { .build(); assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Resource must not be null"); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test void testAddResourceWithoutCapability() { var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testListResources() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + List resources = mcpSyncServer.listResources(); + + assertThat(resources).hasSize(1); + assertThat(resources.get(0).uri()).isEqualTo(TEST_RESOURCE_URI); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = Resource.builder() + .uri(TEST_RESOURCE_URI) + .name("Test Resource") + .title("Test Resource") + .mimeType("text/plain") + .description("Test resource description") + .build(); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); + + mcpSyncServer.addResource(specification); + assertThatCode(() -> mcpSyncServer.removeResource(TEST_RESOURCE_URI)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResource() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource should complete successfully (no error) + // as per the new implementation that just logs a warning + assertThatCode(() -> mcpSyncServer.removeResource("nonexistent://resource")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResourceTemplate(specification)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceTemplate() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("test://template/{id}")).doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResourceTemplate("test://template/{id}")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate("nonexistent://template/{id}")) + .doesNotThrowAnyException(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testListResourceTemplates() { + McpSchema.ResourceTemplate template = McpSchema.ResourceTemplate.builder() + .uriTemplate("test://template/{id}") + .name("test-template") + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + List templates = mcpSyncServer.listResourceTemplates(); + + assertThat(templates).isNotNull(); + + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -344,9 +515,9 @@ void testRemoveResourceWithoutCapability() { void testNotifyPromptsListChanged() { var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::notifyPromptsListChanged).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -356,7 +527,7 @@ void testAddPromptWithNullSpecification() { .build(); assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) - .isInstanceOf(McpError.class) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Prompt specification must not be null"); } @@ -369,7 +540,8 @@ void testAddPromptWithoutCapability() { (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -377,7 +549,8 @@ void testAddPromptWithoutCapability() { void testRemovePromptWithoutCapability() { var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)) + .isInstanceOf(IllegalStateException.class) .hasMessage("Server must be configured with prompt capabilities"); } @@ -395,7 +568,7 @@ void testRemovePrompt() { assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } @Test @@ -404,10 +577,9 @@ void testRemoveNonexistentPrompt() { .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + assertThatCode(() -> mcpSyncServer.removePrompt("nonexistent://template/{id}")).doesNotThrowAnyException(); - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(mcpSyncServer::closeGracefully).doesNotThrowAnyException(); } // --------------------------------------- @@ -428,9 +600,8 @@ void testRootsChangeHandlers() { } })) .build(); - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(singleConsumerServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test with multiple consumers @@ -446,7 +617,7 @@ void testRootsChangeHandlers() { .build(); assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(multipleConsumersServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test error handling @@ -457,14 +628,14 @@ void testRootsChangeHandlers() { .build(); assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(errorHandlingServer::closeGracefully).doesNotThrowAnyException(); onClose(); // Test without consumers var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java new file mode 100644 index 000000000..a72fc1db8 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/McpJsonMapperUtils.java @@ -0,0 +1,13 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; + +public final class McpJsonMapperUtils { + + private McpJsonMapperUtils() { + } + + public static final McpJsonMapper JSON_MAPPER = McpJsonDefaults.getMapper(); + +} \ No newline at end of file diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java new file mode 100644 index 000000000..ce8755223 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.util; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.Collections; + +public final class ToolsUtils { + + private ToolsUtils() { + } + + public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", + Collections.emptyMap(), null, null, null, null); + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java similarity index 70% rename from mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java rename to mcp-test/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index 5484a63c2..4e74dac3e 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -9,40 +9,47 @@ import java.util.function.BiConsumer; import java.util.function.Function; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; -import io.modelcontextprotocol.spec.McpServerTransport; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link McpClientTransport} and {@link McpServerTransport} - * interfaces. - * - * @deprecated not used. to be removed in the future. + * A mock implementation of the {@link McpClientTransport} interfaces. */ -@Deprecated -public class MockMcpTransport implements McpClientTransport, McpServerTransport { +public class MockMcpClientTransport implements McpClientTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); private final List sent = new ArrayList<>(); - private final BiConsumer interceptor; + private final BiConsumer interceptor; - public MockMcpTransport() { + private String protocolVersion = ProtocolVersions.MCP_2025_11_25; + + public MockMcpClientTransport() { this((t, msg) -> { }); } - public MockMcpTransport(BiConsumer interceptor) { + public MockMcpClientTransport(BiConsumer interceptor) { this.interceptor = interceptor; } + public MockMcpClientTransport withProtocolVersion(String protocolVersion) { + return this; + } + + @Override + public List protocolVersions() { + return List.of(protocolVersion); + } + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { if (inbound.tryEmitNext(message).isFailure()) { throw new RuntimeException("Failed to process incoming message " + message); @@ -93,8 +100,8 @@ public Mono closeGracefully() { } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonDefaults.getMapper().convertValue(data, typeRef); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java b/mcp-test/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java similarity index 80% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java rename to mcp-test/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java index 4be680e11..fac26596a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/MockMcpServerTransport.java @@ -8,8 +8,8 @@ import java.util.List; import java.util.function.BiConsumer; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; @@ -53,14 +53,22 @@ public McpSchema.JSONRPCMessage getLastSentMessage() { return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; } + public void clearSentMessages() { + sent.clear(); + } + + public List getAllSentMessages() { + return new ArrayList<>(sent); + } + @Override public Mono closeGracefully() { return Mono.empty(); } @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return new ObjectMapper().convertValue(data, typeRef); + public T unmarshalFrom(Object data, TypeRef typeRef) { + return McpJsonDefaults.getMapper().convertValue(data, typeRef); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp-test/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java rename to mcp-test/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java similarity index 70% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java index aef2ab8dd..a29ca16db 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java @@ -4,41 +4,40 @@ package io.modelcontextprotocol.client; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; - @Timeout(15) public class HttpClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { - private String host = "http://localhost:3001"; + private static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override protected McpClientTransport createMcpTransport() { - return HttpClientStreamableHttpTransport.builder(host).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - public void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..ee5e5de05 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +@Timeout(15) +public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(URI.create(host + "/mcp")), + eq("{\"jsonrpc\":\"2.0\",\"method\":\"notifications/initialized\"}"), + eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java index 0a72b785d..e2037f415 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java @@ -28,7 +28,7 @@ import io.modelcontextprotocol.spec.McpSchema; import reactor.test.StepVerifier; -@Timeout(15) +@Timeout(20) public class HttpSseMcpAsyncClientLostConnectionTests { private static final Logger logger = LoggerFactory.getLogger(HttpSseMcpAsyncClientLostConnectionTests.class); @@ -36,10 +36,9 @@ public class HttpSseMcpAsyncClientLostConnectionTests { static Network network = Network.newNetwork(); static String host = "http://localhost:3001"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withNetwork(network) .withNetworkAliases("everything-server") @@ -98,10 +97,13 @@ McpAsyncClient client(McpClientTransport transport) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { + // Do not advertise roots. Otherwise, the server will list roots during + // initialization. The client responds asynchronously, and there might be a + // rest condition in tests where we disconnect right after initialization. McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(Duration.ofSeconds(14)) .initializationTimeout(Duration.ofSeconds(2)) - .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + .capabilities(McpSchema.ClientCapabilities.builder().build()); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -119,7 +121,7 @@ void withClient(McpClientTransport transport, Consumer c) { } @Test - void testPingWithEaxctExceptionType() { + void testPingWithExactExceptionType() { withClient(HttpClientSseClientTransport.builder(host).build(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java similarity index 72% rename from mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 6cb3f7b65..91a8b6c82 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -19,12 +21,11 @@ @Timeout(15) class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { - String host = "http://localhost:3004"; + private static String host = "http://localhost:3004"; - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -34,15 +35,15 @@ protected McpClientTransport createMcpTransport() { return HttpClientSseClientTransport.builder(host).build(); } - @Override - protected void onStart() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; } - @Override - protected void onClose() { + @AfterAll + static void stopContainer() { container.stop(); } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java new file mode 100644 index 000000000..d903b3b3c --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.net.URI; +import java.util.Map; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpClientTransport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3003"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + private final McpSyncHttpClientRequestCustomizer requestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientSseClientTransport.builder(host).httpRequestCustomizer(requestCustomizer).build(); + } + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + @Test + void customizesRequests() { + var mcpTransportContext = McpTransportContext.create(Map.of("some-key", "some-value")); + withClient(createMcpTransport(), syncSpec -> syncSpec.transportContextProvider(() -> mcpTransportContext), + mcpSyncClient -> { + mcpSyncClient.initialize(); + + verify(requestCustomizer, atLeastOnce()).customize(any(), eq("GET"), eq(URI.create(host + "/sse")), + isNull(), eq(mcpTransportContext)); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java similarity index 95% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index cab847512..47a229afd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -4,16 +4,16 @@ package io.modelcontextprotocol.client; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.Function; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.MockMcpClientTransport; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; @@ -24,6 +24,7 @@ import reactor.core.publisher.Mono; import static io.modelcontextprotocol.spec.McpSchema.METHOD_INITIALIZE; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -42,7 +43,7 @@ private static MockMcpClientTransport initializationEnabledTransport() { private static MockMcpClientTransport initializationEnabledTransport( McpSchema.ServerCapabilities mockServerCapabilities, McpSchema.Implementation mockServerInfo) { - McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.InitializeResult mockInitResult = new McpSchema.InitializeResult(ProtocolVersions.MCP_2025_11_25, mockServerCapabilities, mockServerInfo, "Test instructions"); return new MockMcpClientTransport((t, message) -> { @@ -51,7 +52,7 @@ private static MockMcpClientTransport initializationEnabledTransport( r.id(), mockInitResult, null); t.simulateIncomingMessage(initResponse); } - }).withProtocolVersion(McpSchema.LATEST_PROTOCOL_VERSION); + }).withProtocolVersion(ProtocolVersions.MCP_2025_11_25); } @Test @@ -93,7 +94,7 @@ void testSuccessfulInitialization() { } @Test - void testToolsChangeNotificationHandling() throws JsonProcessingException { + void testToolsChangeNotificationHandling() throws IOException { MockMcpClientTransport transport = initializationEnabledTransport(); // Create a list to store received tools for verification @@ -110,8 +111,11 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { // Create a mock tools list that the server will return Map inputSchema = Map.of("type", "object", "properties", Map.of(), "required", List.of()); - McpSchema.Tool mockTool = new McpSchema.Tool("test-tool-1", "Test Tool 1 Description", - new ObjectMapper().writeValueAsString(inputSchema)); + McpSchema.Tool mockTool = McpSchema.Tool.builder() + .name("test-tool-1") + .description("Test Tool 1 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); // Create page 1 response with nextPageToken String nextPageToken = "page2Token"; @@ -131,9 +135,11 @@ void testToolsChangeNotificationHandling() throws JsonProcessingException { transport.simulateIncomingMessage(toolsListResponse1); // Create mock tools for page 2 - McpSchema.Tool mockTool2 = new McpSchema.Tool("test-tool-2", "Test Tool 2 Description", - new ObjectMapper().writeValueAsString(inputSchema)); - + McpSchema.Tool mockTool2 = McpSchema.Tool.builder() + .name("test-tool-2") + .description("Test Tool 2 Description") + .inputSchema(JSON_MAPPER, JSON_MAPPER.writeValueAsString(inputSchema)) + .build(); // Create page 2 response with no nextPageToken (last page) McpSchema.ListToolsResult mockToolsResult2 = new McpSchema.ListToolsResult(List.of(mockTool2), null); @@ -207,8 +213,12 @@ void testResourcesChangeNotificationHandling() { assertThat(asyncMcpClient.initialize().block()).isNotNull(); // Create a mock resources list that the server will return - McpSchema.Resource mockResource = new McpSchema.Resource("test://resource", "Test Resource", "A test resource", - "text/plain", null); + McpSchema.Resource mockResource = McpSchema.Resource.builder() + .uri("test://resource") + .name("Test Resource") + .description("A test resource") + .mimeType("text/plain") + .build(); McpSchema.ListResourcesResult mockResourcesResult = new McpSchema.ListResourcesResult(List.of(mockResource), null); @@ -321,7 +331,7 @@ void testSamplingCreateMessageRequestHandling() { assertThat(response.error()).isNull(); McpSchema.CreateMessageResult result = transport.unmarshalFrom(response.result(), - new TypeReference() { + new TypeRef() { }); assertThat(result).isNotNull(); assertThat(result.role()).isEqualTo(McpSchema.Role.ASSISTANT); @@ -425,7 +435,7 @@ void testElicitationCreateRequestHandling() { assertThat(response.id()).isEqualTo("test-id"); assertThat(response.error()).isNull(); - McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { }); assertThat(result).isNotNull(); assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); @@ -470,7 +480,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { assertThat(response.id()).isEqualTo("test-id"); assertThat(response.error()).isNull(); - McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { + McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeRef<>() { }); assertThat(result).isNotNull(); assertThat(result.action()).isEqualTo(action); @@ -551,4 +561,4 @@ void testPingMessageRequestHandling() { asyncMcpClient.closeGracefully(); } -} \ No newline at end of file +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java new file mode 100644 index 000000000..48bf1da5b --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java @@ -0,0 +1,310 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +class McpAsyncClientTests { + + public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", + "1.0.0"); + + public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() + .tools(true) + .build(); + + public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( + ProtocolVersions.MCP_2024_11_05, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); + + private static final String CONTEXT_KEY = "context.key"; + + private McpClientTransport createMockTransportForToolValidation(boolean hasOutputSchema, boolean invalidOutput) { + + // Create tool with or without output schema + Map inputSchemaMap = Map.of("type", "object", "properties", + Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); + + McpSchema.JsonSchema inputSchema = new McpSchema.JsonSchema("object", inputSchemaMap, null, null, null, null); + McpSchema.Tool.Builder toolBuilder = McpSchema.Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .inputSchema(inputSchema); + + if (hasOutputSchema) { + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + toolBuilder.outputSchema(outputSchema); + } + + McpSchema.Tool calculatorTool = toolBuilder.build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(calculatorTool), null); + + // Create call tool result - valid or invalid based on parameter + Map structuredContent = invalidOutput ? Map.of("result", "5", "operation", "add") + : Map.of("result", 5, "operation", "add"); + + McpSchema.CallToolResult mockCallToolResult = McpSchema.CallToolResult.builder() + .addTextContent("Calculation result") + .structuredContent(structuredContent) + .build(); + + return new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + this.handler = handler; + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else if (McpSchema.METHOD_TOOLS_CALL.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + mockCallToolResult, null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + } + + @Test + void validateContextPassedToTransportConnect() { + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + final AtomicReference contextValue = new AtomicReference<>(); + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + if (ctx.hasKey(CONTEXT_KEY)) { + this.contextValue.set(ctx.get(CONTEXT_KEY)); + } + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!"hello".equals(this.contextValue.get())) { + return Mono.error(new RuntimeException("Context value not propagated via #connect method")); + } + // We're only interested in handling the init request to provide an init + // response + if (!(message instanceof McpSchema.JSONRPCRequest)) { + return Mono.empty(); + } + McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); + return handler.apply(Mono.just(initResponse)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + assertThatCode(() -> { + McpAsyncClient client = McpClient.async(transport).build(); + client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); + }).doesNotThrowAnyException(); + } + + @Test + void testCallToolWithOutputSchemaValidationSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(true, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithNoOutputSchemaSuccess() { + McpClientTransport transport = createMockTransportForToolValidation(false, false); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectNextMatches(response -> { + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.structuredContent()).isInstanceOf(Map.class); + assertThat((Map) response.structuredContent()).hasSize(2); + assertThat(response.content()).hasSize(1); + return true; + }) + .verifyComplete(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testCallToolWithOutputSchemaValidationFailure() { + McpClientTransport transport = createMockTransportForToolValidation(true, true); + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + + StepVerifier.create(client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")))) + .expectErrorMatches(ex -> ex instanceof IllegalArgumentException + && ex.getMessage().contains("Tool call result validation failed")) + .verify(); + + StepVerifier.create(client.closeGracefully()).verifyComplete(); + } + + @Test + void testListToolsWithEmptyCursor() { + McpSchema.Tool addTool = McpSchema.Tool.builder().name("add").description("calculate add").build(); + McpSchema.Tool subtractTool = McpSchema.Tool.builder() + .name("subtract") + .description("calculate subtract") + .build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(addTool, subtractTool), ""); + + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); + } + else { + return Mono.empty(); + } + + return handler.apply(Mono.just(response)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } + }; + + McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + + Mono mono = client.listTools(); + McpSchema.ListToolsResult toolsResult = mono.block(); + assertThat(toolsResult).isNotNull(); + + Set names = toolsResult.tools().stream().map(McpSchema.Tool::name).collect(Collectors.toSet()); + assertThat(names).containsExactlyInAnyOrder("subtract", "add"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java similarity index 88% rename from mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index 3feb1d05c..03f64aa64 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -8,9 +8,10 @@ import java.util.List; import io.modelcontextprotocol.MockMcpClientTransport; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.ProtocolVersions; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -22,7 +23,7 @@ */ class McpClientProtocolVersionTests { - private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(30); + private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(300); private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); @@ -46,13 +47,12 @@ void shouldUseLatestVersionByDefault() { assertThat(initRequest.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(protocolVersion, null, + new McpSchema.InitializeResult(protocolVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { assertThat(result.protocolVersion()).isEqualTo(protocolVersion); }).verifyComplete(); - } finally { // Ensure cleanup happens even if test fails @@ -69,7 +69,7 @@ void shouldNegotiateSpecificVersion() { .requestTimeout(REQUEST_TIMEOUT) .build(); - client.setProtocolVersions(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); + client.setProtocolVersions(List.of(oldVersion, ProtocolVersions.MCP_2025_11_25)); try { Mono initializeResultMono = client.initialize(); @@ -78,10 +78,10 @@ void shouldNegotiateSpecificVersion() { McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); McpSchema.InitializeRequest initRequest = (McpSchema.InitializeRequest) request.params(); - assertThat(initRequest.protocolVersion()).isIn(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); + assertThat(initRequest.protocolVersion()).isIn(List.of(oldVersion, ProtocolVersions.MCP_2025_11_25)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(oldVersion, null, + new McpSchema.InitializeResult(oldVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { @@ -110,7 +110,7 @@ void shouldFailForUnsupportedVersion() { assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(unsupportedVersion, null, + new McpSchema.InitializeResult(unsupportedVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).expectError(RuntimeException.class).verify(); @@ -124,7 +124,7 @@ void shouldFailForUnsupportedVersion() { void shouldUseHighestVersionWhenMultipleSupported() { String oldVersion = "0.1.0"; String middleVersion = "0.2.0"; - String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION; + String latestVersion = ProtocolVersions.MCP_2025_11_25; MockMcpClientTransport transport = new MockMcpClientTransport(); McpAsyncClient client = McpClient.async(transport) @@ -143,7 +143,7 @@ void shouldUseHighestVersionWhenMultipleSupported() { assertThat(initRequest.protocolVersion()).isEqualTo(latestVersion); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(latestVersion, null, + new McpSchema.InitializeResult(latestVersion, ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java new file mode 100644 index 000000000..547ccc52f --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/ServerParameterUtils.java @@ -0,0 +1,21 @@ +package io.modelcontextprotocol.client; + +import io.modelcontextprotocol.client.transport.ServerParameters; + +public final class ServerParameterUtils { + + private ServerParameterUtils() { + } + + public static ServerParameters createServerParameters() { + if (System.getProperty("os.name").toLowerCase().contains("win")) { + return ServerParameters.builder("cmd.exe") + .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything@2025.12.18", "stdio") + .build(); + } + return ServerParameters.builder("npx") + .args("-y", "@modelcontextprotocol/server-everything@2025.12.18", "stdio") + .build(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java similarity index 50% rename from mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index e9356d0c0..aa8aaa397 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -11,33 +11,37 @@ import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for the {@link McpAsyncClient} with {@link StdioClientTransport}. * + *

+ * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * * @author Christian Tzolov * @author Dariusz JΔ™drzejczyk */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams; - if (System.getProperty("os.name").toLowerCase().contains("win")) { - stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - else { - stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - return new StdioClientTransport(stdioParams); + return new StdioClientTransport(createServerParameters(), JSON_MAPPER); } protected Duration getInitializationTimeout() { return Duration.ofSeconds(20); } + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java similarity index 69% rename from mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 4b5f4f9c0..08e5ea61a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -17,31 +17,30 @@ import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import static io.modelcontextprotocol.client.ServerParameterUtils.createServerParameters; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; /** * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. * + *

+ * These tests use npx to download and run the MCP "everything" server locally. The first + * test execution will download the everything server scripts and cache them locally, + * which can take more than 15 seconds. Subsequent test runs will use the cached version + * and execute faster. + * * @author Christian Tzolov * @author Dariusz JΔ™drzejczyk */ -@Timeout(15) // Giving extra time beyond the client timeout +@Timeout(25) // Giving extra time beyond the client timeout to account for initial server + // download class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - ServerParameters stdioParams; - if (System.getProperty("os.name").toLowerCase().contains("win")) { - stdioParams = ServerParameters.builder("cmd.exe") - .args("/c", "npx.cmd", "-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - else { - stdioParams = ServerParameters.builder("npx") - .args("-y", "@modelcontextprotocol/server-everything", "stdio") - .build(); - } - return new StdioClientTransport(stdioParams); + ServerParameters stdioParams = createServerParameters(); + return new StdioClientTransport(stdioParams, JSON_MAPPER); } @Test @@ -68,7 +67,12 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException { } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(25); + } + + @Override + protected Duration getRequestTimeout() { + return Duration.ofSeconds(25); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java similarity index 88% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 46b9207f6..a24805a30 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -14,9 +14,12 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; + import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -33,6 +36,7 @@ import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.util.UriComponentsBuilder; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; @@ -54,14 +58,16 @@ class HttpClientSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); private TestHttpClientSseClientTransport transport; + private final McpTransportContext context = McpTransportContext.create(Map.of("some-key", "some-value")); + // Test class to access protected methods static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport { @@ -71,8 +77,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo public TestHttpClientSseClientTransport(final String baseUri) { super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), - HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", - new ObjectMapper(), AsyncHttpRequestCustomizer.NOOP); + HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER, + McpAsyncHttpClientRequestCustomizer.NOOP); } public int getInboundMessageCount() { @@ -389,7 +395,7 @@ void testChainedCustomizations() { @Test void testRequestCustomizer() { - var mockCustomizer = mock(SyncHttpRequestCustomizer.class); + var mockCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); // Create a transport with the customizer var customizedTransport = HttpClientSseClientTransport.builder(host) @@ -397,11 +403,14 @@ void testRequestCustomizer() { .build(); // Connect - StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockCustomizer).customize(any(), eq("GET"), - eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); clearInvocations(mockCustomizer); // Send test message @@ -409,12 +418,16 @@ void testRequestCustomizer() { Map.of("key", "value")); // Subscribe to messages and verify - StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); // Clean up @@ -423,8 +436,8 @@ void testRequestCustomizer() { @Test void testAsyncRequestCustomizer() { - var mockCustomizer = mock(AsyncHttpRequestCustomizer.class); - when(mockCustomizer.customize(any(), any(), any(), any())) + var mockCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any(), any())) .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); // Create a transport with the customizer @@ -433,11 +446,14 @@ void testAsyncRequestCustomizer() { .build(); // Connect - StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + StepVerifier + .create(customizedTransport.connect(Function.identity()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockCustomizer).customize(any(), eq("GET"), - eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull(), eq(context)); clearInvocations(mockCustomizer); // Send test message @@ -445,12 +461,16 @@ void testAsyncRequestCustomizer() { Map.of("key", "value")); // Subscribe to messages and verify - StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + StepVerifier + .create(customizedTransport.sendMessage(testMessage) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}"), + eq(context)); assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); // Clean up diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java index 8b3668671..81e642681 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java @@ -22,6 +22,7 @@ import com.sun.net.httpserver.HttpServer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.ProtocolVersions; @@ -70,14 +71,14 @@ static void stopContainer() { void testNotificationInitialized() throws URISyntaxException { var uri = new URI(host + "/mcp"); - var mockRequestCustomizer = mock(SyncHttpRequestCustomizer.class); + var mockRequestCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); var transport = HttpClientStreamableHttpTransport.builder(host) .httpRequestCustomizer(mockRequestCustomizer) .build(); var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, McpSchema.ClientCapabilities.builder().roots(true).build(), - new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + new McpSchema.Implementation("MCP Client", "0.3.1")); var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", initializeRequest); @@ -85,7 +86,8 @@ void testNotificationInitialized() throws URISyntaxException { // Verify the customizer was called verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + any()); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java similarity index 99% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java index 2b502a83b..b82d6eb2c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -63,7 +63,7 @@ void startServer() throws IOException { if ("DELETE".equals(httpExchange.getRequestMethod())) { httpExchange.sendResponseHeaders(200, 0); } - else { + else if ("POST".equals(httpExchange.getRequestMethod())) { // Capture session ID from request if present String requestSessionId = httpExchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); lastReceivedSessionId.set(requestSessionId); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java similarity index 53% rename from mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java rename to mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java index d645bb0b3..f88736a5d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java @@ -4,9 +4,14 @@ package io.modelcontextprotocol.client.transport; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; import java.net.URI; import java.net.URISyntaxException; +import java.util.Map; import java.util.function.Consumer; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -32,9 +37,12 @@ class HttpClientStreamableHttpTransportTest { static String host = "http://localhost:3001"; + private McpTransportContext context = McpTransportContext + .create(Map.of("test-transport-context-key", "some-value")); + @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") + static GenericContainer container = new GenericContainer<>("docker.io/node:lts-alpine3.23") + .withCommand("npx -y @modelcontextprotocol/server-everything@2025.12.18 streamableHttp") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) .waitingFor(Wait.forHttp("/").forStatusCode(404)); @@ -63,7 +71,7 @@ void withTransport(HttpClientStreamableHttpTransport transport, Consumer { // Send test message - var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_11_25, McpSchema.ClientCapabilities.builder().roots(true).build(), - new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + new McpSchema.Implementation("MCP Client", "0.3.1")); var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", initializeRequest); - StepVerifier.create(t.sendMessage(testMessage)).verifyComplete(); + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); }); } @Test void testAsyncRequestCustomizer() throws URISyntaxException { var uri = new URI(host + "/mcp"); - var mockRequestCustomizer = mock(AsyncHttpRequestCustomizer.class); - when(mockRequestCustomizer.customize(any(), any(), any(), any())) + var mockRequestCustomizer = mock(McpAsyncHttpClientRequestCustomizer.class); + when(mockRequestCustomizer.customize(any(), any(), any(), any(), any())) .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); var transport = HttpClientStreamableHttpTransport.builder(host) @@ -98,18 +109,56 @@ void testAsyncRequestCustomizer() throws URISyntaxException { withTransport(transport, (t) -> { // Send test message - var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_11_25, McpSchema.ClientCapabilities.builder().roots(true).build(), - new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + new McpSchema.Implementation("MCP Client", "0.3.1")); var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, "test-id", initializeRequest); - StepVerifier.create(t.sendMessage(testMessage)).verifyComplete(); + StepVerifier + .create(t.sendMessage(testMessage).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, context))) + .verifyComplete(); // Verify the customizer was called verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( - "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"MCP Client\",\"version\":\"0.3.1\"}}}"), + eq(context)); }); } + @Test + void testCloseUninitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_11_25, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMessage("MCP session has been closed") + .verify(); + } + + @Test + void testCloseInitialized() { + var transport = HttpClientStreamableHttpTransport.builder(host).build(); + + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_11_25, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + StepVerifier.create(transport.sendMessage(testMessage)) + .expectErrorMatches(err -> err.getMessage().matches("MCP session with ID [a-zA-Z0-9-]* has been closed")) + .verify(); + } + } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..ce381436d --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,288 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers. + * + *

+ * This test class validates the end-to-end flow of transport context propagation in MCP + * communication, demonstrating how contextual information can be passed from client to + * server through HTTP headers and accessed within server-side handlers. + * + *

Test Scenarios

+ *

+ * The tests cover multiple transport configurations with async servers: + *

    + *
  • Stateless server with async streamable HTTP clients
  • + *
  • Streamable server with async streamable HTTP clients
  • + *
  • SSE (Server-Sent Events) server with async SSE clients
  • + *
+ * + *

Context Propagation Flow

+ *
    + *
  1. Client-side: Context data is stored in the Reactor Context and injected into HTTP + * headers via {@link McpSyncHttpClientRequestCustomizer}
  2. + *
  3. Transport: The context travels as HTTP headers (specifically "x-test" header in + * these tests)
  4. + *
  5. Server-side: A {@link McpTransportContextExtractor} extracts the header value and + * makes it available to request handlers through {@link McpTransportContext}
  6. + *
  7. Verification: The server echoes back the received context value as the tool call + * result
  8. + *
+ * + *

+ * All tests use an embedded Tomcat server running on a dynamically allocated port to + * ensure isolation and prevent port conflicts during parallel test execution. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final String HEADER_NAME = "x-test"; + + private final McpAsyncHttpClientRequestCustomizer asyncClientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + return Mono.just(builder); + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpAsyncClient asyncStreamableClient = McpClient + .async(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) + .build()) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono.just(McpSchema.CallToolResult.builder() + .addTextContent(transportContext.get("server-side-header-value").toString()) + .isError(false) + .build()); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + @AfterEach + public void after() { + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } + stopTomcat(); + } + + @Test + void asyncClinetStatelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler(asyncStatefulHandler) + .build()) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.AsyncToolSpecification.builder() + .tool(tool) + .callHandler(asyncStatefulHandler) + .build()) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java new file mode 100644 index 000000000..29eef1410 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/common/HttpClientStreamableHttpVersionNegotiationIntegrationTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.McpTestRequestRecordingServletFilter; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class HttpClientStreamableHttpVersionNegotiationIntegrationTests { + + private Tomcat tomcat; + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private final McpTestRequestRecordingServletFilter requestRecordingFilter = new McpTestRequestRecordingServletFilter(); + + private final HttpServletStreamableServerTransportProvider transport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor( + req -> McpTransportContext.create(Map.of("protocol-version", req.getHeader("MCP-protocol-version")))) + .build(); + + private final McpSchema.Tool toolSpec = McpSchema.Tool.builder() + .name("test-tool") + .description("return the protocol version used") + .build(); + + private final BiFunction toolHandler = ( + exchange, request) -> McpSchema.CallToolResult.builder() + .addTextContent(exchange.transportContext().get("protocol-version").toString()) + .isError(false) + .build(); + + McpSyncServer mcpServer = McpServer.sync(transport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(false).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder().tool(toolSpec).callHandler(toolHandler).build()) + .build(); + + @AfterEach + void tearDown() { + stopTomcat(); + } + + @Test + void usesLatestVersion() { + startTomcat(); + + var client = McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT).build()) + .build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + + assertThat(calls).filteredOn(c -> !c.body().contains("\"method\":\"initialize\"")) + // GET /mcp ; POST notification/initialized ; POST tools/call + .hasSize(3) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_11_25)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_11_25); + mcpServer.close(); + } + + @Test + void usesServerSupportedVersion() { + startTomcat(); + + var transport = HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .supportedProtocolVersions(List.of(ProtocolVersions.MCP_2025_11_25, "2263-03-18")) + .build(); + var client = McpClient.sync(transport).build(); + + client.initialize(); + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + var calls = requestRecordingFilter.getCalls(); + // Initialize tells the server the Client's latest supported version + // FIXME: Set the correct protocol version on GET /mcp + assertThat(calls).filteredOn(c -> c.method().equals("POST") && !c.body().contains("\"method\":\"initialize\"")) + // POST notification/initialized ; POST tools/call + .hasSize(2) + .map(McpTestRequestRecordingServletFilter.Call::headers) + .allSatisfy(headers -> assertThat(headers).containsEntry("mcp-protocol-version", + ProtocolVersions.MCP_2025_11_25)); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo(ProtocolVersions.MCP_2025_11_25); + mcpServer.close(); + } + + private void startTomcat() { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport, requestRecordingFilter); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..563e2167d --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,245 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test both Client and Server {@link McpTransportContext} integration, in two steps. + *

+ * First, the client calls a tool and writes data stored in a thread-local to an HTTP + * header using {@link SyncSpec#transportContextProvider(Supplier)} and + * {@link McpSyncHttpClientRequestCustomizer}. + *

+ * Then the server reads the header with a {@link McpTransportContextExtractor} and + * returns the value as the result of the tool call. + * + * @author Daniel Garnier-Moiroux + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, request) -> McpSchema.CallToolResult.builder() + .addTextContent(transportContext.get("server-side-header-value").toString()) + .isError(false) + .build(); + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder().tool(tool).callHandler(statefulHandler).build()) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder().tool(tool).callHandler(statefulHandler).build()) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java similarity index 87% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java index 0f2991a9f..5841c13da 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -4,26 +4,28 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; - import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; - -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SyncSpecification; import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import static org.assertj.core.api.Assertions.assertThat; @Timeout(15) class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -38,11 +40,14 @@ class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationT private Tomcat tomcat; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + @BeforeEach public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .contextExtractor(TEST_CONTEXT_EXTRACTOR) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) @@ -94,9 +99,7 @@ public void after() { protected void prepareClients(int port, String mcpEndpoint) { } - static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r, tc) -> { - tc.put("important", "value"); - return tc; - }; + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java similarity index 76% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index a8951e6dc..491c2d4ed 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -4,22 +4,30 @@ package io.modelcontextprotocol.server; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.PromptArgument; import io.modelcontextprotocol.spec.McpSchema.PromptReference; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.ProtocolVersions; import net.javacrumbs.jsonunit.core.Option; @@ -32,19 +40,15 @@ import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.web.client.RestClient; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; - import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; @@ -66,7 +70,6 @@ class HttpServletStatelessIntegrationTests { @BeforeEach public void before() { this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); @@ -105,24 +108,19 @@ public void after() { // --------------------------------------- // Tools Tests // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient" }) void testToolCallSuccess(String clientType) { var clientBuilder = clientBuilders.get(clientType); - var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + var callResponse = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("CALL RESPONSE"))) + .isError(false) + .build(); McpStatelessServerFeatures.SyncToolSpecification tool1 = new McpStatelessServerFeatures.SyncToolSpecification( - new Tool("tool1", "tool1 description", emptyJsonSchema), (transportContext, request) -> { + Tool.builder().name("tool1").title("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build(), + (transportContext, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -150,8 +148,9 @@ void testToolCallSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -166,8 +165,9 @@ void testInitialize(String clientType) { InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } // --------------------------------------- @@ -197,7 +197,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { List.of(new PromptArgument("language", "Language", "string", false))), (transportContext, getPromptRequest) -> null)) .completions(new McpStatelessServerFeatures.SyncCompletionSpecification( - new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), completionHandler)) .build(); try (var mcpClient = clientBuilder.build()) { @@ -206,7 +206,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(initResult).isNotNull(); CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), + new PromptReference(PromptReference.TYPE, "code_review", "Code review"), new CompleteRequest.CompleteArgument("language", "py")); CompleteResult result = mcpClient.completeCompletion(request); @@ -215,10 +215,11 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + } + finally { + mcpServer.close(); } - - mcpServer.close(); } // --------------------------------------- @@ -289,8 +290,129 @@ void testStructuredOutputValidationSuccess(String clientType) { .isEqualTo(json(""" {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); } + finally { + mcpServer.close(); + } + } - mcpServer.close(); + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema that returns an array of objects + Map outputSchema = Map + .of( // @formatter:off + "type", "array", + "items", Map.of( + "type", "object", + "properties", Map.of( + "name", Map.of("type", "string"), + "age", Map.of("type", "number")), + "required", List.of("name", "age"))); // @formatter:on + + Tool calculatorTool = Tool.builder() + .name("getMembers") + .description("Returns a list of members") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + return CallToolResult.builder() + .structuredContent(List.of(Map.of("name", "John", "age", 30), Map.of("name", "Peter", "age", 25))) + .build(); + }) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + assertThat(mcpClient.initialize()).isNotNull(); + + // Call tool with valid structured output of type array + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("getMembers", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isArray() + .hasSize(2) + .containsExactlyInAnyOrder(json(""" + {"name":"John","age":30}"""), json(""" + {"name":"Peter","age":25}""")); + } + finally { + mcpServer.closeGracefully(); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputWithInHandlerError(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + // Handler that returns an error result + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> CallToolResult.builder() + .isError(true) + .content(List.of(new TextContent("Error calling tool: Simulated in-handler error"))) + .build()) + .build(); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).isNotEmpty(); + assertThat(response.content()) + .containsExactly(new McpSchema.TextContent("Error calling tool: Simulated in-handler error")); + assertThat(response.structuredContent()).isNull(); + } + finally { + mcpServer.closeGracefully(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -341,8 +463,9 @@ void testStructuredOutputValidationFailure(String clientType) { String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); assertThat(errorMessage).contains("Validation failed"); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -390,8 +513,9 @@ void testStructuredOutputMissingStructuredContent(String clientType) { assertThat(errorMessage).isEqualTo( "Response missing structured content which is expected when calling tool with non-empty outputSchema"); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @ParameterizedTest(name = "{0} : {displayName} ") @@ -464,12 +588,13 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { .isEqualTo(json(""" {"count":3,"message":"Dynamic execution"}""")); } - - mcpServer.close(); + finally { + mcpServer.close(); + } } @Test - void testThrownMcpError() throws Exception { + void testThrownMcpErrorAndJsonRpcError() throws Exception { var mcpServer = McpServer.sync(mcpStatelessServerTransport) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) @@ -479,7 +604,7 @@ void testThrownMcpError() throws Exception { McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( testTool, (transportContext, request) -> { - throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + throw new RuntimeException("testing"); }); mcpServer.addTool(toolSpec); @@ -491,7 +616,7 @@ void testThrownMcpError() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest("POST", CUSTOM_MESSAGE_ENDPOINT); MockHttpServletResponse response = new MockHttpServletResponse(); - byte[] content = new ObjectMapper().writeValueAsBytes(jsonrpcRequest); + byte[] content = JSON_MAPPER.writeValueAsBytes(jsonrpcRequest); request.setContent(content); request.addHeader("Content-Type", "application/json"); request.addHeader("Content-Length", Integer.toString(content.length)); @@ -500,13 +625,16 @@ void testThrownMcpError() throws Exception { request.addHeader("Content-Type", APPLICATION_JSON); request.addHeader("Cache-Control", "no-cache"); request.addHeader(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_03_26); + mcpStatelessServerTransport.service(request, response); - McpSchema.JSONRPCResponse jsonrpcResponse = new ObjectMapper().readValue(response.getContentAsByteArray(), + McpSchema.JSONRPCResponse jsonrpcResponse = JSON_MAPPER.readValue(response.getContentAsByteArray(), McpSchema.JSONRPCResponse.class); - assertThat(jsonrpcResponse.error()) - .isEqualTo(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + assertThat(jsonrpcResponse).isNotNull(); + assertThat(jsonrpcResponse.error()).isNotNull(); + assertThat(jsonrpcResponse.error().code()).isEqualTo(ErrorCodes.INTERNAL_ERROR); + assertThat(jsonrpcResponse.error().message()).isEqualTo("testing"); mcpServer.close(); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java similarity index 79% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java index 327ec1b21..96f1524b7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java @@ -6,8 +6,6 @@ import org.junit.jupiter.api.Timeout; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; @@ -21,10 +19,7 @@ class HttpServletStreamableAsyncServerTests extends AbstractMcpAsyncServerTests { protected McpStreamableServerTransportProvider createMcpTransportProvider() { - return HttpServletStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint("/mcp/message") - .build(); + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java similarity index 87% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 2e9b4cbad..5b934e4e9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -4,26 +4,28 @@ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; - import java.time.Duration; +import java.util.Map; +import java.util.stream.Stream; -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; - -import com.fasterxml.jackson.databind.ObjectMapper; - +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SyncSpecification; import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.provider.Arguments; + +import static org.assertj.core.api.Assertions.assertThat; @Timeout(15) class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -36,11 +38,14 @@ class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerInteg private Tomcat tomcat; + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + @BeforeEach public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .contextExtractor(TEST_CONTEXT_EXTRACTOR) .mcpEndpoint(MESSAGE_ENDPOINT) .keepAliveInterval(Duration.ofSeconds(1)) @@ -92,9 +97,7 @@ public void after() { protected void prepareClients(int port, String mcpEndpoint) { } - static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r, tc) -> { - tc.put("important", "value"); - return tc; - }; + static McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext + .create(Map.of("important", "value")); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java similarity index 79% rename from mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java index 66fa2b2ac..87c0712dc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java @@ -6,8 +6,6 @@ import org.junit.jupiter.api.Timeout; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; @@ -21,10 +19,7 @@ class HttpServletStreamableSyncServerTests extends AbstractMcpSyncServerTests { protected McpStreamableServerTransportProvider createMcpTransportProvider() { - return HttpServletStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .mcpEndpoint("/mcp/message") - .build(); + return HttpServletStreamableServerTransportProvider.builder().mcpEndpoint("/mcp/message").build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java similarity index 93% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java index f915895be..54fb80a78 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -12,14 +12,13 @@ import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; + import static org.assertj.core.api.Assertions.assertThat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; @@ -59,7 +58,6 @@ class McpCompletionTests { public void before() { // Create and con figure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); @@ -99,7 +97,7 @@ void testCompletionHandlerReceivesContext() { return new CompleteResult(new CompleteResult.CompleteCompletion(List.of("test-completion"), 1, false)); }; - ResourceReference resourceRef = new ResourceReference("ref/resource", "test://resource/{param}"); + ResourceReference resourceRef = new ResourceReference(ResourceReference.TYPE, "test://resource/{param}"); var resource = Resource.builder() .uri("test://resource/{param}") @@ -154,7 +152,7 @@ void testCompletionBackwardCompatibility() { .prompts(new McpServerFeatures.SyncPromptSpecification(prompt, (mcpSyncServerExchange, getPromptRequest) -> null)) .completions(new McpServerFeatures.SyncCompletionSpecification( - new PromptReference("ref/prompt", "test-prompt"), completionHandler)) + new PromptReference(PromptReference.TYPE, "test-prompt"), completionHandler)) .build(); try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -163,7 +161,7 @@ void testCompletionBackwardCompatibility() { assertThat(initResult).isNotNull(); // Test without context - CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "test-prompt"), + CompleteRequest request = new CompleteRequest(new PromptReference(PromptReference.TYPE, "test-prompt"), new CompleteRequest.CompleteArgument("arg", "val")); CompleteResult result = mcpClient.completeCompletion(request); @@ -219,7 +217,7 @@ else if ("products_db".equals(db)) { .resources(new McpServerFeatures.SyncResourceSpecification(resource, (exchange, req) -> new ReadResourceResult(List.of()))) .completions(new McpServerFeatures.SyncCompletionSpecification( - new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler)) + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) .build(); try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) @@ -229,7 +227,7 @@ else if ("products_db".equals(db)) { // First, complete database CompleteRequest dbRequest = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("database", "")); CompleteResult dbResult = mcpClient.completeCompletion(dbRequest); @@ -237,7 +235,7 @@ else if ("products_db".equals(db)) { // Then complete table with database context CompleteRequest tableRequest = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", ""), new CompleteRequest.CompleteContext(Map.of("database", "users_db"))); @@ -246,7 +244,7 @@ else if ("products_db".equals(db)) { // Different database gives different tables CompleteRequest tableRequest2 = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", ""), new CompleteRequest.CompleteContext(Map.of("database", "products_db"))); @@ -296,7 +294,7 @@ void testCompletionErrorOnMissingContext() { .resources(new McpServerFeatures.SyncResourceSpecification(resource, (exchange, req) -> new ReadResourceResult(List.of()))) .completions(new McpServerFeatures.SyncCompletionSpecification( - new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler)) + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), completionHandler)) .build(); try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample" + "client", "0.0.0")) @@ -306,7 +304,7 @@ void testCompletionErrorOnMissingContext() { // Try to complete table without database context - should raise error CompleteRequest requestWithoutContext = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", "")); assertThatExceptionOfType(McpError.class) @@ -315,7 +313,7 @@ void testCompletionErrorOnMissingContext() { // Now complete with proper context - should work normally CompleteRequest requestWithContext = new CompleteRequest( - new ResourceReference("ref/resource", "db://{database}/{table}"), + new ResourceReference(ResourceReference.TYPE, "db://{database}/{table}"), new CompleteRequest.CompleteArgument("table", ""), new CompleteRequest.CompleteContext(Map.of("database", "test_db"))); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java similarity index 94% rename from mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index cdd2bacb7..d9f899020 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -10,6 +10,7 @@ import io.modelcontextprotocol.MockMcpServerTransport; import io.modelcontextprotocol.MockMcpServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -36,8 +37,7 @@ void shouldUseLatestVersionByDefault() { String requestId = UUID.randomUUID().toString(); - transportProvider - .simulateIncomingMessage(jsonRpcInitializeRequest(requestId, McpSchema.LATEST_PROTOCOL_VERSION)); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, ProtocolVersions.MCP_2025_11_25)); McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); assertThat(response).isInstanceOf(McpSchema.JSONRPCResponse.class); @@ -60,7 +60,7 @@ void shouldNegotiateSpecificVersion() { McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); - server.setProtocolVersions(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); + server.setProtocolVersions(List.of(oldVersion, ProtocolVersions.MCP_2025_11_25)); String requestId = UUID.randomUUID().toString(); @@ -105,7 +105,7 @@ void shouldSuggestLatestVersionForUnsupportedVersion() { void shouldUseHighestVersionWhenMultipleSupported() { String oldVersion = "0.1.0"; String middleVersion = "0.2.0"; - String latestVersion = McpSchema.LATEST_PROTOCOL_VERSION; + String latestVersion = ProtocolVersions.MCP_2025_11_25; MockMcpServerTransport serverTransport = new MockMcpServerTransport(); var transportProvider = new MockMcpServerTransportProvider(serverTransport); diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java new file mode 100644 index 000000000..b7d46a967 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/ResourceTemplateManagementTests.java @@ -0,0 +1,299 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.MockMcpServerTransport; +import io.modelcontextprotocol.MockMcpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Test suite for Resource Template Management functionality. Tests the new + * addResourceTemplate() and removeResourceTemplate() methods, as well as the Map-based + * resource template storage. + * + * @author Christian Tzolov + */ +public class ResourceTemplateManagementTests { + + private static final String TEST_TEMPLATE_URI = "test://resource/{param}"; + + private static final String TEST_TEMPLATE_NAME = "test-template"; + + private MockMcpServerTransportProvider mockTransportProvider; + + private McpAsyncServer mcpAsyncServer; + + @BeforeEach + void setUp() { + mockTransportProvider = new MockMcpServerTransportProvider(new MockMcpServerTransport()); + } + + @AfterEach + void tearDown() { + if (mcpAsyncServer != null) { + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + } + + // --------------------------------------- + // Async Resource Template Tests + // --------------------------------------- + + @Test + void testAddResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResourceTemplate(specification)).verifyComplete(); + } + + @Test + void testAddResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResourceTemplate(specification)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + StepVerifier.create(mcpAsyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).verifyComplete(); + } + + @Test + void testRemoveResourceTemplateWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResourceTemplate(TEST_TEMPLATE_URI)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Server must be configured with resource capabilities"); + }); + + assertThatCode(() -> serverWithoutResources.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentResourceTemplate() { + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + // Removing a non-existent resource template should complete successfully (no + // error) + // as per the new implementation that just logs a warning + StepVerifier.create(mcpAsyncServer.removeResourceTemplate("nonexistent://template/{id}")).verifyComplete(); + } + + @Test + void testReplaceExistingResourceTemplate() { + ResourceTemplate originalTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Original template") + .mimeType("text/plain") + .build(); + + ResourceTemplate updatedTemplate = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Updated template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification originalSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + originalTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification updatedSpec = new McpServerFeatures.AsyncResourceTemplateSpecification( + updatedTemplate, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(originalSpec) + .build(); + + // Adding a resource template with the same URI should replace the existing one + StepVerifier.create(mcpAsyncServer.addResourceTemplate(updatedSpec)).verifyComplete(); + } + + // --------------------------------------- + // Sync Resource Template Tests + // --------------------------------------- + + @Test + void testSyncAddResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatCode(() -> mcpSyncServer.addResourceTemplate(specification)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testSyncRemoveResourceTemplate() { + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.SyncResourceTemplateSpecification specification = new McpServerFeatures.SyncResourceTemplateSpecification( + template, (exchange, req) -> new ReadResourceResult(List.of())); + + var mcpSyncServer = McpServer.sync(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build(); + + assertThatCode(() -> mcpSyncServer.removeResourceTemplate(TEST_TEMPLATE_URI)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Map-based Storage Tests + // --------------------------------------- + + @Test + void testResourceTemplateMapBasedStorage() { + ResourceTemplate template1 = ResourceTemplate.builder() + .uriTemplate("test://template1/{id}") + .name("template1") + .description("First template") + .mimeType("text/plain") + .build(); + + ResourceTemplate template2 = ResourceTemplate.builder() + .uriTemplate("test://template2/{id}") + .name("template2") + .description("Second template") + .mimeType("application/json") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification spec1 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template1, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + McpServerFeatures.AsyncResourceTemplateSpecification spec2 = new McpServerFeatures.AsyncResourceTemplateSpecification( + template2, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + mcpAsyncServer = McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(spec1, spec2) + .build(); + + // Verify both templates are stored (this would be tested through integration + // tests + // or by accessing internal state, but for unit tests we verify no exceptions) + assertThat(mcpAsyncServer).isNotNull(); + } + + @Test + void testResourceTemplateBuilderWithMap() { + // Test that the new Map-based builder methods work correctly + ResourceTemplate template = ResourceTemplate.builder() + .uriTemplate(TEST_TEMPLATE_URI) + .name(TEST_TEMPLATE_NAME) + .description("Test resource template") + .mimeType("text/plain") + .build(); + + McpServerFeatures.AsyncResourceTemplateSpecification specification = new McpServerFeatures.AsyncResourceTemplateSpecification( + template, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); + + // Test varargs builder method + assertThatCode(() -> { + McpServer.async(mockTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .resourceTemplates(specification) + .build() + .closeGracefully() + .block(Duration.ofSeconds(10)); + }).doesNotThrowAnyException(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java similarity index 100% rename from mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index 97db5fa06..b2dfbea25 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -8,6 +8,8 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. * @@ -17,7 +19,7 @@ class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { protected McpServerTransportProvider createMcpTransportProvider() { - return new StdioServerTransportProvider(); + return new StdioServerTransportProvider(JSON_MAPPER); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java similarity index 84% rename from mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index 1e01962e9..c97c75d38 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -8,6 +8,8 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; + /** * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. * @@ -17,7 +19,7 @@ class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { protected McpServerTransportProvider createMcpTransportProvider() { - return new StdioServerTransportProvider(); + return new StdioServerTransportProvider(JSON_MAPPER); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java similarity index 96% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 0462cbafe..be88097b3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -4,8 +4,6 @@ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.McpServer; @@ -40,7 +38,6 @@ public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) .baseUrl(CUSTOM_CONTEXT_PATH) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java new file mode 100644 index 000000000..b94552d12 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/McpTestRequestRecordingServletFilter.java @@ -0,0 +1,128 @@ +/* + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; + +/** + * Simple {@link Filter} which records calls made to an MCP server. + * + * @author Daniel Garnier-Moiroux + */ +public class McpTestRequestRecordingServletFilter implements Filter { + + private final List calls = new ArrayList<>(); + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + + if (servletRequest instanceof HttpServletRequest req) { + var headers = Collections.list(req.getHeaderNames()) + .stream() + .collect(Collectors.toUnmodifiableMap(Function.identity(), + name -> String.join(",", Collections.list(req.getHeaders(name))))); + var request = new CachedBodyHttpServletRequest(req); + calls.add(new Call(req.getMethod(), headers, request.getBodyAsString())); + filterChain.doFilter(request, servletResponse); + } + else { + filterChain.doFilter(servletRequest, servletResponse); + } + + } + + public List getCalls() { + + return List.copyOf(calls); + } + + public record Call(String method, Map headers, String body) { + + } + + public static class CachedBodyHttpServletRequest extends HttpServletRequestWrapper { + + private final byte[] cachedBody; + + public CachedBodyHttpServletRequest(HttpServletRequest request) throws IOException { + super(request); + this.cachedBody = request.getInputStream().readAllBytes(); + } + + @Override + public ServletInputStream getInputStream() { + return new CachedBodyServletInputStream(cachedBody); + } + + @Override + public BufferedReader getReader() { + return new BufferedReader(new InputStreamReader(getInputStream(), StandardCharsets.UTF_8)); + } + + public String getBodyAsString() { + return new String(cachedBody, StandardCharsets.UTF_8); + } + + } + + public static class CachedBodyServletInputStream extends ServletInputStream { + + private InputStream cachedBodyInputStream; + + public CachedBodyServletInputStream(byte[] cachedBody) { + this.cachedBodyInputStream = new ByteArrayInputStream(cachedBody); + } + + @Override + public boolean isFinished() { + try { + return cachedBodyInputStream.available() == 0; + } + catch (IOException e) { + e.printStackTrace(); + } + return false; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + throw new UnsupportedOperationException(); + } + + @Override + public int read() throws IOException { + return cachedBodyInputStream.read(); + } + + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java new file mode 100644 index 000000000..10bb30568 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityIntegrationTests.java @@ -0,0 +1,339 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.http.HttpServlet; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.BeforeParameterizedClassInvocation; +import org.junit.jupiter.params.Parameter; +import org.junit.jupiter.params.ParameterizedClass; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +/** + * Test the header security validation for all transport types. + * + * @author Daniel Garnier-Moiroux + */ +@ParameterizedClass +@MethodSource("transports") +class ServerTransportSecurityIntegrationTests { + + private static final String DISALLOWED_ORIGIN = "https://malicious.example.com"; + + private static final String DISALLOWED_HOST = "malicious.example.com:8080"; + + @Parameter + private static Transport transport; + + private static Tomcat tomcat; + + private static String baseUrl; + + @BeforeParameterizedClassInvocation + static void createTransportAndStartTomcat(Transport transport) { + var port = TomcatTestUtil.findAvailablePort(); + baseUrl = "http://localhost:" + port; + startTomcat(transport.servlet(), port); + } + + @AfterAll + static void afterAll() { + stopTomcat(); + } + + private McpSyncClient mcpClient; + + private final TestRequestCustomizer requestCustomizer = new TestRequestCustomizer(); + + @BeforeEach + void setUp() { + requestCustomizer.reset(); + mcpClient = transport.createMcpClient(baseUrl, requestCustomizer); + } + + @AfterEach + void tearDown() { + mcpClient.close(); + } + + @Test + void originAllowed() { + requestCustomizer.setOriginHeader(baseUrl); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void noOrigin() { + requestCustomizer.setOriginHeader(null); + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectOriginNotAllowed() { + requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageOriginNotAllowed() { + requestCustomizer.setOriginHeader(baseUrl); + mcpClient.initialize(); + requestCustomizer.setOriginHeader(DISALLOWED_ORIGIN); + assertThatThrownBy(() -> mcpClient.listTools()); + } + + @Test + void hostAllowed() { + // Host header is set by default by HttpClient to the request URI host + var result = mcpClient.initialize(); + var tools = mcpClient.listTools(); + + assertThat(result.protocolVersion()).isNotEmpty(); + assertThat(tools.tools()).isEmpty(); + } + + @Test + void connectHostNotAllowed() { + requestCustomizer.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.initialize()); + } + + @Test + void messageHostNotAllowed() { + mcpClient.initialize(); + requestCustomizer.setHostHeader(DISALLOWED_HOST); + assertThatThrownBy(() -> mcpClient.listTools()); + } + + // ---------------------------------------------------- + // Tomcat management + // ---------------------------------------------------- + + private static void startTomcat(jakarta.servlet.Servlet servlet, int port) { + tomcat = TomcatTestUtil.createTomcatServer("", port, servlet); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private static void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // ---------------------------------------------------- + // Transport servers to test + // ---------------------------------------------------- + + /** + * All transport types we want to test. We use a {@link MethodSource} rather than a + * {@link org.junit.jupiter.params.provider.ValueSource} to provide a readable name. + */ + static Stream transports() { + //@formatter:off + return Stream.of( + arguments(named("SSE", new Sse())), + arguments(named("Streamable HTTP", new StreamableHttp())), + arguments(named("Stateless", new Stateless())) + ); + //@formatter:on + } + + /** + * Represents a server transport we want to test, and how to create a client for the + * resulting MCP Server. + */ + interface Transport { + + McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer); + + HttpServlet servlet(); + + } + + /** + * SSE-based transport. + */ + static class Sse implements Transport { + + private final HttpServletSseServerTransportProvider transport; + + public Sse() { + transport = HttpServletSseServerTransportProvider.builder() + .messageEndpoint("/mcp/message") + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) + .build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) { + var transport = HttpClientSseClientTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonDefaults.getMapper()) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + + static class StreamableHttp implements Transport { + + private final HttpServletStreamableServerTransportProvider transport; + + public StreamableHttp() { + transport = HttpServletStreamableServerTransportProvider.builder() + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) + .build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) { + var transport = HttpClientStreamableHttpTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonDefaults.getMapper()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + + static class Stateless implements Transport { + + private final HttpServletStatelessServerTransport transport; + + public Stateless() { + transport = HttpServletStatelessServerTransport.builder() + .securityValidator(DefaultServerTransportSecurityValidator.builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build()) + .build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public McpSyncClient createMcpClient(String baseUrl, TestRequestCustomizer requestCustomizer) { + var transport = HttpClientStreamableHttpTransport.builder(baseUrl) + .httpRequestCustomizer(requestCustomizer) + .jsonMapper(McpJsonDefaults.getMapper()) + .openConnectionOnStartup(true) + .build(); + return McpClient.sync(transport).initializationTimeout(Duration.ofMillis(500)).build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + + static class TestRequestCustomizer implements McpSyncHttpClientRequestCustomizer { + + private String originHeader = null; + + private String hostHeader = null; + + @Override + public void customize(HttpRequest.Builder builder, String method, URI endpoint, String body, + McpTransportContext context) { + if (originHeader != null) { + builder.header("Origin", originHeader); + } + if (hostHeader != null) { + // HttpClient normally sets Host automatically, but we can override it + builder.header("Host", hostHeader); + } + } + + public void setOriginHeader(String originHeader) { + this.originHeader = originHeader; + } + + public void setHostHeader(String hostHeader) { + this.hostHeader = hostHeader; + } + + public void reset() { + this.originHeader = null; + this.hostHeader = null; + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java similarity index 92% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5ac..5390cc4c2 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -14,7 +14,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -37,7 +37,6 @@ * * @author Christian Tzolov */ -@Disabled class StdioServerTransportProviderTests { private final PrintStream originalOut = System.out; @@ -50,8 +49,6 @@ class StdioServerTransportProviderTests { private StdioServerTransportProvider transportProvider; - private ObjectMapper objectMapper; - private McpServerSession.Factory sessionFactory; private McpServerSession mockSession; @@ -64,8 +61,6 @@ void setUp() { System.setOut(testOutPrintStream); System.setErr(testOutPrintStream); - objectMapper = new ObjectMapper(); - // Create mocks for session factory and session mockSession = mock(McpServerSession.class); sessionFactory = mock(McpServerSession.Factory.class); @@ -75,7 +70,8 @@ void setUp() { when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); - transportProvider = new StdioServerTransportProvider(objectMapper, System.in, testOutPrintStream); + transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper(), System.in, + testOutPrintStream); } @AfterEach @@ -105,7 +101,7 @@ void shouldHandleIncomingMessages() throws Exception { String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); - transportProvider = new StdioServerTransportProvider(objectMapper, stream, System.out); + transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper(), stream, System.out); // Set up a real session to capture the message AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); @@ -185,11 +181,11 @@ void shouldHandleMultipleCloseGracefullyCalls() { @Test void shouldHandleNotificationBeforeSessionFactoryIsSet() { - transportProvider = new StdioServerTransportProvider(objectMapper); + transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper()); // Send notification before setting session factory StepVerifier.create(transportProvider.notifyClients("testNotification", Map.of("key", "value"))) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class); + assertThat(error).isInstanceOf(IllegalStateException.class); }); } @@ -200,7 +196,7 @@ void shouldHandleInvalidJsonMessage() throws Exception { String jsonMessage = "{invalid json}\n"; InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); - transportProvider = new StdioServerTransportProvider(objectMapper, stream, testOutPrintStream); + transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper(), stream, testOutPrintStream); // Set up a session factory transportProvider.setSessionFactory(sessionFactory); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java similarity index 76% rename from mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java rename to mcp-test/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java index 2cf95dc94..490e29838 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -8,6 +8,7 @@ import java.net.InetSocketAddress; import java.net.ServerSocket; +import jakarta.servlet.Filter; import jakarta.servlet.Servlet; import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; @@ -24,7 +25,8 @@ public class TomcatTestUtil { // Prevent instantiation } - public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet, + Filter... additionalFilters) { var tomcat = new Tomcat(); tomcat.setPort(port); @@ -43,15 +45,17 @@ public static Tomcat createTomcatServer(String contextPath, int port, Servlet se context.addChild(wrapper); context.addServletMappingDecoded("/*", "mcpServlet"); - var filterDef = new FilterDef(); - filterDef.setFilterClass(McpTestServletFilter.class.getName()); - filterDef.setFilterName(McpTestServletFilter.class.getSimpleName()); - context.addFilterDef(filterDef); + for (var filter : additionalFilters) { + var filterDef = new FilterDef(); + filterDef.setFilter(filter); + filterDef.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + context.addFilterDef(filterDef); - var filterMap = new FilterMap(); - filterMap.setFilterName(McpTestServletFilter.class.getSimpleName()); - filterMap.addURLPattern("/*"); - context.addFilterMap(filterMap); + var filterMap = new FilterMap(); + filterMap.setFilterName(McpTestRequestRecordingServletFilter.class.getSimpleName()); + filterMap.addURLPattern("/*"); + context.addFilterMap(filterMap); + } var connector = tomcat.getConnector(); connector.setAsyncTimeout(3000); diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java new file mode 100644 index 000000000..195b6ec6d --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/spec/CompleteCompletionSerializationTest.java @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.json.McpJsonDefaults; +import io.modelcontextprotocol.json.McpJsonMapper; +import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.util.Collections; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class CompleteCompletionSerializationTest { + + @Test + void codeCompletionSerialization() throws IOException { + McpJsonMapper jsonMapper = McpJsonDefaults.getMapper(); + McpSchema.CompleteResult.CompleteCompletion codeComplete = new McpSchema.CompleteResult.CompleteCompletion( + Collections.emptyList(), 0, false); + String json = jsonMapper.writeValueAsString(codeComplete); + String expected = """ + {"values":[],"total":0,"hasMore":false}"""; + assertEquals(expected, json, json); + + McpSchema.CompleteResult completeResult = new McpSchema.CompleteResult(codeComplete); + json = jsonMapper.writeValueAsString(completeResult); + expected = """ + {"completion":{"values":[],"total":0,"hasMore":false}}"""; + assertEquals(expected, json, json); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java similarity index 80% rename from mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java rename to mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index a5b2137fd..942e0a6e2 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -4,23 +4,23 @@ package io.modelcontextprotocol.spec; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; - -import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; import net.javacrumbs.jsonunit.core.Option; /** @@ -29,14 +29,12 @@ */ public class McpSchemaTests { - ObjectMapper mapper = new ObjectMapper(); - // Content Types Tests @Test void testTextContent() throws Exception { McpSchema.TextContent test = new McpSchema.TextContent("XXX"); - String value = mapper.writeValueAsString(test); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -47,7 +45,7 @@ void testTextContent() throws Exception { @Test void testTextContentDeserialization() throws Exception { - McpSchema.TextContent textContent = mapper.readValue(""" + McpSchema.TextContent textContent = JSON_MAPPER.readValue(""" {"type":"text","text":"XXX","_meta":{"metaKey":"metaValue"}}""", McpSchema.TextContent.class); assertThat(textContent).isNotNull(); @@ -57,19 +55,25 @@ void testTextContentDeserialization() throws Exception { } @Test - void testContentDeserializationWrongType() throws Exception { - - assertThatThrownBy(() -> mapper.readValue(""" - {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)) - .isInstanceOf(InvalidTypeIdException.class) + void testContentDeserializationWrongType() { + assertThatThrownBy(() -> JSON_MAPPER.readValue(""" + {"type":"WRONG","text":"XXX"}""", McpSchema.TextContent.class)).isInstanceOf(IOException.class) + // Jackson 2 throws the InvalidTypeException directly, but Jackson 3 wraps it. + // Try to unwrap in case it's Jackson 3. + .extracting(throwable -> throwable.getCause() != null ? throwable.getCause() : throwable) + .asInstanceOf(InstanceOfAssertFactories.THROWABLE) .hasMessageContaining( - "Could not resolve type id 'WRONG' as a subtype of `io.modelcontextprotocol.spec.McpSchema$TextContent`: known type ids = [audio, image, resource, resource_link, text]"); + "Could not resolve type id 'WRONG' as a subtype of `io.modelcontextprotocol.spec.McpSchema$TextContent`: known type ids = [audio, image, resource, resource_link, text]") + .extracting(Object::getClass) + .extracting(Class::getSimpleName) + // Class name is the same for both Jackson 2 and 3, only the package differs. + .isEqualTo("InvalidTypeIdException"); } @Test void testImageContent() throws Exception { - McpSchema.ImageContent test = new McpSchema.ImageContent(null, null, "base64encodeddata", "image/png"); - String value = mapper.writeValueAsString(test); + McpSchema.ImageContent test = new McpSchema.ImageContent(null, "base64encodeddata", "image/png"); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -80,7 +84,7 @@ void testImageContent() throws Exception { @Test void testImageContentDeserialization() throws Exception { - McpSchema.ImageContent imageContent = mapper.readValue(""" + McpSchema.ImageContent imageContent = JSON_MAPPER.readValue(""" {"type":"image","data":"base64encodeddata","mimeType":"image/png","_meta":{"metaKey":"metaValue"}}""", McpSchema.ImageContent.class); assertThat(imageContent).isNotNull(); @@ -93,7 +97,7 @@ void testImageContentDeserialization() throws Exception { @Test void testAudioContent() throws Exception { McpSchema.AudioContent audioContent = new McpSchema.AudioContent(null, "base64encodeddata", "audio/wav"); - String value = mapper.writeValueAsString(audioContent); + String value = JSON_MAPPER.writeValueAsString(audioContent); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -104,7 +108,7 @@ void testAudioContent() throws Exception { @Test void testAudioContentDeserialization() throws Exception { - McpSchema.AudioContent audioContent = mapper.readValue(""" + McpSchema.AudioContent audioContent = JSON_MAPPER.readValue(""" {"type":"audio","data":"base64encodeddata","mimeType":"audio/wav","_meta":{"metaKey":"metaValue"}}""", McpSchema.AudioContent.class); assertThat(audioContent).isNotNull(); @@ -140,7 +144,7 @@ void testCreateMessageRequestWithMeta() throws Exception { .meta(meta) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -156,9 +160,9 @@ void testEmbeddedResource() throws Exception { McpSchema.TextResourceContents resourceContents = new McpSchema.TextResourceContents("resource://test", "text/plain", "Sample resource content"); - McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, resourceContents); - String value = mapper.writeValueAsString(test); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -169,7 +173,7 @@ void testEmbeddedResource() throws Exception { @Test void testEmbeddedResourceDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( """ {"type":"resource","resource":{"uri":"resource://test","mimeType":"text/plain","text":"Sample resource content"},"_meta":{"metaKey":"metaValue"}}""", McpSchema.EmbeddedResource.class); @@ -187,9 +191,9 @@ void testEmbeddedResourceWithBlobContents() throws Exception { McpSchema.BlobResourceContents resourceContents = new McpSchema.BlobResourceContents("resource://test", "application/octet-stream", "base64encodedblob"); - McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, null, resourceContents); + McpSchema.EmbeddedResource test = new McpSchema.EmbeddedResource(null, resourceContents); - String value = mapper.writeValueAsString(test); + String value = JSON_MAPPER.writeValueAsString(test); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -200,7 +204,7 @@ void testEmbeddedResourceWithBlobContents() throws Exception { @Test void testEmbeddedResourceWithBlobContentsDeserialization() throws Exception { - McpSchema.EmbeddedResource embeddedResource = mapper.readValue( + McpSchema.EmbeddedResource embeddedResource = JSON_MAPPER.readValue( """ {"type":"resource","resource":{"uri":"resource://test","mimeType":"application/octet-stream","blob":"base64encodedblob","_meta":{"metaKey":"metaValue"}}}""", McpSchema.EmbeddedResource.class); @@ -219,7 +223,7 @@ void testResourceLink() throws Exception { McpSchema.ResourceLink resourceLink = new McpSchema.ResourceLink("main.rs", "Main file", "file:///project/src/main.rs", "Primary application entry point", "text/x-rust", null, null, Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(resourceLink); + String value = JSON_MAPPER.writeValueAsString(resourceLink); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -231,7 +235,7 @@ void testResourceLink() throws Exception { @Test void testResourceLinkDeserialization() throws Exception { - McpSchema.ResourceLink resourceLink = mapper.readValue( + McpSchema.ResourceLink resourceLink = JSON_MAPPER.readValue( """ {"type":"resource_link","name":"main.rs","uri":"file:///project/src/main.rs","description":"Primary application entry point","mimeType":"text/x-rust","_meta":{"metaKey":"metaValue"}}""", McpSchema.ResourceLink.class); @@ -254,7 +258,7 @@ void testJSONRPCRequest() throws Exception { McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, params); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -270,7 +274,7 @@ void testJSONRPCNotification() throws Exception { McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, "notification_method", params); - String value = mapper.writeValueAsString(notification); + String value = JSON_MAPPER.writeValueAsString(notification); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -285,7 +289,7 @@ void testJSONRPCResponse() throws Exception { McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); - String value = mapper.writeValueAsString(response); + String value = JSON_MAPPER.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -300,7 +304,7 @@ void testJSONRPCResponseWithError() throws Exception { McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); - String value = mapper.writeValueAsString(response); + String value = JSON_MAPPER.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -323,7 +327,7 @@ void testInitializeRequest() throws Exception { McpSchema.InitializeRequest request = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2024_11_05, capabilities, clientInfo, meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -346,7 +350,7 @@ void testInitializeResult() throws Exception { McpSchema.InitializeResult result = new McpSchema.InitializeResult(ProtocolVersions.MCP_2024_11_05, capabilities, serverInfo, "Server initialized successfully"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -362,10 +366,15 @@ void testResource() throws Exception { McpSchema.Annotations annotations = new McpSchema.Annotations( Arrays.asList(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), 0.8); - McpSchema.Resource resource = new McpSchema.Resource("resource://test", "Test Resource", "A test resource", - "text/plain", annotations); + McpSchema.Resource resource = McpSchema.Resource.builder() + .uri("resource://test") + .name("Test Resource") + .description("A test resource") + .mimeType("text/plain") + .annotations(annotations) + .build(); - String value = mapper.writeValueAsString(resource); + String value = JSON_MAPPER.writeValueAsString(resource); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -389,7 +398,7 @@ void testResourceBuilder() throws Exception { .meta(Map.of("metaKey", "metaValue")) .build(); - String value = mapper.writeValueAsString(resource); + String value = JSON_MAPPER.writeValueAsString(resource); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -436,7 +445,7 @@ void testResourceTemplate() throws Exception { McpSchema.ResourceTemplate template = new McpSchema.ResourceTemplate("resource://{param}/test", "Test Template", "Test Template", "A test resource template", "text/plain", annotations, meta); - String value = mapper.writeValueAsString(template); + String value = JSON_MAPPER.writeValueAsString(template); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -447,18 +456,26 @@ void testResourceTemplate() throws Exception { @Test void testListResourcesResult() throws Exception { - McpSchema.Resource resource1 = new McpSchema.Resource("resource://test1", "Test Resource 1", - "First test resource", "text/plain", null); + McpSchema.Resource resource1 = McpSchema.Resource.builder() + .uri("resource://test1") + .name("Test Resource 1") + .description("First test resource") + .mimeType("text/plain") + .build(); - McpSchema.Resource resource2 = new McpSchema.Resource("resource://test2", "Test Resource 2", - "Second test resource", "application/json", null); + McpSchema.Resource resource2 = McpSchema.Resource.builder() + .uri("resource://test2") + .name("Test Resource 2") + .description("Second test resource") + .mimeType("application/json") + .build(); Map meta = Map.of("metaKey", "metaValue"); McpSchema.ListResourcesResult result = new McpSchema.ListResourcesResult(Arrays.asList(resource1, resource2), "next-cursor", meta); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -478,7 +495,7 @@ void testListResourceTemplatesResult() throws Exception { McpSchema.ListResourceTemplatesResult result = new McpSchema.ListResourceTemplatesResult( Arrays.asList(template1, template2), "next-cursor"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -492,7 +509,7 @@ void testReadResourceRequest() throws Exception { McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -507,7 +524,7 @@ void testReadResourceRequestWithMeta() throws Exception { McpSchema.ReadResourceRequest request = new McpSchema.ReadResourceRequest("resource://test", meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -521,7 +538,7 @@ void testReadResourceRequestWithMeta() throws Exception { @Test void testReadResourceRequestDeserialization() throws Exception { - McpSchema.ReadResourceRequest request = mapper.readValue(""" + McpSchema.ReadResourceRequest request = JSON_MAPPER.readValue(""" {"uri":"resource://test","_meta":{"progressToken":"test-token"}}""", McpSchema.ReadResourceRequest.class); @@ -541,7 +558,7 @@ void testReadResourceResult() throws Exception { McpSchema.ReadResourceResult result = new McpSchema.ReadResourceResult(Arrays.asList(contents1, contents2), Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -562,7 +579,7 @@ void testPrompt() throws Exception { McpSchema.Prompt prompt = new McpSchema.Prompt("test-prompt", "Test Prompt", "A test prompt", Arrays.asList(arg1, arg2), Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(prompt); + String value = JSON_MAPPER.writeValueAsString(prompt); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -577,7 +594,7 @@ void testPromptMessage() throws Exception { McpSchema.PromptMessage message = new McpSchema.PromptMessage(McpSchema.Role.USER, content); - String value = mapper.writeValueAsString(message); + String value = JSON_MAPPER.writeValueAsString(message); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -598,7 +615,7 @@ void testListPromptsResult() throws Exception { McpSchema.ListPromptsResult result = new McpSchema.ListPromptsResult(Arrays.asList(prompt1, prompt2), "next-cursor"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -615,7 +632,7 @@ void testGetPromptRequest() throws Exception { McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments); - assertThat(mapper.readValue(""" + assertThat(JSON_MAPPER.readValue(""" {"name":"test-prompt","arguments":{"arg1":"value1","arg2":42}}""", McpSchema.GetPromptRequest.class)) .isEqualTo(request); } @@ -631,7 +648,7 @@ void testGetPromptRequestWithMeta() throws Exception { McpSchema.GetPromptRequest request = new McpSchema.GetPromptRequest("test-prompt", arguments, meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -656,7 +673,7 @@ void testGetPromptResult() throws Exception { McpSchema.GetPromptResult result = new McpSchema.GetPromptResult("A test prompt result", Arrays.asList(message1, message2)); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -696,16 +713,16 @@ void testJsonSchema() throws Exception { """; // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); // Serialize the object back to a string - String serialized = mapper.writeValueAsString(schema); + String serialized = JSON_MAPPER.writeValueAsString(schema); // Deserialize again - McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); // Serialize one more time and compare with the first serialization - String serializedAgain = mapper.writeValueAsString(deserialized); + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); // The two serialized strings should be the same assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); @@ -739,16 +756,16 @@ void testJsonSchemaWithDefinitions() throws Exception { """; // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); // Serialize the object back to a string - String serialized = mapper.writeValueAsString(schema); + String serialized = JSON_MAPPER.writeValueAsString(schema); // Deserialize again - McpSchema.JsonSchema deserialized = mapper.readValue(serialized, McpSchema.JsonSchema.class); + McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); // Serialize one more time and compare with the first serialization - String serializedAgain = mapper.writeValueAsString(deserialized); + String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); // The two serialized strings should be the same assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); @@ -771,9 +788,13 @@ void testTool() throws Exception { } """; - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -805,16 +826,20 @@ void testToolWithComplexSchema() throws Exception { } """; - McpSchema.Tool tool = new McpSchema.Tool("addressTool", "Handles addresses", complexSchemaJson); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("Handles addresses") + .inputSchema(JSON_MAPPER, complexSchemaJson) + .build(); // Serialize the tool to a string - String serialized = mapper.writeValueAsString(tool); + String serialized = JSON_MAPPER.writeValueAsString(tool); // Deserialize back to a Tool object - McpSchema.Tool deserializedTool = mapper.readValue(serialized, McpSchema.Tool.class); + McpSchema.Tool deserializedTool = JSON_MAPPER.readValue(serialized, McpSchema.Tool.class); // Serialize again and compare with first serialization - String serializedAgain = mapper.writeValueAsString(deserializedTool); + String serializedAgain = JSON_MAPPER.writeValueAsString(deserializedTool); // The two serialized strings should be the same assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); @@ -841,11 +866,16 @@ void testToolWithMeta() throws Exception { } """; - McpSchema.JsonSchema schema = mapper.readValue(schemaJson, McpSchema.JsonSchema.class); + McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); Map meta = Map.of("metaKey", "metaValue"); - McpSchema.Tool tool = new McpSchema.Tool("addressTool", "addressTool", "Handles addresses", schema, null, null, - meta); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("addressTool") + .title("addressTool") + .description("Handles addresses") + .inputSchema(schema) + .meta(meta) + .build(); // Verify that meta value was preserved assertThat(tool.meta()).isNotNull(); @@ -871,9 +901,14 @@ void testToolWithAnnotations() throws Exception { McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool", false, false, false, false, false); - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", schemaJson, annotations); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, schemaJson) + .annotations(annotations) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -934,9 +969,14 @@ void testToolWithOutputSchema() throws Exception { } """; - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", inputSchemaJson, outputSchemaJson, null); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -996,10 +1036,15 @@ void testToolWithOutputSchemaAndAnnotations() throws Exception { McpSchema.ToolAnnotations annotations = new McpSchema.ToolAnnotations("A test tool with output", true, false, true, false, true); - McpSchema.Tool tool = new McpSchema.Tool("test-tool", "A test tool", inputSchemaJson, outputSchemaJson, - annotations); + McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("A test tool") + .inputSchema(JSON_MAPPER, inputSchemaJson) + .outputSchema(JSON_MAPPER, outputSchemaJson) + .annotations(annotations) + .build(); - String value = mapper.writeValueAsString(tool); + String value = JSON_MAPPER.writeValueAsString(tool); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1063,7 +1108,7 @@ void testToolDeserialization() throws Exception { } """; - McpSchema.Tool tool = mapper.readValue(toolJson, McpSchema.Tool.class); + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); assertThat(tool).isNotNull(); assertThat(tool.name()).isEqualTo("test-tool"); @@ -1097,7 +1142,7 @@ void testToolDeserializationWithoutOutputSchema() throws Exception { } """; - McpSchema.Tool tool = mapper.readValue(toolJson, McpSchema.Tool.class); + McpSchema.Tool tool = JSON_MAPPER.readValue(toolJson, McpSchema.Tool.class); assertThat(tool).isNotNull(); assertThat(tool.name()).isEqualTo("test-tool"); @@ -1115,7 +1160,7 @@ void testCallToolRequest() throws Exception { McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", arguments); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1127,14 +1172,14 @@ void testCallToolRequest() throws Exception { @Test void testCallToolRequestJsonArguments() throws Exception { - McpSchema.CallToolRequest request = new McpSchema.CallToolRequest("test-tool", """ + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(JSON_MAPPER, "test-tool", """ { "name": "test", "value": 42 } """); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1151,7 +1196,7 @@ void testCallToolRequestWithMeta() throws Exception { .arguments(Map.of("name", "test", "value", 42)) .progressToken("tool-progress-123") .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1170,14 +1215,18 @@ void testCallToolRequestBuilderWithJsonArguments() throws Exception { Map meta = new HashMap<>(); meta.put("progressToken", "json-builder-789"); - McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder().name("test-tool").arguments(""" - { - "name": "test", - "value": 42 - } - """).meta(meta).build(); + McpSchema.CallToolRequest request = McpSchema.CallToolRequest.builder() + .name("test-tool") + .arguments(JSON_MAPPER, """ + { + "name": "test", + "value": 42 + } + """) + .meta(meta) + .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1206,9 +1255,11 @@ void testCallToolRequestBuilderNameRequired() { void testCallToolResult() throws Exception { McpSchema.TextContent content = new McpSchema.TextContent("Tool execution result"); - McpSchema.CallToolResult result = new McpSchema.CallToolResult(Collections.singletonList(content), false); + McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() + .content(Collections.singletonList(content)) + .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1224,7 +1275,7 @@ void testCallToolResultBuilder() throws Exception { .isError(false) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1236,7 +1287,7 @@ void testCallToolResultBuilder() throws Exception { @Test void testCallToolResultBuilderWithMultipleContents() throws Exception { McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); - McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, "base64data", "image/png"); McpSchema.CallToolResult result = McpSchema.CallToolResult.builder() .addContent(textContent) @@ -1244,7 +1295,7 @@ void testCallToolResultBuilderWithMultipleContents() throws Exception { .isError(false) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1257,12 +1308,12 @@ void testCallToolResultBuilderWithMultipleContents() throws Exception { @Test void testCallToolResultBuilderWithContentList() throws Exception { McpSchema.TextContent textContent = new McpSchema.TextContent("Text result"); - McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, null, "base64data", "image/png"); + McpSchema.ImageContent imageContent = new McpSchema.ImageContent(null, "base64data", "image/png"); List contents = Arrays.asList(textContent, imageContent); McpSchema.CallToolResult result = McpSchema.CallToolResult.builder().content(contents).isError(true).build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1279,7 +1330,7 @@ void testCallToolResultBuilderWithErrorResult() throws Exception { .isError(true) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1288,27 +1339,6 @@ void testCallToolResultBuilderWithErrorResult() throws Exception { {"content":[{"type":"text","text":"Error: Operation failed"}],"isError":true}""")); } - @Test - void testCallToolResultStringConstructor() throws Exception { - // Test the existing string constructor alongside the builder - McpSchema.CallToolResult result1 = new McpSchema.CallToolResult("Simple result", false); - McpSchema.CallToolResult result2 = McpSchema.CallToolResult.builder() - .addTextContent("Simple result") - .isError(false) - .build(); - - String value1 = mapper.writeValueAsString(result1); - String value2 = mapper.writeValueAsString(result2); - - // Both should produce the same JSON - assertThat(value1).isEqualTo(value2); - assertThatJson(value1).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"content":[{"type":"text","text":"Simple result"}],"isError":false}""")); - } - // Sampling Tests @Test @@ -1336,7 +1366,7 @@ void testCreateMessageRequest() throws Exception { .metadata(metadata) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1357,7 +1387,7 @@ void testCreateMessageResult() throws Exception { .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1372,7 +1402,7 @@ void testCreateMessageResultUnknownStopReason() throws Exception { String input = """ {"role":"assistant","content":{"type":"text","text":"Assistant response"},"model":"gpt-4","stopReason":"arbitrary value"}"""; - McpSchema.CreateMessageResult value = mapper.readValue(input, McpSchema.CreateMessageResult.class); + McpSchema.CreateMessageResult value = JSON_MAPPER.readValue(input, McpSchema.CreateMessageResult.class); McpSchema.TextContent expectedContent = new McpSchema.TextContent("Assistant response"); McpSchema.CreateMessageResult expected = McpSchema.CreateMessageResult.builder() @@ -1393,7 +1423,7 @@ void testCreateElicitationRequest() throws Exception { Map.of("foo", Map.of("type", "string")))) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1409,7 +1439,7 @@ void testCreateElicitationResult() throws Exception { .message(McpSchema.ElicitResult.Action.ACCEPT) .build(); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1432,7 +1462,7 @@ void testElicitRequestWithMeta() throws Exception { .meta(meta) .build(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1449,7 +1479,7 @@ void testElicitRequestWithMeta() throws Exception { void testPaginatedRequestNoArgs() throws Exception { McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest(); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1465,7 +1495,7 @@ void testPaginatedRequestNoArgs() throws Exception { void testPaginatedRequestWithCursor() throws Exception { McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123"); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1484,7 +1514,7 @@ void testPaginatedRequestWithMeta() throws Exception { McpSchema.PaginatedRequest request = new McpSchema.PaginatedRequest("cursor123", meta); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1498,7 +1528,7 @@ void testPaginatedRequestWithMeta() throws Exception { @Test void testPaginatedRequestDeserialization() throws Exception { - McpSchema.PaginatedRequest request = mapper.readValue(""" + McpSchema.PaginatedRequest request = JSON_MAPPER.readValue(""" {"cursor":"test-cursor","_meta":{"progressToken":"test-token"}}""", McpSchema.PaginatedRequest.class); assertThat(request.cursor()).isEqualTo("test-cursor"); @@ -1516,7 +1546,7 @@ void testCompleteRequest() throws Exception { McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(promptRef, argument); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1540,7 +1570,7 @@ void testCompleteRequestWithMeta() throws Exception { McpSchema.CompleteRequest request = new McpSchema.CompleteRequest(resourceRef, argument, meta, null); - String value = mapper.writeValueAsString(request); + String value = JSON_MAPPER.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1559,7 +1589,7 @@ void testCompleteRequestWithMeta() throws Exception { void testRoot() throws Exception { McpSchema.Root root = new McpSchema.Root("file:///path/to/root", "Test Root", Map.of("metaKey", "metaValue")); - String value = mapper.writeValueAsString(root); + String value = JSON_MAPPER.writeValueAsString(root); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1575,7 +1605,7 @@ void testListRootsResult() throws Exception { McpSchema.ListRootsResult result = new McpSchema.ListRootsResult(Arrays.asList(root1, root2), "next-cursor"); - String value = mapper.writeValueAsString(result); + String value = JSON_MAPPER.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) @@ -1586,6 +1616,107 @@ void testListRootsResult() throws Exception { } + // Elicitation Capability Tests (Issue #724) + + @Test + void testElicitationCapabilityWithFormField() throws Exception { + // Test that elicitation with "form" field can be deserialized (2025-11-25 spec) + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityWithFormAndUrlFields() throws Exception { + // Test that elicitation with both "form" and "url" fields can be deserialized + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{"form":{},"url":{}}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBackwardCompatibilityEmptyObject() throws Exception { + // Test backward compatibility: empty elicitation {} should still work + String json = """ + {"protocolVersion":"2024-11-05","capabilities":{"elicitation":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}} + """; + + McpSchema.InitializeRequest request = JSON_MAPPER.readValue(json, McpSchema.InitializeRequest.class); + + assertThat(request).isNotNull(); + assertThat(request.capabilities()).isNotNull(); + assertThat(request.capabilities().elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderBackwardCompatibility() throws Exception { + // Test that the existing builder API still works and produces valid JSON + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder().elicitation().build(); + + assertThat(capabilities.elicitation()).isNotNull(); + + // Serialize and verify it produces valid JSON (should be {} for backward compat) + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"elicitation\""); + } + + @Test + void testElicitationCapabilitySerializationRoundTrip() throws Exception { + // Test that serialization and deserialization round-trip works + McpSchema.ClientCapabilities original = McpSchema.ClientCapabilities.builder().elicitation().build(); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.ClientCapabilities deserialized = JSON_MAPPER.readValue(json, McpSchema.ClientCapabilities.class); + + assertThat(deserialized.elicitation()).isNotNull(); + } + + @Test + void testElicitationCapabilityBuilderWithFormAndUrl() throws Exception { + // Test the new builder method that explicitly sets form and url support + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, true) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNotNull(); + + // Verify serialization produces the expected JSON + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThatJson(json).when(Option.IGNORING_ARRAY_ORDER).isObject().containsKey("elicitation"); + assertThat(json).contains("\"form\""); + assertThat(json).contains("\"url\""); + } + + @Test + void testElicitationCapabilityBuilderFormOnly() throws Exception { + // Test builder with form only + McpSchema.ClientCapabilities capabilities = McpSchema.ClientCapabilities.builder() + .elicitation(true, false) + .build(); + + assertThat(capabilities.elicitation()).isNotNull(); + assertThat(capabilities.elicitation().form()).isNotNull(); + assertThat(capabilities.elicitation().url()).isNull(); + + String json = JSON_MAPPER.writeValueAsString(capabilities); + assertThat(json).contains("\"form\""); + assertThat(json).doesNotContain("\"url\""); + } + // Progress Notification Tests @Test @@ -1593,7 +1724,7 @@ void testProgressNotificationWithMessage() throws Exception { McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-123", 0.5, 1.0, "Processing file 1 of 2", Map.of("key", "value")); - String value = mapper.writeValueAsString(notification); + String value = JSON_MAPPER.writeValueAsString(notification); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() @@ -1604,7 +1735,7 @@ void testProgressNotificationWithMessage() throws Exception { @Test void testProgressNotificationDeserialization() throws Exception { - McpSchema.ProgressNotification notification = mapper.readValue( + McpSchema.ProgressNotification notification = JSON_MAPPER.readValue( """ {"progressToken":"token-456","progress":0.75,"total":1.0,"message":"Almost done","_meta":{"key":"value"}}""", McpSchema.ProgressNotification.class); @@ -1621,7 +1752,7 @@ void testProgressNotificationWithoutMessage() throws Exception { McpSchema.ProgressNotification notification = new McpSchema.ProgressNotification("progress-token-789", 0.25, null, null); - String value = mapper.writeValueAsString(notification); + String value = JSON_MAPPER.writeValueAsString(notification); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) .isObject() diff --git a/mcp/pom.xml b/mcp/pom.xml index 1cf61c48f..937974228 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 1.1.0-SNAPSHOT mcp jar @@ -20,204 +20,20 @@ git@github.com/modelcontextprotocol/java-sdk.git - - - - biz.aQute.bnd - bnd-maven-plugin - ${bnd-maven-plugin.version} - - - bnd-process - - bnd-process - - - - - - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - ${project.build.outputDirectory}/META-INF/MANIFEST.MF - - - - - - - org.slf4j - slf4j-api - ${slf4j-api.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - - io.projectreactor - reactor-core - - - - com.networknt - json-schema-validator - ${json-schema-validator.version} - - - - - jakarta.servlet - jakarta.servlet-api - ${jakarta.servlet.version} - provided - - - - - - org.springframework - spring-webmvc - ${springframework.version} - test - - - - - io.projectreactor.netty - reactor-netty-http - test - - - - - org.springframework - spring-context - ${springframework.version} - test - - - - org.springframework - spring-test - ${springframework.version} - test - - - - org.assertj - assertj-core - ${assert4j.version} - test - - - org.junit.jupiter - junit-jupiter-api - ${junit.version} - test - - - org.junit.jupiter - junit-jupiter-params - ${junit.version} - test - - - org.mockito - mockito-core - ${mockito.version} - test - - - - - net.bytebuddy - byte-buddy - ${byte-buddy.version} - test - - - io.projectreactor - reactor-test - test - - - org.testcontainers - junit-jupiter - ${testcontainers.version} - test - - - - org.awaitility - awaitility - ${awaitility.version} - test - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - net.javacrumbs.json-unit - json-unit-assertj - ${json-unit-assertj.version} - test + io.modelcontextprotocol.sdk + mcp-json-jackson3 + 1.1.0-SNAPSHOT - - org.apache.tomcat.embed - tomcat-embed-core - ${tomcat.version} - test + io.modelcontextprotocol.sdk + mcp-core + 1.1.0-SNAPSHOT - - org.apache.tomcat.embed - tomcat-embed-websocket - ${tomcat.version} - test - - - - org.testcontainers - toxiproxy - ${toxiproxy.version} - test - - - - \ No newline at end of file + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java deleted file mode 100644 index 72b6e6c1b..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client.transport; - -import java.net.URI; -import java.net.http.HttpRequest; -import reactor.util.annotation.Nullable; - -/** - * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or - * Streamable HTTP transport. - * - * @author Daniel Garnier-Moiroux - */ -public interface SyncHttpRequestCustomizer { - - void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body); - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java deleted file mode 100644 index 9e18e189d..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * Default implementation for {@link McpTransportContext} which uses a Thread-safe map. - * Objects of this kind are mutable. - * - * @author Dariusz JΔ™drzejczyk - */ -public class DefaultMcpTransportContext implements McpTransportContext { - - private final Map storage; - - /** - * Create an empty instance. - */ - public DefaultMcpTransportContext() { - this.storage = new ConcurrentHashMap<>(); - } - - DefaultMcpTransportContext(Map storage) { - this.storage = storage; - } - - @Override - public Object get(String key) { - return this.storage.get(key); - } - - @Override - public void put(String key, Object value) { - this.storage.put(key, value); - } - - /** - * Allows copying the contents. - * @return new instance with the copy of the underlying map - */ - public McpTransportContext copy() { - return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java b/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java deleted file mode 100644 index 65b80957c..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -/** - * Names of HTTP headers in use by MCP HTTP transports. - * - * @author Dariusz JΔ™drzejczyk - */ -public interface HttpHeaders { - - /** - * Identifies individual MCP sessions. - */ - String MCP_SESSION_ID = "mcp-session-id"; - - /** - * Identifies events within an SSE Stream. - */ - String LAST_EVENT_ID = "Last-Event-ID"; - - /** - * Identifies the MCP protocol version. - */ - String PROTOCOL_VERSION = "MCP-Protocol-Version"; - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java deleted file mode 100644 index 6172d8637..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java +++ /dev/null @@ -1,61 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ - -package io.modelcontextprotocol.spec; - -import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; -import io.modelcontextprotocol.util.Assert; - -public class McpError extends RuntimeException { - - private JSONRPCError jsonRpcError; - - public McpError(JSONRPCError jsonRpcError) { - super(jsonRpcError.message()); - this.jsonRpcError = jsonRpcError; - } - - @Deprecated - public McpError(Object error) { - super(error.toString()); - } - - public JSONRPCError getJsonRpcError() { - return jsonRpcError; - } - - public static Builder builder(int errorCode) { - return new Builder(errorCode); - } - - public static class Builder { - - private final int code; - - private String message; - - private Object data; - - private Builder(int code) { - this.code = code; - } - - public Builder message(String message) { - this.message = message; - return this; - } - - public Builder data(Object data) { - this.data = data; - return this; - } - - public McpError build() { - Assert.hasText(message, "message must not be empty"); - return new McpError(new JSONRPCError(code, message, data)); - } - - } - -} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java deleted file mode 100644 index ec23e21dc..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import eu.rekawek.toxiproxy.Proxy; -import eu.rekawek.toxiproxy.ToxiproxyClient; -import eu.rekawek.toxiproxy.model.ToxicDirection; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransport; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.Network; -import org.testcontainers.containers.ToxiproxyContainer; -import org.testcontainers.containers.wait.strategy.Wait; -import reactor.test.StepVerifier; - -import java.io.IOException; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; - -import static org.assertj.core.api.Assertions.assertThatCode; - -/** - * Resiliency test suite for the {@link McpAsyncClient} that can be used with different - * {@link McpTransport} implementations that support Streamable HTTP. - * - * The purpose of these tests is to allow validating the transport layer resiliency - * instead of the functionality offered by the logical layer of MCP concepts such as - * tools, resources, prompts, etc. - * - * @author Dariusz JΔ™drzejczyk - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpAsyncClientResiliencyTests { - - private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); - - static Network network = Network.newNetwork(); - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withNetwork(network) - .withNetworkAliases("everything-server") - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) - .withExposedPorts(8474, 3000); - - static Proxy proxy; - - static { - container.start(); - - toxiproxy.start(); - - final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); - try { - proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); - } - catch (IOException e) { - throw new RuntimeException("Can't create proxy!", e); - } - - final String ipAddressViaToxiproxy = toxiproxy.getHost(); - final int portViaToxiproxy = toxiproxy.getMappedPort(3000); - - host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; - } - - static void disconnect() { - long start = System.nanoTime(); - try { - // disconnect - // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", - // ToxicDirection.DOWNSTREAM, 0); - // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", - // ToxicDirection.UPSTREAM, 0); - proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); - proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); - logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); - } - catch (IOException e) { - throw new RuntimeException("Failed to disconnect", e); - } - } - - static void reconnect() { - long start = System.nanoTime(); - try { - proxy.toxics().get("RESET_UPSTREAM").remove(); - proxy.toxics().get("RESET_DOWNSTREAM").remove(); - // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); - // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); - logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); - } - catch (IOException e) { - throw new RuntimeException("Failed to reconnect", e); - } - } - - static void restartMcpServer() { - container.stop(); - container.start(); - } - - abstract McpClientTransport createMcpTransport(); - - protected Duration getRequestTimeout() { - return Duration.ofSeconds(14); - } - - protected Duration getInitializationTimeout() { - return Duration.ofSeconds(2); - } - - McpAsyncClient client(McpClientTransport transport) { - return client(transport, Function.identity()); - } - - McpAsyncClient client(McpClientTransport transport, Function customizer) { - AtomicReference client = new AtomicReference<>(); - - assertThatCode(() -> { - McpClient.AsyncSpec builder = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .initializationTimeout(getInitializationTimeout()) - .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); - builder = customizer.apply(builder); - client.set(builder.build()); - }).doesNotThrowAnyException(); - - return client.get(); - } - - void withClient(McpClientTransport transport, Consumer c) { - withClient(transport, Function.identity(), c); - } - - void withClient(McpClientTransport transport, Function customizer, - Consumer c) { - var client = client(transport, customizer); - try { - c.accept(client); - } - finally { - StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); - } - } - - @Test - void testPing() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); - - disconnect(); - - StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); - - reconnect(); - - StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); - }); - } - - @Test - void testSessionInvalidation() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); - - restartMcpServer(); - - // The first try will face the session mismatch exception and the second one - // will go through the re-initialization process. - StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); - }); - } - - @Test - void testCallTool() { - withClient(createMcpTransport(), mcpAsyncClient -> { - AtomicReference> tools = new AtomicReference<>(); - StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); - StepVerifier.create(mcpAsyncClient.listTools()) - .consumeNextWith(list -> tools.set(list.tools())) - .verifyComplete(); - - disconnect(); - - String name = tools.get().get(0).name(); - // Assuming this is the echo tool - McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); - StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); - - reconnect(); - - StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); - }); - } - - @Test - void testSessionClose() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); - // In case of Streamable HTTP this call should issue a HTTP DELETE request - // invalidating the session - StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); - // The next use should immediately re-initialize with no issue and send the - // request without any broken connections. - StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); - }); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java deleted file mode 100644 index 3626d8ca0..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ /dev/null @@ -1,835 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ResourceContents; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; -import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; -import io.modelcontextprotocol.spec.McpTransport; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - -/** - * Test suite for the {@link McpAsyncClient} that can be used with different - * {@link McpTransport} implementations. - * - * @author Christian Tzolov - * @author Dariusz JΔ™drzejczyk - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpAsyncClientTests { - - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - - abstract protected McpClientTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - protected Duration getRequestTimeout() { - return Duration.ofSeconds(14); - } - - protected Duration getInitializationTimeout() { - return Duration.ofSeconds(2); - } - - McpAsyncClient client(McpClientTransport transport) { - return client(transport, Function.identity()); - } - - McpAsyncClient client(McpClientTransport transport, Function customizer) { - AtomicReference client = new AtomicReference<>(); - - assertThatCode(() -> { - McpClient.AsyncSpec builder = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .initializationTimeout(getInitializationTimeout()) - .sampling(req -> Mono.just(new CreateMessageResult(McpSchema.Role.USER, - new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN))) - .capabilities(ClientCapabilities.builder().roots(true).sampling().build()); - builder = customizer.apply(builder); - client.set(builder.build()); - }).doesNotThrowAnyException(); - - return client.get(); - } - - void withClient(McpClientTransport transport, Consumer c) { - withClient(transport, Function.identity(), c); - } - - void withClient(McpClientTransport transport, Function customizer, - Consumer c) { - var client = client(transport, customizer); - try { - c.accept(client); - } - finally { - StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); - } - } - - @BeforeEach - void setUp() { - onStart(); - } - - @AfterEach - void tearDown() { - onClose(); - } - - void verifyNotificationSucceedsWithImplicitInitialization(Function> operation, - String action) { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(operation.apply(mcpAsyncClient)).verifyComplete(); - }); - } - - void verifyCallSucceedsWithImplicitInitialization(Function> operation, String action) { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(operation.apply(mcpAsyncClient)).expectNextCount(1).verifyComplete(); - }); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListToolsWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(McpSchema.FIRST_PAGE), "listing tools"); - } - - @Test - void testListTools() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(McpSchema.FIRST_PAGE))) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllTools() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllToolsReturnsImmutableList() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools())) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull(); - // Verify that the returned list is immutable - assertThatThrownBy(() -> result.tools().add(new Tool("test", "test", "{\"type\":\"object\"}"))) - .isInstanceOf(UnsupportedOperationException.class); - }) - .verifyComplete(); - }); - } - - @Test - void testPingWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); - } - - @Test - void testPing() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) - .expectNextCount(1) - .verifyComplete(); - }); - } - - @Test - void testCallToolWithoutInitialization() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); - } - - @Test - void testCallTool() { - withClient(createMcpTransport(), mcpAsyncClient -> { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) - .consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }) - .verifyComplete(); - }); - } - - @Test - void testCallToolWithInvalidTool() { - withClient(createMcpTransport(), mcpAsyncClient -> { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", - Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) - .consumeErrorWith( - e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) - .verify(); - }); - } - - @ParameterizedTest - @ValueSource(strings = { "success", "error", "debug" }) - void testCallToolWithMessageAnnotations(String messageType) { - McpClientTransport transport = createMcpTransport(); - - withClient(transport, mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize() - .then(mcpAsyncClient.callTool(new McpSchema.CallToolRequest("annotatedMessage", - Map.of("messageType", messageType, "includeImage", true))))) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.isError()).isNotEqualTo(true); - assertThat(result.content()).isNotEmpty(); - assertThat(result.content()).allSatisfy(content -> { - switch (content.type()) { - case "text": - McpSchema.TextContent textContent = assertInstanceOf(McpSchema.TextContent.class, - content); - assertThat(textContent.text()).isNotEmpty(); - assertThat(textContent.annotations()).isNotNull(); - - switch (messageType) { - case "error": - assertThat(textContent.annotations().priority()).isEqualTo(1.0); - assertThat(textContent.annotations().audience()) - .containsOnly(McpSchema.Role.USER, McpSchema.Role.ASSISTANT); - break; - case "success": - assertThat(textContent.annotations().priority()).isEqualTo(0.7); - assertThat(textContent.annotations().audience()) - .containsExactly(McpSchema.Role.USER); - break; - case "debug": - assertThat(textContent.annotations().priority()).isEqualTo(0.3); - assertThat(textContent.annotations().audience()) - .containsExactly(McpSchema.Role.ASSISTANT); - break; - default: - throw new IllegalStateException("Unexpected value: " + content.type()); - } - break; - case "image": - McpSchema.ImageContent imageContent = assertInstanceOf(McpSchema.ImageContent.class, - content); - assertThat(imageContent.data()).isNotEmpty(); - assertThat(imageContent.annotations()).isNotNull(); - assertThat(imageContent.annotations().priority()).isEqualTo(0.5); - assertThat(imageContent.annotations().audience()).containsExactly(McpSchema.Role.USER); - break; - default: - fail("Unexpected content type: " + content.type()); - } - }); - }) - .verifyComplete(); - }); - } - - @Test - void testListResourcesWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(McpSchema.FIRST_PAGE), - "listing resources"); - } - - @Test - void testListResources() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(McpSchema.FIRST_PAGE))) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllResources() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllResourcesReturnsImmutableList() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources())) - .consumeNextWith(result -> { - assertThat(result.resources()).isNotNull(); - // Verify that the returned list is immutable - assertThatThrownBy( - () -> result.resources().add(Resource.builder().uri("test://uri").name("test").build())) - .isInstanceOf(UnsupportedOperationException.class); - }) - .verifyComplete(); - }); - } - - @Test - void testMcpAsyncClientState() { - withClient(createMcpTransport(), mcpAsyncClient -> { - assertThat(mcpAsyncClient).isNotNull(); - }); - } - - @Test - void testListPromptsWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.listPrompts(McpSchema.FIRST_PAGE), - "listing " + "prompts"); - } - - @Test - void testListPrompts() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(McpSchema.FIRST_PAGE))) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllPrompts() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllPromptsReturnsImmutableList() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts())) - .consumeNextWith(result -> { - assertThat(result.prompts()).isNotNull(); - // Verify that the returned list is immutable - assertThatThrownBy(() -> result.prompts().add(new Prompt("test", "test", "test", null))) - .isInstanceOf(UnsupportedOperationException.class); - }) - .verifyComplete(); - }); - } - - @Test - void testGetPromptWithoutInitialization() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - verifyCallSucceedsWithImplicitInitialization(client -> client.getPrompt(request), "getting " + "prompts"); - } - - @Test - void testGetPrompt() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier - .create(mcpAsyncClient.initialize() - .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); - }); - } - - @Test - void testRootsListChangedWithoutInitialization() { - verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), - "sending roots list changed notification"); - } - - @Test - void testRootsListChanged() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) - .verifyComplete(); - }); - } - - @Test - void testInitializeWithRootsListProviders() { - withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), - client -> { - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); - }); - } - - @Test - void testAddRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); - }); - } - - @Test - void testAddRootWithNullValue() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.addRoot(null)) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Root must not be null")) - .verify(); - }); - } - - @Test - void testRemoveRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); - - StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); - }); - } - - @Test - void testRemoveNonExistentRoot() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .consumeErrorWith(e -> assertThat(e).isInstanceOf(IllegalStateException.class) - .hasMessage("Root with uri 'nonexistent-uri' not found")) - .verify(); - }); - } - - @Test - void testReadResource() { - withClient(createMcpTransport(), client -> { - Flux resources = client.initialize() - .then(client.listResources(null)) - .flatMapMany(r -> Flux.fromIterable(r.resources())) - .flatMap(r -> client.readResource(r)); - - StepVerifier.create(resources).recordWith(ArrayList::new).consumeRecordedWith(readResourceResults -> { - - for (ReadResourceResult result : readResourceResults) { - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull().isNotEmpty(); - - // Validate each content item - for (ResourceContents content : result.contents()) { - assertThat(content).isNotNull(); - assertThat(content.uri()).isNotNull().isNotEmpty(); - assertThat(content.mimeType()).isNotNull().isNotEmpty(); - - // Validate content based on its type with more comprehensive - // checks - switch (content.mimeType()) { - case "text/plain" -> { - TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, - content); - assertThat(textContent.text()).isNotNull().isNotEmpty(); - assertThat(textContent.uri()).isNotEmpty(); - } - case "application/octet-stream" -> { - BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, - content); - assertThat(blobContent.blob()).isNotNull().isNotEmpty(); - assertThat(blobContent.uri()).isNotNull().isNotEmpty(); - // Validate base64 encoding format - assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); - } - default -> { - - // Still validate basic properties - if (content instanceof TextResourceContents textContent) { - assertThat(textContent.text()).isNotNull(); - } - else if (content instanceof BlobResourceContents blobContent) { - assertThat(blobContent.blob()).isNotNull(); - } - } - } - } - } - }) - .expectNextCount(10) // Expect 10 elements - .verifyComplete(); - }); - } - - @Test - void testListResourceTemplatesWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(McpSchema.FIRST_PAGE), - "listing resource templates"); - } - - @Test - void testListResourceTemplates() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier - .create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates(McpSchema.FIRST_PAGE))) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllResourceTemplates() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); - }); - } - - @Test - void testListAllResourceTemplatesReturnsImmutableList() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result.resourceTemplates()).isNotNull(); - // Verify that the returned list is immutable - assertThatThrownBy(() -> result.resourceTemplates() - .add(new McpSchema.ResourceTemplate("test://template", "test", "test", null, null, null))) - .isInstanceOf(UnsupportedOperationException.class); - }) - .verifyComplete(); - }); - } - - // @Test - void testResourceSubscription() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); - - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); - }); - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - withClient(createMcpTransport(), - builder -> builder - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer( - resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), - mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.initialize()) - .expectNextMatches(Objects::nonNull) - .verifyComplete(); - }); - } - - @Test - void testInitializeWithSamplingCapability() { - ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); - CreateMessageResult createMessageResult = CreateMessageResult.builder() - .message("test") - .model("test-model") - .build(); - withClient(createMcpTransport(), - builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), - client -> { - StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); - }); - } - - @Test - void testInitializeWithElicitationCapability() { - ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); - ElicitResult elicitResult = ElicitResult.builder() - .message(ElicitResult.Action.ACCEPT) - .content(Map.of("foo", "bar")) - .build(); - withClient(createMcpTransport(), - builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), - client -> { - StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); - }); - } - - @Test - void testInitializeWithAllCapabilities() { - var capabilities = ClientCapabilities.builder() - .experimental(Map.of("feature", "test")) - .roots(true) - .sampling() - .build(); - - Function> samplingHandler = request -> Mono - .just(CreateMessageResult.builder().message("test").model("test-model").build()); - - Function> elicitationHandler = request -> Mono - .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); - - withClient(createMcpTransport(), - builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), - client -> - - StepVerifier.create(client.initialize()).assertNext(result -> { - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - }).verifyComplete()); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevelsWithoutInitialization() { - verifyNotificationSucceedsWithImplicitInitialization( - client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); - } - - @Test - void testLoggingLevels() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier - .create(mcpAsyncClient.initialize() - .thenMany(Flux.fromArray(McpSchema.LoggingLevel.values()).flatMap(mcpAsyncClient::setLoggingLevel))) - .verifyComplete(); - }); - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - - withClient(createMcpTransport(), - builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), - client -> { - StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); - StepVerifier.create(client.closeGracefully()).verifyComplete(); - - }); - - } - - @Test - void testLoggingWithNullNotification() { - withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) - .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) - .verify(); - }); - } - - @Test - void testSampling() { - McpClientTransport transport = createMcpTransport(); - - final String message = "Hello, world!"; - final String response = "Goodbye, world!"; - final int maxTokens = 100; - - AtomicReference receivedPrompt = new AtomicReference<>(); - AtomicReference receivedMessage = new AtomicReference<>(); - AtomicInteger receivedMaxTokens = new AtomicInteger(); - - withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build()) - .sampling(request -> { - McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class, - request.messages().get(0).content()); - receivedPrompt.set(request.systemPrompt()); - receivedMessage.set(messageText.text()); - receivedMaxTokens.set(request.maxTokens()); - - return Mono - .just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response), - "modelId", McpSchema.CreateMessageResult.StopReason.END_TURN)); - }), client -> { - StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); - - StepVerifier.create(client.callTool( - new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)))) - .consumeNextWith(result -> { - // Verify tool response to ensure our sampling response was passed - // through - assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class); - assertThat(result.content()).allSatisfy(content -> { - if (!(content instanceof McpSchema.TextContent text)) - return; - - assertThat(text.text()).endsWith(response); // Prefixed - }); - - // Verify sampling request parameters received in our callback - assertThat(receivedPrompt.get()).isNotEmpty(); - assertThat(receivedMessage.get()).endsWith(message); // Prefixed - assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens); - }) - .verifyComplete(); - }); - } - - // --------------------------------------- - // Progress Notification Tests - // --------------------------------------- - - @Test - void testProgressConsumer() { - Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); - List receivedNotifications = new CopyOnWriteArrayList<>(); - - withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { - receivedNotifications.add(notification); - sink.tryEmitNext(notification); - return Mono.empty(); - }), client -> { - StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); - - // Call a tool that sends progress notifications - CallToolRequest request = CallToolRequest.builder() - .name("longRunningOperation") - .arguments(Map.of("duration", 1, "steps", 2)) - .progressToken("test-token") - .build(); - - StepVerifier.create(client.callTool(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - }).verifyComplete(); - - // Use StepVerifier to verify the progress notifications via the sink - StepVerifier.create(sink.asFlux()).expectNextCount(2).thenCancel().verify(Duration.ofSeconds(3)); - - assertThat(receivedNotifications).hasSize(2); - assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); - }); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java deleted file mode 100644 index c74255060..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ /dev/null @@ -1,699 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Function; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.ListResourceTemplatesResult; -import io.modelcontextprotocol.spec.McpSchema.ListResourcesResult; -import io.modelcontextprotocol.spec.McpSchema.ListToolsResult; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ResourceContents; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest; -import io.modelcontextprotocol.spec.McpSchema.TextContent; -import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; - -/** - * Unit tests for MCP Client Session functionality. - * - * @author Christian Tzolov - * @author Dariusz JΔ™drzejczyk - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpSyncClientTests { - - private static final Logger logger = LoggerFactory.getLogger(AbstractMcpSyncClientTests.class); - - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - - abstract protected McpClientTransport createMcpTransport(); - - protected void onStart() { - } - - protected void onClose() { - } - - protected Duration getRequestTimeout() { - return Duration.ofSeconds(14); - } - - protected Duration getInitializationTimeout() { - return Duration.ofSeconds(2); - } - - McpSyncClient client(McpClientTransport transport) { - return client(transport, Function.identity()); - } - - McpSyncClient client(McpClientTransport transport, Function customizer) { - AtomicReference client = new AtomicReference<>(); - - assertThatCode(() -> { - McpClient.SyncSpec builder = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()); - builder = customizer.apply(builder); - client.set(builder.build()); - }).doesNotThrowAnyException(); - - return client.get(); - } - - void withClient(McpClientTransport transport, Consumer c) { - withClient(transport, Function.identity(), c); - } - - void withClient(McpClientTransport transport, Function customizer, - Consumer c) { - var client = client(transport, customizer); - try { - c.accept(client); - } - finally { - assertThat(client.closeGracefully()).isTrue(); - } - } - - @BeforeEach - void setUp() { - onStart(); - - } - - @AfterEach - void tearDown() { - onClose(); - } - - static final Object DUMMY_RETURN_VALUE = new Object(); - - void verifyNotificationSucceedsWithImplicitInitialization(Consumer operation, String action) { - verifyCallSucceedsWithImplicitInitialization(client -> { - operation.accept(client); - return DUMMY_RETURN_VALUE; - }, action); - } - - void verifyCallSucceedsWithImplicitInitialization(Function blockingOperation, String action) { - withClient(createMcpTransport(), mcpSyncClient -> { - StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)) - // Offload the blocking call to the real scheduler - .subscribeOn(Schedulers.boundedElastic())).expectNextCount(1).verifyComplete(); - }); - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); - - assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Request timeout must not be null"); - } - - @Test - void testListToolsWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.listTools(McpSchema.FIRST_PAGE), "listing tools"); - } - - @Test - void testListTools() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(McpSchema.FIRST_PAGE); - - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }); - }); - } - - @Test - void testListAllTools() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(); - - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); - - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }); - }); - } - - @Test - void testCallToolsWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization( - client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), "calling tools"); - } - - @Test - void testCallTools() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - - assertThat(toolResult).isNotNull().satisfies(result -> { - - assertThat(result.content()).hasSize(1); - - TextContent content = (TextContent) result.content().get(0); - - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); - }); - }); - } - - @Test - void testPingWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.ping(), "pinging the server"); - } - - @Test - void testPing() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); - }); - } - - @Test - void testCallToolWithoutInitialization() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - verifyCallSucceedsWithImplicitInitialization(client -> client.callTool(callToolRequest), "calling tools"); - } - - @Test - void testCallTool() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); - }); - }); - } - - @Test - void testCallToolWithInvalidTool() { - withClient(createMcpTransport(), mcpSyncClient -> { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); - }); - } - - @ParameterizedTest - @ValueSource(strings = { "success", "error", "debug" }) - void testCallToolWithMessageAnnotations(String messageType) { - McpClientTransport transport = createMcpTransport(); - - withClient(transport, client -> { - client.initialize(); - - McpSchema.CallToolResult result = client.callTool(new McpSchema.CallToolRequest("annotatedMessage", - Map.of("messageType", messageType, "includeImage", true))); - - assertThat(result).isNotNull(); - assertThat(result.isError()).isNotEqualTo(true); - assertThat(result.content()).isNotEmpty(); - assertThat(result.content()).allSatisfy(content -> { - switch (content.type()) { - case "text": - McpSchema.TextContent textContent = assertInstanceOf(McpSchema.TextContent.class, content); - assertThat(textContent.text()).isNotEmpty(); - assertThat(textContent.annotations()).isNotNull(); - - switch (messageType) { - case "error": - assertThat(textContent.annotations().priority()).isEqualTo(1.0); - assertThat(textContent.annotations().audience()).containsOnly(McpSchema.Role.USER, - McpSchema.Role.ASSISTANT); - break; - case "success": - assertThat(textContent.annotations().priority()).isEqualTo(0.7); - assertThat(textContent.annotations().audience()).containsExactly(McpSchema.Role.USER); - break; - case "debug": - assertThat(textContent.annotations().priority()).isEqualTo(0.3); - assertThat(textContent.annotations().audience()) - .containsExactly(McpSchema.Role.ASSISTANT); - break; - default: - throw new IllegalStateException("Unexpected value: " + content.type()); - } - break; - case "image": - McpSchema.ImageContent imageContent = assertInstanceOf(McpSchema.ImageContent.class, content); - assertThat(imageContent.data()).isNotEmpty(); - assertThat(imageContent.annotations()).isNotNull(); - assertThat(imageContent.annotations().priority()).isEqualTo(0.5); - assertThat(imageContent.annotations().audience()).containsExactly(McpSchema.Role.USER); - break; - default: - fail("Unexpected content type: " + content.type()); - } - }); - }); - } - - @Test - void testRootsListChangedWithoutInitialization() { - verifyNotificationSucceedsWithImplicitInitialization(client -> client.rootsListChangedNotification(), - "sending roots list changed notification"); - } - - @Test - void testRootsListChanged() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); - }); - } - - @Test - void testListResourcesWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.listResources(McpSchema.FIRST_PAGE), - "listing resources"); - } - - @Test - void testListResources() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(McpSchema.FIRST_PAGE); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }); - } - - @Test - void testListAllResources() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }); - } - - @Test - void testClientSessionState() { - withClient(createMcpTransport(), mcpSyncClient -> { - assertThat(mcpSyncClient).isNotNull(); - }); - } - - @Test - void testInitializeWithRootsListProviders() { - withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), - mcpSyncClient -> { - - assertThatCode(() -> { - mcpSyncClient.initialize(); - mcpSyncClient.close(); - }).doesNotThrowAnyException(); - }); - } - - @Test - void testAddRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); - }); - } - - @Test - void testAddRootWithNullValue() { - withClient(createMcpTransport(), mcpSyncClient -> { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); - }); - } - - @Test - void testRemoveRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); - }); - } - - @Test - void testRemoveNonExistentRoot() { - withClient(createMcpTransport(), mcpSyncClient -> { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); - }); - } - - @Test - void testReadResourceWithoutInitialization() { - AtomicReference> resources = new AtomicReference<>(); - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - resources.set(mcpSyncClient.listResources().resources()); - }); - - verifyCallSucceedsWithImplicitInitialization(client -> client.readResource(resources.get().get(0)), - "reading resources"); - } - - @Test - void testReadResource() { - withClient(createMcpTransport(), mcpSyncClient -> { - - int readResourceCount = 0; - - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull(); - assertThat(resources.resources()).isNotNull(); - - assertThat(resources.resources()).isNotNull().isNotEmpty(); - - // Test reading each resource individually for better error isolation - for (Resource resource : resources.resources()) { - ReadResourceResult result = mcpSyncClient.readResource(resource); - - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull().isNotEmpty(); - - readResourceCount++; - - // Validate each content item - for (ResourceContents content : result.contents()) { - assertThat(content).isNotNull(); - assertThat(content.uri()).isNotNull().isNotEmpty(); - assertThat(content.mimeType()).isNotNull().isNotEmpty(); - - // Validate content based on its type with more comprehensive - // checks - switch (content.mimeType()) { - case "text/plain" -> { - TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, content); - assertThat(textContent.text()).isNotNull().isNotEmpty(); - // Verify URI consistency - assertThat(textContent.uri()).isEqualTo(resource.uri()); - } - case "application/octet-stream" -> { - BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, content); - assertThat(blobContent.blob()).isNotNull().isNotEmpty(); - // Verify URI consistency - assertThat(blobContent.uri()).isEqualTo(resource.uri()); - // Validate base64 encoding format - assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$"); - } - default -> { - // More flexible handling of additional MIME types - // Log the unexpected type for debugging but don't fail - // the test - logger.warn("Warning: Encountered unexpected MIME type: {} for resource: {}", - content.mimeType(), resource.uri()); - - // Still validate basic properties - if (content instanceof TextResourceContents textContent) { - assertThat(textContent.text()).isNotNull(); - } - else if (content instanceof BlobResourceContents blobContent) { - assertThat(blobContent.blob()).isNotNull(); - } - } - } - } - } - - // Assert that we read exactly 10 resources - assertThat(readResourceCount).isEqualTo(10); - }); - } - - @Test - void testListResourceTemplatesWithoutInitialization() { - verifyCallSucceedsWithImplicitInitialization(client -> client.listResourceTemplates(McpSchema.FIRST_PAGE), - "listing resource templates"); - } - - @Test - void testListResourceTemplates() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(McpSchema.FIRST_PAGE); - - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }); - } - - @Test - void testListAllResourceTemplates() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(); - - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }); - } - - // @Test - void testResourceSubscription() { - withClient(createMcpTransport(), mcpSyncClient -> { - ListResourcesResult resources = mcpSyncClient.listResources(null); - - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } - }); - } - - @Test - void testNotificationHandlers() { - AtomicBoolean toolsNotificationReceived = new AtomicBoolean(false); - AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); - AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - - withClient(createMcpTransport(), - builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), - client -> { - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - }); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @Test - void testLoggingLevelsWithoutInitialization() { - verifyNotificationSucceedsWithImplicitInitialization( - client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), "setting logging level"); - } - - @Test - void testLoggingLevels() { - withClient(createMcpTransport(), mcpSyncClient -> { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } - }); - } - - @Test - void testLoggingConsumer() { - AtomicBoolean logReceived = new AtomicBoolean(false); - withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> logReceived.set(true)), client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); - }); - } - - @Test - void testLoggingWithNullNotification() { - withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null")); - } - - @Test - void testSampling() { - McpClientTransport transport = createMcpTransport(); - - final String message = "Hello, world!"; - final String response = "Goodbye, world!"; - final int maxTokens = 100; - - AtomicReference receivedPrompt = new AtomicReference<>(); - AtomicReference receivedMessage = new AtomicReference<>(); - AtomicInteger receivedMaxTokens = new AtomicInteger(); - - withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build()) - .sampling(request -> { - McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class, - request.messages().get(0).content()); - receivedPrompt.set(request.systemPrompt()); - receivedMessage.set(messageText.text()); - receivedMaxTokens.set(request.maxTokens()); - - return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response), - "modelId", McpSchema.CreateMessageResult.StopReason.END_TURN); - }), client -> { - client.initialize(); - - McpSchema.CallToolResult result = client.callTool( - new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))); - - // Verify tool response to ensure our sampling response was passed through - assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class); - assertThat(result.content()).allSatisfy(content -> { - if (!(content instanceof McpSchema.TextContent text)) - return; - - assertThat(text.text()).endsWith(response); // Prefixed - }); - - // Verify sampling request parameters received in our callback - assertThat(receivedPrompt.get()).isNotEmpty(); - assertThat(receivedMessage.get()).endsWith(message); // Prefixed - assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens); - }); - } - - // --------------------------------------- - // Progress Notification Tests - // --------------------------------------- - - @Test - void testProgressConsumer() { - AtomicInteger progressNotificationCount = new AtomicInteger(0); - List receivedNotifications = new CopyOnWriteArrayList<>(); - CountDownLatch latch = new CountDownLatch(2); - - withClient(createMcpTransport(), builder -> builder.progressConsumer(notification -> { - System.out.println("Received progress notification: " + notification); - receivedNotifications.add(notification); - progressNotificationCount.incrementAndGet(); - latch.countDown(); - }), client -> { - client.initialize(); - - // Call a tool that sends progress notifications - CallToolRequest request = CallToolRequest.builder() - .name("longRunningOperation") - .arguments(Map.of("duration", 1, "steps", 2)) - .progressToken("test-token") - .build(); - - CallToolResult result = client.callTool(request); - - assertThat(result).isNotNull(); - - try { - // Wait for progress notifications to be processed - latch.await(3, TimeUnit.SECONDS); - } - catch (InterruptedException e) { - e.printStackTrace(); - } - - assertThat(progressNotificationCount.get()).isEqualTo(2); - - assertThat(receivedNotifications).isNotEmpty(); - assertThat(receivedNotifications.get(0).progressToken()).isEqualTo("test-token"); - }); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java deleted file mode 100644 index 7f00de60e..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; - -@Timeout(15) -public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { - - static String host = "http://localhost:3001"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js streamableHttp") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return HttpClientStreamableHttpTransport.builder(host).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - public void onClose() { - container.stop(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java deleted file mode 100644 index 8646c1b4c..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; -import org.junit.jupiter.api.Timeout; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.Wait; - -/** - * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. - * - * @author Christian Tzolov - */ -@Timeout(15) // Giving extra time beyond the client timeout -class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { - - String host = "http://localhost:3003"; - - // Uses the https://github.com/tzolov/mcp-everything-server-docker-image - @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") - .withCommand("node dist/index.js sse") - .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) - .withExposedPorts(3001) - .waitingFor(Wait.forHttp("/").forStatusCode(404)); - - @Override - protected McpClientTransport createMcpTransport() { - return HttpClientSseClientTransport.builder(host).build(); - } - - @Override - protected void onStart() { - container.start(); - int port = container.getMappedPort(3001); - host = "http://" + container.getHost() + ":" + port; - } - - @Override - protected void onClose() { - container.stop(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java deleted file mode 100644 index ae33898b7..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2024-2025 the original author or authors. - */ - -package io.modelcontextprotocol.client; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ProtocolVersions; - -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; - -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; - -import static org.assertj.core.api.Assertions.assertThatCode; - -class McpAsyncClientTests { - - public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", - "1.0.0"); - - public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() - .build(); - - public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( - ProtocolVersions.MCP_2024_11_05, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); - - private static final String CONTEXT_KEY = "context.key"; - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - - @Test - void validateContextPassedToTransportConnect() { - McpClientTransport transport = new McpClientTransport() { - Function, Mono> handler; - - final AtomicReference contextValue = new AtomicReference<>(); - - @Override - public Mono connect( - Function, Mono> handler) { - return Mono.deferContextual(ctx -> { - this.handler = handler; - if (ctx.hasKey(CONTEXT_KEY)) { - this.contextValue.set(ctx.get(CONTEXT_KEY)); - } - return Mono.empty(); - }); - } - - @Override - public Mono closeGracefully() { - return Mono.empty(); - } - - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (!"hello".equals(this.contextValue.get())) { - return Mono.error(new RuntimeException("Context value not propagated via #connect method")); - } - // We're only interested in handling the init request to provide an init - // response - if (!(message instanceof McpSchema.JSONRPCRequest)) { - return Mono.empty(); - } - McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, - ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); - return handler.apply(Mono.just(initResponse)).then(); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return OBJECT_MAPPER.convertValue(data, typeRef); - } - }; - - assertThatCode(() -> { - McpAsyncClient client = McpClient.async(transport).build(); - client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); - }).doesNotThrowAnyException(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java deleted file mode 100644 index 0ba8bf929..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ /dev/null @@ -1,497 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import java.time.Duration; -import java.util.List; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -/** - * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. - * - * @author Christian Tzolov - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpAsyncServerTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport provider must not be null"); - - assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); - var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); - } - - @Test - void testImmediateClose() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - @Deprecated - void testAddTool() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddToolCall() { - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() - .tool(newTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build())).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - @Deprecated - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier - .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, - (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.addTool(McpServerFeatures.AsyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build())).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); - - assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); - } - - @Test - void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); - List specs = List.of( - McpServerFeatures.AsyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build(), - McpServerFeatures.AsyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build() // Duplicate! - ); - - assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(specs) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Tool with name 'batch-list-tool' is already registered."); - } - - @Test - void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - - assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(McpServerFeatures.AsyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build(), - McpServerFeatures.AsyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build() // Duplicate! - ) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); - } - - @Test - void testRemoveTool() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) - .build(); - - StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testNotifyResourcesUpdated() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - StepVerifier - .create(mcpAsyncServer - .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) - .verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( - resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullSpecification() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); - }); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( - resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - - StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - @Test - void testRemoveResourceWithoutCapability() { - // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - }); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullSpecification() { - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) - .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); - }); - } - - @Test - void testAddPromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( - prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePromptWithoutCapability() { - // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - }); - } - - @Test - void testRemovePrompt() { - String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; - - Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( - prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - - var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specification) - .build(); - - StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); - - assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - }); - - assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeHandlers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - }))) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) - .doesNotThrowAnyException(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java deleted file mode 100644 index acaf0c8a9..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java +++ /dev/null @@ -1,1594 +0,0 @@ -/* - * Copyright 2024 - 2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; - -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; -import io.modelcontextprotocol.spec.McpSchema.CompleteResult; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptArgument; -import io.modelcontextprotocol.spec.McpSchema.PromptReference; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.util.Utils; -import net.javacrumbs.jsonunit.core.Option; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -public abstract class AbstractMcpClientServerIntegrationTests { - - protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); - - abstract protected void prepareClients(int port, String mcpEndpoint); - - abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); - - abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void simple(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1000)) - .build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .requestTimeout(Duration.ofSeconds(1000)) - .build()) { - - assertThat(client.initialize()).isNotNull(); - - } - server.closeGracefully(); - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.closeGracefully(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateMessageSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(createMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { - - // Client - - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(createMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(4)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - - mcpClient.close(); - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { - - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(createMessageRequest).thenReturn(callResponse); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("1000ms"); - - mcpClient.close(); - mcpServer.close(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateElicitationWithoutElicitationCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) - .then(Mono.just(mock(CallToolResult.class)))) - .build(); - - var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); - - // Create client without elicitation capabilities - try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateElicitationSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference resultRef = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - return exchange.createElicitation(elicitationRequest) - .doOnNext(resultRef::set) - .then(Mono.just(callResponse)); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - assertWith(resultRef.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testCreateElicitationWithRequestTimeoutFail(String clientType) { - - var latch = new CountDownLatch(1); - - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - try { - if (!latch.await(2, TimeUnit.SECONDS)) { - throw new RuntimeException("Timeout waiting for elicitation processing"); - } - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - AtomicReference resultRef = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - return exchange.createElicitation(elicitationRequest) - .doOnNext(resultRef::set) - .then(Mono.just(callResponse)); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) // 1 second. - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - ElicitResult elicitResult = resultRef.get(); - assertThat(elicitResult).isNull(); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testRootsSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = prepareSyncServerBuilder() - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testRootsWithoutCapability(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - try ( - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testRootsNotificationWithEmptyRootsList(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = prepareSyncServerBuilder() - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testRootsWithMultipleHandlers(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = prepareSyncServerBuilder() - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testRootsServerCloseWithActiveSubscription(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = prepareSyncServerBuilder() - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var responseBodyIsNullOrBlank = new AtomicBoolean(false); - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - try { - HttpResponse response = HttpClient.newHttpClient() - .send(HttpRequest.newBuilder() - .uri(URI.create( - "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) - .GET() - .build(), HttpResponse.BodyHandlers.ofString()); - String responseBody = response.body(); - responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); - } - catch (Exception e) { - e.printStackTrace(); - } - - return callResponse; - }) - .build(); - - var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(responseBodyIsNullOrBlank.get()).isFalse(); - assertThat(response).isNotNull().isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpSyncServer mcpServer = prepareSyncServerBuilder() - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder() - .name("tool1") - .description("tool1 description") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> { - // We trigger a timeout on blocking read, raising an exception - Mono.never().block(Duration.ofSeconds(1)); - return null; - }) - .build()) - .build(); - - try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // We expect the tool call to fail immediately with the exception raised by - // the offending tool - // instead of getting back a timeout. - assertThatExceptionOfType(McpError.class) - .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) - .withMessageContaining("Timeout on blocking read"); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testToolCallSuccessWithTranportContextExtraction(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var transportContextIsNull = new AtomicBoolean(false); - var transportContextIsEmpty = new AtomicBoolean(false); - var responseBodyIsNullOrBlank = new AtomicBoolean(false); - - var expectedCallResponse = new McpSchema.CallToolResult( - List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=value")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - - McpTransportContext transportContext = exchange.transportContext(); - transportContextIsNull.set(transportContext == null); - transportContextIsEmpty.set(transportContext.equals(McpTransportContext.EMPTY)); - String ctxValue = (String) transportContext.get("important"); - - try { - String responseBody = "TOOL RESPONSE"; - responseBodyIsNullOrBlank.set(!Utils.hasText(responseBody)); - } - catch (Exception e) { - e.printStackTrace(); - } - - return new McpSchema.CallToolResult( - List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)), null); - }) - .build(); - - var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(transportContextIsNull.get()).isFalse(); - assertThat(transportContextIsEmpty.get()).isFalse(); - assertThat(responseBodyIsNullOrBlank.get()).isFalse(); - assertThat(response).isNotNull().isEqualTo(expectedCallResponse); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testToolListChangeHandlingSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - try { - HttpResponse response = HttpClient.newHttpClient() - .send(HttpRequest.newBuilder() - .uri(URI.create( - "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) - .GET() - .build(), HttpResponse.BodyHandlers.ofString()); - String responseBody = response.body(); - assertThat(responseBody).isNotBlank(); - } - catch (Exception e) { - e.printStackTrace(); - } - return callResponse; - }) - .build(); - - AtomicReference> toolsRef = new AtomicReference<>(); - - var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - try { - HttpResponse response = HttpClient.newHttpClient() - .send(HttpRequest.newBuilder() - .uri(URI.create( - "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) - .GET() - .build(), HttpResponse.BodyHandlers.ofString()); - String responseBody = response.body(); - assertThat(responseBody).isNotBlank(); - toolsRef.set(toolsUpdate); - } - catch (Exception e) { - e.printStackTrace(); - } - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(toolsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(Tool.builder() - .name("tool2") - .description("tool2 description") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testInitialize(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var mcpServer = prepareSyncServerBuilder().build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testLoggingNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 3; - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder() - .name("logging-test") - .description("Test logging notifications") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> { - - // Create and send notifications with different levels - - //@formatter:off - return exchange // This should be filtered out (DEBUG < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .then(exchange // This should be sent (NOTICE >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build())) - .then(exchange // This should be sent (ERROR > NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build())) - .then(exchange // This should be filtered out (INFO < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build())) - .then(exchange // This should be sent (ERROR >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); - //@formatter:on - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - } - mcpServer.close(); - } - - // --------------------------------------- - // Progress Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testProgressNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress - // token - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(McpSchema.Tool.builder() - .name("progress-test") - .description("Test progress notifications") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> { - - // Create and send notifications - var progressToken = (String) request.meta().get("progressToken"); - - return exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) - .then(// Send a progress notification with another progress value - // should - exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", - 0.0, 1.0, "Another processing started"))) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) - .thenReturn(new CallToolResult(("Progress test completed"), false)); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with progress notification handler - var mcpClient = clientBuilder.progressConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that sends progress notifications - McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() - .name("progress-test") - .meta(Map.of("progressToken", "test-progress-token")) - .build(); - CallToolResult result = mcpClient.callTool(callToolRequest); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.message(), n -> n)); - - // First notification should be 0.0/1.0 progress - assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); - - // Second notification should be 0.5/1.0 progress - assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); - assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); - - // Third notification should be another progress token with 0.0/1.0 progress - assertThat(notificationMap.get("Another processing started").progressToken()) - .isEqualTo("another-progress-token"); - assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Another processing started").message()) - .isEqualTo("Another processing started"); - - // Fourth notification should be 1.0/1.0 progress - assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); - } - finally { - mcpServer.close(); - } - } - - // --------------------------------------- - // Completion Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient" }) - void testCompletionShouldReturnExpectedSuggestions(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - var expectedValues = List.of("python", "pytorch", "pyside"); - var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total - true // hasMore - )); - - AtomicReference samplingRequest = new AtomicReference<>(); - BiFunction completionHandler = (mcpSyncServerExchange, - request) -> { - samplingRequest.set(request); - return completionResponse; - }; - - var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) - .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "Code review", "this is code review prompt", - List.of(new PromptArgument("language", "Language", "string", false))), - (mcpSyncServerExchange, getPromptRequest) -> null)) - .completions(new McpServerFeatures.SyncCompletionSpecification( - new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), - new CompleteRequest.CompleteArgument("language", "py")); - - CompleteResult result = mcpClient.completeCompletion(request); - - assertThat(result).isNotNull(); - - assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); - assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Ping Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testPingSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(Tool.builder() - .name("ping-async-test") - .description("Test ping async behavior") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }) - .build(); - - var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testStructuredOutputValidationSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(calculatorTool) - .callHandler((exchange, request) -> { - String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }) - .build(); - - var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - - // In WebMVC, structured content is returned properly - if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) - .containsEntry("operation", "2 + 3") - .containsEntry("timestamp", "2024-01-01T10:00:00Z"); - } - else { - // Fallback to checking content if structured content is not available - assertThat(response.content()).isNotEmpty(); - } - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testStructuredOutputValidationFailure(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(calculatorTool) - .callHandler((exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }) - .build(); - - var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testStructuredOutputMissingStructuredContent(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(calculatorTool) - .callHandler((exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }) - .build(); - - var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient" }) - void testStructuredOutputRuntimeToolAddition(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - // Start server without tools - var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() - .tool(dynamicTool) - .callHandler((exchange, request) -> { - int count = (Integer) request.arguments().getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }) - .build(); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) response.content().get(0)).text()) - .isEqualTo("Dynamic tool executed 3 times"); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"count":3,"message":"Dynamic execution"}""")); - } - - mcpServer.close(); - } - - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java deleted file mode 100644 index 67579ce72..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ /dev/null @@ -1,471 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import java.util.List; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptMessage; -import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; -import io.modelcontextprotocol.spec.McpSchema.Resource; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerTransportProvider; - -/** - * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpServerTransportProvider} implementations. - * - * @author Christian Tzolov - */ -// KEEP IN SYNC with the class in mcp-test module -public abstract class AbstractMcpSyncServerTests { - - private static final String TEST_TOOL_NAME = "test-tool"; - - private static final String TEST_RESOURCE_URI = "test://resource"; - - private static final String TEST_PROMPT_NAME = "test-prompt"; - - abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); - - protected void onStart() { - } - - protected void onClose() { - } - - @BeforeEach - void setUp() { - // onStart(); - } - - @AfterEach - void tearDown() { - onClose(); - } - - // --------------------------------------- - // Server Lifecycle Tests - // --------------------------------------- - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport provider must not be null"); - - assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Server info must not be null"); - } - - @Test - void testGracefulShutdown() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testImmediateClose() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); - } - - @Test - void testGetAsyncServer() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - @Deprecated - void testAddTool() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddToolCall() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() - .tool(newTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build())).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - @Deprecated - void testAddDuplicateTool() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, - (exchange, args) -> new CallToolResult(List.of(), false)))) - .isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddDuplicateToolCall() { - Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addTool(McpServerFeatures.SyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build())).isInstanceOf(McpError.class) - .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testDuplicateToolCallDuringBuilding() { - Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", - emptyJsonSchema); - - assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) - .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Tool with name 'duplicate-build-toolcall' is already registered."); - } - - @Test - void testDuplicateToolsInBatchListRegistration() { - Tool duplicateTool = new Tool("batch-list-tool", "Duplicate tool in batch list", emptyJsonSchema); - List specs = List.of( - McpServerFeatures.SyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build(), - McpServerFeatures.SyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build() // Duplicate! - ); - - assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(specs) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Tool with name 'batch-list-tool' is already registered."); - } - - @Test - void testDuplicateToolsInBatchVarargsRegistration() { - Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - - assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(McpServerFeatures.SyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build(), - McpServerFeatures.SyncToolSpecification.builder() - .tool(duplicateTool) - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build() // Duplicate! - ) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Tool with name 'batch-varargs-tool' is already registered."); - } - - @Test - void testRemoveTool() { - Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) - .build(); - - assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentTool() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) - .hasMessage("Tool with name 'nonexistent-tool' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testNotifyToolsListChanged() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Resources Tests - // --------------------------------------- - - @Test - void testNotifyResourcesListChanged() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testNotifyResourcesUpdated() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer - .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) - .doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResource() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( - resource, (exchange, req) -> new ReadResourceResult(List.of())); - - assertThatCode(() -> mcpSyncServer.addResource(specification)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithNullSpecification() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().resources(true, false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) - .isInstanceOf(McpError.class) - .hasMessage("Resource must not be null"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddResourceWithoutCapability() { - var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", - null); - McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( - resource, (exchange, req) -> new ReadResourceResult(List.of())); - - assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - @Test - void testRemoveResourceWithoutCapability() { - var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with resource capabilities"); - } - - // --------------------------------------- - // Prompts Tests - // --------------------------------------- - - @Test - void testNotifyPromptsListChanged() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testAddPromptWithNullSpecification() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(false).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) - .isInstanceOf(McpError.class) - .hasMessage("Prompt specification must not be null"); - } - - @Test - void testAddPromptWithoutCapability() { - var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, - (exchange, req) -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specification)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) - .hasMessage("Server must be configured with prompt capabilities"); - } - - @Test - void testRemovePrompt() { - Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); - McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, - (exchange, req) -> new GetPromptResult("Test prompt description", List - .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(specification) - .build(); - - assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - @Test - void testRemoveNonexistentPrompt() { - var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().prompts(true).build()) - .build(); - - assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) - .hasMessage("Prompt with name 'nonexistent-prompt' not found"); - - assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - - @Test - void testRootsChangeHandlers() { - // Test with single consumer - var rootsReceived = new McpSchema.Root[1]; - var consumerCalled = new boolean[1]; - - var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - consumerCalled[0] = true; - if (!roots.isEmpty()) { - rootsReceived[0] = roots.get(0); - } - })) - .build(); - - assertThat(singleConsumerServer).isNotNull(); - assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test with multiple consumers - var consumer1Called = new boolean[1]; - var consumer2Called = new boolean[1]; - var rootsContent = new List[1]; - - var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - consumer1Called[0] = true; - rootsContent[0] = roots; - }, (exchange, roots) -> consumer2Called[0] = true)) - .build(); - - assertThat(multipleConsumersServer).isNotNull(); - assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test error handling - var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") - .rootsChangeHandlers(List.of((exchange, roots) -> { - throw new RuntimeException("Test error"); - })) - .build(); - - assertThat(errorHandlingServer).isNotNull(); - assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); - onClose(); - - // Test without consumers - var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); - - assertThat(noConsumersServer).isNotNull(); - assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java b/mcp/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java deleted file mode 100644 index 6744826c9..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import java.util.List; -import java.util.Map; - -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.TextContent; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -/** - * Tests for {@link McpServerFeatures.AsyncToolSpecification.Builder}. - * - * @author Christian Tzolov - */ -class AsyncToolSpecificationBuilderTest { - - String emptyJsonSchema = """ - { - "type": "object" - } - """; - - @Test - void builderShouldCreateValidAsyncToolSpecification() { - - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); - - McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() - .tool(tool) - .callHandler((exchange, request) -> Mono - .just(new CallToolResult(List.of(new TextContent("Test result")), false))) - .build(); - - assertThat(specification).isNotNull(); - assertThat(specification.tool()).isEqualTo(tool); - assertThat(specification.callHandler()).isNotNull(); - assertThat(specification.call()).isNull(); // deprecated field should be null - } - - @Test - void builderShouldThrowExceptionWhenToolIsNull() { - assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder() - .callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); - } - - @Test - void builderShouldThrowExceptionWhenCallToolIsNull() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); - - assertThatThrownBy(() -> McpServerFeatures.AsyncToolSpecification.builder().tool(tool).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Call handler function must not be null"); - } - - @Test - void builderShouldAllowMethodChaining() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); - McpServerFeatures.AsyncToolSpecification.Builder builder = McpServerFeatures.AsyncToolSpecification.builder(); - - // Then - verify method chaining returns the same builder instance - assertThat(builder.tool(tool)).isSameAs(builder); - assertThat(builder.callHandler((exchange, request) -> Mono.just(new CallToolResult(List.of(), false)))) - .isSameAs(builder); - } - - @Test - void builtSpecificationShouldExecuteCallToolCorrectly() { - Tool tool = new Tool("calculator", "Simple calculator", emptyJsonSchema); - String expectedResult = "42"; - - McpServerFeatures.AsyncToolSpecification specification = McpServerFeatures.AsyncToolSpecification.builder() - .tool(tool) - .callHandler((exchange, request) -> { - return Mono.just(new CallToolResult(List.of(new TextContent(expectedResult)), false)); - }) - .build(); - - CallToolRequest request = new CallToolRequest("calculator", Map.of()); - Mono resultMono = specification.callHandler().apply(null, request); - - StepVerifier.create(resultMono).assertNext(result -> { - assertThat(result).isNotNull(); - assertThat(result.content()).hasSize(1); - assertThat(result.content().get(0)).isInstanceOf(TextContent.class); - assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); - assertThat(result.isError()).isFalse(); - }).verifyComplete(); - } - - @Test - @SuppressWarnings("deprecation") - void deprecatedConstructorShouldWorkCorrectly() { - Tool tool = new Tool("deprecated-tool", "A deprecated tool", emptyJsonSchema); - String expectedResult = "deprecated result"; - - // Test the deprecated constructor that takes a 'call' function - McpServerFeatures.AsyncToolSpecification specification = new McpServerFeatures.AsyncToolSpecification(tool, - (exchange, arguments) -> Mono - .just(new CallToolResult(List.of(new TextContent(expectedResult)), false))); - - assertThat(specification).isNotNull(); - assertThat(specification.tool()).isEqualTo(tool); - assertThat(specification.call()).isNotNull(); // deprecated field should be set - assertThat(specification.callHandler()).isNotNull(); // should be automatically - // created - - // Test that the callTool function works (it should delegate to the call function) - CallToolRequest request = new CallToolRequest("deprecated-tool", Map.of("arg1", "value1")); - Mono resultMono = specification.callHandler().apply(null, request); - - StepVerifier.create(resultMono).assertNext(result -> { - assertThat(result).isNotNull(); - assertThat(result.content()).hasSize(1); - assertThat(result.content().get(0)).isInstanceOf(TextContent.class); - assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); - assertThat(result.isError()).isFalse(); - }).verifyComplete(); - - // Test that the deprecated call function also works directly - Mono callResultMono = specification.call().apply(null, request.arguments()); - - StepVerifier.create(callResultMono).assertNext(result -> { - assertThat(result).isNotNull(); - assertThat(result.content()).hasSize(1); - assertThat(result.content().get(0)).isInstanceOf(TextContent.class); - assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); - assertThat(result.isError()).isFalse(); - }).verifyComplete(); - } - - @Test - void fromSyncShouldConvertSyncToolSpecificationCorrectly() { - Tool tool = new Tool("sync-tool", "A sync tool", emptyJsonSchema); - String expectedResult = "sync result"; - - // Create a sync tool specification - McpServerFeatures.SyncToolSpecification syncSpec = McpServerFeatures.SyncToolSpecification.builder() - .tool(tool) - .callHandler((exchange, request) -> new CallToolResult(List.of(new TextContent(expectedResult)), false)) - .build(); - - // Convert to async using fromSync - McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification - .fromSync(syncSpec); - - assertThat(asyncSpec).isNotNull(); - assertThat(asyncSpec.tool()).isEqualTo(tool); - assertThat(asyncSpec.callHandler()).isNotNull(); - assertThat(asyncSpec.call()).isNull(); // should be null since sync spec doesn't - // have deprecated call - - // Test that the converted async specification works correctly - CallToolRequest request = new CallToolRequest("sync-tool", Map.of("param", "value")); - Mono resultMono = asyncSpec.callHandler().apply(null, request); - - StepVerifier.create(resultMono).assertNext(result -> { - assertThat(result).isNotNull(); - assertThat(result.content()).hasSize(1); - assertThat(result.content().get(0)).isInstanceOf(TextContent.class); - assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); - assertThat(result.isError()).isFalse(); - }).verifyComplete(); - } - - @Test - @SuppressWarnings("deprecation") - void fromSyncShouldConvertSyncToolSpecificationWithDeprecatedCallCorrectly() { - Tool tool = new Tool("sync-deprecated-tool", "A sync tool with deprecated call", emptyJsonSchema); - String expectedResult = "sync deprecated result"; - McpAsyncServerExchange nullExchange = null; // Mock or create a suitable exchange - // if needed - - // Create a sync tool specification using the deprecated constructor - McpServerFeatures.SyncToolSpecification syncSpec = new McpServerFeatures.SyncToolSpecification(tool, - (exchange, arguments) -> new CallToolResult(List.of(new TextContent(expectedResult)), false)); - - // Convert to async using fromSync - McpServerFeatures.AsyncToolSpecification asyncSpec = McpServerFeatures.AsyncToolSpecification - .fromSync(syncSpec); - - assertThat(asyncSpec).isNotNull(); - assertThat(asyncSpec.tool()).isEqualTo(tool); - assertThat(asyncSpec.callHandler()).isNotNull(); - assertThat(asyncSpec.call()).isNotNull(); // should be set since sync spec has - // deprecated call - - // Test that the converted async specification works correctly via callTool - CallToolRequest request = new CallToolRequest("sync-deprecated-tool", Map.of("param", "value")); - Mono resultMono = asyncSpec.callHandler().apply(nullExchange, request); - - StepVerifier.create(resultMono).assertNext(result -> { - assertThat(result).isNotNull(); - assertThat(result.content()).hasSize(1); - assertThat(result.content().get(0)).isInstanceOf(TextContent.class); - assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); - assertThat(result.isError()).isFalse(); - }).verifyComplete(); - - // Test that the deprecated call function also works - Mono callResultMono = asyncSpec.call().apply(nullExchange, request.arguments()); - - StepVerifier.create(callResultMono).assertNext(result -> { - assertThat(result).isNotNull(); - assertThat(result.content()).hasSize(1); - assertThat(result.content().get(0)).isInstanceOf(TextContent.class); - assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); - assertThat(result.isError()).isFalse(); - }).verifyComplete(); - } - - @Test - void fromSyncShouldReturnNullWhenSyncSpecIsNull() { - assertThat(McpServerFeatures.AsyncToolSpecification.fromSync(null)).isNull(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java b/mcp/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java deleted file mode 100644 index 4aac46952..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.server; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import java.util.List; -import java.util.Map; - -import org.junit.jupiter.api.Test; - -import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.TextContent; -import io.modelcontextprotocol.spec.McpSchema.Tool; - -/** - * Tests for {@link McpServerFeatures.SyncToolSpecification.Builder}. - * - * @author Christian Tzolov - */ -class SyncToolSpecificationBuilderTest { - - String emptyJsonSchema = """ - { - "type": "object" - } - """; - - @Test - void builderShouldCreateValidSyncToolSpecification() { - - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); - - McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() - .tool(tool) - .callHandler((exchange, request) -> new CallToolResult(List.of(new TextContent("Test result")), false)) - .build(); - - assertThat(specification).isNotNull(); - assertThat(specification.tool()).isEqualTo(tool); - assertThat(specification.callHandler()).isNotNull(); - assertThat(specification.call()).isNull(); // deprecated field should be null - } - - @Test - void builderShouldThrowExceptionWhenToolIsNull() { - assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder() - .callHandler((exchange, request) -> new CallToolResult(List.of(), false)) - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Tool must not be null"); - } - - @Test - void builderShouldThrowExceptionWhenCallToolIsNull() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); - - assertThatThrownBy(() -> McpServerFeatures.SyncToolSpecification.builder().tool(tool).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("CallTool function must not be null"); - } - - @Test - void builderShouldAllowMethodChaining() { - Tool tool = new Tool("test-tool", "A test tool", emptyJsonSchema); - McpServerFeatures.SyncToolSpecification.Builder builder = McpServerFeatures.SyncToolSpecification.builder(); - - // Then - verify method chaining returns the same builder instance - assertThat(builder.tool(tool)).isSameAs(builder); - assertThat(builder.callHandler((exchange, request) -> new CallToolResult(List.of(), false))).isSameAs(builder); - } - - @Test - void builtSpecificationShouldExecuteCallToolCorrectly() { - Tool tool = new Tool("calculator", "Simple calculator", emptyJsonSchema); - String expectedResult = "42"; - - McpServerFeatures.SyncToolSpecification specification = McpServerFeatures.SyncToolSpecification.builder() - .tool(tool) - .callHandler((exchange, request) -> { - // Simple test implementation - return new CallToolResult(List.of(new TextContent(expectedResult)), false); - }) - .build(); - - CallToolRequest request = new CallToolRequest("calculator", Map.of()); - CallToolResult result = specification.callHandler().apply(null, request); - - assertThat(result).isNotNull(); - assertThat(result.content()).hasSize(1); - assertThat(result.content().get(0)).isInstanceOf(TextContent.class); - assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedResult); - assertThat(result.isError()).isFalse(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java deleted file mode 100644 index cc2543aa9..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/McpTestServletFilter.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2025 - 2025 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; - -import jakarta.servlet.Filter; -import jakarta.servlet.FilterChain; -import jakarta.servlet.ServletException; -import jakarta.servlet.ServletRequest; -import jakarta.servlet.ServletResponse; - -/** - * Simple {@link Filter} which sets a value in a thread local. Used to verify whether MCP - * executions happen on the thread processing the request or are offloaded. - * - * @author Daniel Garnier-Moiroux - */ -public class McpTestServletFilter implements Filter { - - public static final String THREAD_LOCAL_VALUE = McpTestServletFilter.class.getName(); - - private static final ThreadLocal holder = new ThreadLocal<>(); - - @Override - public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) - throws IOException, ServletException { - holder.set(THREAD_LOCAL_VALUE); - try { - filterChain.doFilter(servletRequest, servletResponse); - } - finally { - holder.remove(); - } - } - - public static String getThreadLocalValue() { - return holder.get(); - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java deleted file mode 100644 index 85dcd26c2..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - */ - -package io.modelcontextprotocol.spec; - -import java.time.Duration; -import java.util.Map; - -import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.MockMcpClientTransport; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, - * request-response correlation, and notification processing. - * - * @author Christian Tzolov - */ -class McpClientSessionTests { - - private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); - - private static final Duration TIMEOUT = Duration.ofSeconds(5); - - private static final String TEST_METHOD = "test.method"; - - private static final String TEST_NOTIFICATION = "test.notification"; - - private static final String ECHO_METHOD = "echo"; - - private McpClientSession session; - - private MockMcpClientTransport transport; - - @BeforeEach - void setUp() { - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: {}", params)))); - } - - @AfterEach - void tearDown() { - if (session != null) { - session.close(); - } - } - - @Test - void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The requestTimeout can not be null"); - - assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("transport can not be null"); - } - - TypeReference responseType = new TypeReference<>() { - }; - - @Test - void testSendRequest() { - String testParam = "test parameter"; - String responseData = "test response"; - - // Create a Mono that will emit the response after the request is sent - Mono responseMono = session.sendRequest(TEST_METHOD, testParam, responseType); - // Verify response handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), responseData, null)); - }).consumeNextWith(response -> { - // Verify the request was sent - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessageAsRequest(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCRequest.class); - McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) sentMessage; - assertThat(request.method()).isEqualTo(TEST_METHOD); - assertThat(request.params()).isEqualTo(testParam); - assertThat(response).isEqualTo(responseData); - }).verifyComplete(); - } - - @Test - void testSendRequestWithError() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify error handling - StepVerifier.create(responseMono).then(() -> { - McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); - // Simulate error response - McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Method not found", null); - transport.simulateIncomingMessage( - new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); - }).expectError(McpError.class).verify(); - } - - @Test - void testRequestTimeout() { - Mono responseMono = session.sendRequest(TEST_METHOD, "test", responseType); - - // Verify timeout - StepVerifier.create(responseMono) - .expectError(java.util.concurrent.TimeoutException.class) - .verify(TIMEOUT.plusSeconds(1)); - } - - @Test - void testSendNotification() { - Map params = Map.of("key", "value"); - Mono notificationMono = session.sendNotification(TEST_NOTIFICATION, params); - - // Verify notification was sent - StepVerifier.create(notificationMono).consumeSubscriptionWith(response -> { - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class); - McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage; - assertThat(notification.method()).isEqualTo(TEST_NOTIFICATION); - assertThat(notification.params()).isEqualTo(params); - }).verifyComplete(); - } - - @Test - void testRequestHandling() { - String echoMessage = "Hello MCP!"; - Map> requestHandlers = Map.of(ECHO_METHOD, - params -> Mono.just(params)); - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); - - // Simulate incoming request - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, - "test-id", echoMessage); - transport.simulateIncomingMessage(request); - - // Verify response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.result()).isEqualTo(echoMessage); - assertThat(response.error()).isNull(); - } - - @Test - void testNotificationHandling() { - Sinks.One receivedParams = Sinks.one(); - - transport = new MockMcpClientTransport(); - session = new McpClientSession(TIMEOUT, transport, Map.of(), - Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); - - // Simulate incoming notification from the server - Map notificationParams = Map.of("status", "ready"); - - McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - TEST_NOTIFICATION, notificationParams); - - transport.simulateIncomingMessage(notification); - - // Verify handler was called - assertThat(receivedParams.asMono().block(Duration.ofSeconds(1))).isEqualTo(notificationParams); - } - - @Test - void testUnknownMethodHandling() { - // Simulate incoming request for unknown method - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", - "test-id", null); - transport.simulateIncomingMessage(request); - - // Verify error response - McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage(); - assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); - McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.error()).isNotNull(); - assertThat(response.error().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); - } - - @Test - void testGracefulShutdown() { - StepVerifier.create(session.closeGracefully()).verifyComplete(); - } - -} diff --git a/migration-0.8.0.md b/migration-0.8.0.md deleted file mode 100644 index 3ba29a10b..000000000 --- a/migration-0.8.0.md +++ /dev/null @@ -1,328 +0,0 @@ -# MCP Java SDK Migration Guide: 0.7.0 to 0.8.0 - -This document outlines the breaking changes and provides guidance on how to migrate your code from version 0.7.0 to 0.8.0. - -The 0.8.0 refactoring introduces a session-based architecture for server-side MCP implementations. -It improves the SDK's ability to handle multiple concurrent client connections and provides an API better aligned with the MCP specification. -The main changes include: - -1. Introduction of a session-based architecture -2. New transport provider abstraction -3. Exchange objects for client interaction -4. Renamed and reorganized interfaces -5. Updated handler signatures - -## Breaking Changes - -### 1. Interface Renaming - -Several interfaces have been renamed to better reflect their roles: - -| 0.7.0 (Old) | 0.8.0 (New) | -|-------------|-------------| -| `ClientMcpTransport` | `McpClientTransport` | -| `ServerMcpTransport` | `McpServerTransport` | -| `DefaultMcpSession` | `McpClientSession`, `McpServerSession` | - -### 2. New Server Transport Architecture - -The most significant change is the introduction of the `McpServerTransportProvider` interface, which replaces direct usage of `ServerMcpTransport` when creating servers. This new pattern separates the concerns of: - -1. **Transport Provider**: Manages connections with clients and creates individual transports for each connection -2. **Server Transport**: Handles communication with a specific client connection - -| 0.7.0 (Old) | 0.8.0 (New) | -|-------------|-------------| -| `ServerMcpTransport` | `McpServerTransportProvider` + `McpServerTransport` | -| Direct transport usage | Session-based transport usage | - -#### Before (0.7.0): - -```java -// Create a transport -ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); - -// Create a server with the transport -McpServer.sync(transport) - .serverInfo("my-server", "1.0.0") - .build(); -``` - -#### After (0.8.0): - -```java -// Create a transport provider -McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); - -// Create a server with the transport provider -McpServer.sync(transportProvider) - .serverInfo("my-server", "1.0.0") - .build(); -``` - -### 3. Handler Method Signature Changes - -Tool, resource, and prompt handlers now receive an additional `exchange` parameter that provides access to client capabilities and methods to interact with the client: - -| 0.7.0 (Old) | 0.8.0 (New) | -|-------------|-------------| -| `(args) -> result` | `(exchange, args) -> result` | - -The exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide context for the current session and access to session-specific operations. - -#### Before (0.7.0): - -```java -// Tool handler -.tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) - -// Resource handler -.resource(fileResource, req -> new ReadResourceResult(readFile(req))) - -// Prompt handler -.prompt(analysisPrompt, req -> new GetPromptResult("Analysis prompt")) -``` - -#### After (0.8.0): - -```java -// Tool handler -.tool(calculatorTool, (exchange, args) -> new CallToolResult("Result: " + calculate(args))) - -// Resource handler -.resource(fileResource, (exchange, req) -> new ReadResourceResult(readFile(req))) - -// Prompt handler -.prompt(analysisPrompt, (exchange, req) -> new GetPromptResult("Analysis prompt")) -``` - -### 4. Registration vs. Specification - -The naming convention for handlers has changed from "Registration" to "Specification": - -| 0.7.0 (Old) | 0.8.0 (New) | -|-------------|-------------| -| `AsyncToolRegistration` | `AsyncToolSpecification` | -| `SyncToolRegistration` | `SyncToolSpecification` | -| `AsyncResourceRegistration` | `AsyncResourceSpecification` | -| `SyncResourceRegistration` | `SyncResourceSpecification` | -| `AsyncPromptRegistration` | `AsyncPromptSpecification` | -| `SyncPromptRegistration` | `SyncPromptSpecification` | - -### 5. Roots Change Handler Updates - -The roots change handlers now receive an exchange parameter: - -#### Before (0.7.0): - -```java -.rootsChangeConsumers(List.of( - roots -> { - // Process roots - } -)) -``` - -#### After (0.8.0): - -```java -.rootsChangeHandlers(List.of( - (exchange, roots) -> { - // Process roots with access to exchange - } -)) -``` - -### 6. Server Creation Method Changes - -The `McpServer` factory methods now accept `McpServerTransportProvider` instead of `ServerMcpTransport`: - -| 0.7.0 (Old) | 0.8.0 (New) | -|-------------|-------------| -| `McpServer.async(ServerMcpTransport)` | `McpServer.async(McpServerTransportProvider)` | -| `McpServer.sync(ServerMcpTransport)` | `McpServer.sync(McpServerTransportProvider)` | - -The method names for creating servers have been updated: - -Root change handlers now receive an exchange object: - -| 0.7.0 (Old) | 0.8.0 (New) | -|-------------|-------------| -| `rootsChangeConsumers(List>>)` | `rootsChangeHandlers(List>>)` | -| `rootsChangeConsumer(Consumer>)` | `rootsChangeHandler(BiConsumer>)` | - -### 7. Direct Server Methods Moving to Exchange - -Several methods that were previously available directly on the server are now accessed through the exchange object: - -| 0.7.0 (Old) | 0.8.0 (New) | -|-------------|-------------| -| `server.listRoots()` | `exchange.listRoots()` | -| `server.createMessage()` | `exchange.createMessage()` | -| `server.getClientCapabilities()` | `exchange.getClientCapabilities()` | -| `server.getClientInfo()` | `exchange.getClientInfo()` | - -The direct methods are deprecated and will be removed in 0.9.0: - -- `McpSyncServer.listRoots()` -- `McpSyncServer.getClientCapabilities()` -- `McpSyncServer.getClientInfo()` -- `McpSyncServer.createMessage()` -- `McpAsyncServer.listRoots()` -- `McpAsyncServer.getClientCapabilities()` -- `McpAsyncServer.getClientInfo()` -- `McpAsyncServer.createMessage()` - -## Deprecation Notices - -The following components are deprecated in 0.8.0 and will be removed in 0.9.0: - -- `ClientMcpTransport` interface (use `McpClientTransport` instead) -- `ServerMcpTransport` interface (use `McpServerTransport` instead) -- `DefaultMcpSession` class (use `McpClientSession` instead) -- `WebFluxSseServerTransport` class (use `WebFluxSseServerTransportProvider` instead) -- `WebMvcSseServerTransport` class (use `WebMvcSseServerTransportProvider` instead) -- `StdioServerTransport` class (use `StdioServerTransportProvider` instead) -- All `*Registration` classes (use corresponding `*Specification` classes instead) -- Direct server methods for client interaction (use exchange object instead) - -## Migration Examples - -### Example 1: Creating a Server - -#### Before (0.7.0): - -```java -// Create a transport -ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); - -// Create a server with the transport -var server = McpServer.sync(transport) - .serverInfo("my-server", "1.0.0") - .tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) - .rootsChangeConsumers(List.of( - roots -> System.out.println("Roots changed: " + roots) - )) - .build(); - -// Get client capabilities directly from server -ClientCapabilities capabilities = server.getClientCapabilities(); -``` - -#### After (0.8.0): - -```java -// Create a transport provider -McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); - -// Create a server with the transport provider -var server = McpServer.sync(transportProvider) - .serverInfo("my-server", "1.0.0") - .tool(calculatorTool, (exchange, args) -> { - // Get client capabilities from exchange - ClientCapabilities capabilities = exchange.getClientCapabilities(); - return new CallToolResult("Result: " + calculate(args)); - }) - .rootsChangeHandlers(List.of( - (exchange, roots) -> System.out.println("Roots changed: " + roots) - )) - .build(); -``` - -### Example 2: Implementing a Tool with Client Interaction - -#### Before (0.7.0): - -```java -McpServerFeatures.SyncToolRegistration tool = new McpServerFeatures.SyncToolRegistration( - new Tool("weather", "Get weather information", schema), - args -> { - String location = (String) args.get("location"); - // Cannot interact with client from here - return new CallToolResult("Weather for " + location + ": Sunny"); - } -); - -var server = McpServer.sync(transport) - .tools(tool) - .build(); - -// Separate call to create a message -CreateMessageResult result = server.createMessage(new CreateMessageRequest(...)); -``` - -#### After (0.8.0): - -```java -McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( - new Tool("weather", "Get weather information", schema), - (exchange, args) -> { - String location = (String) args.get("location"); - - // Can interact with client directly from the tool handler - CreateMessageResult result = exchange.createMessage(new CreateMessageRequest(...)); - - return new CallToolResult("Weather for " + location + ": " + result.content()); - } -); - -var server = McpServer.sync(transportProvider) - .tools(tool) - .build(); -``` - -### Example 3: Converting Existing Registration Classes - -If you have custom implementations of the registration classes, you can convert them to the new specification classes: - -#### Before (0.7.0): - -```java -McpServerFeatures.AsyncToolRegistration toolReg = new McpServerFeatures.AsyncToolRegistration( - tool, - args -> Mono.just(new CallToolResult("Result")) -); - -McpServerFeatures.AsyncResourceRegistration resourceReg = new McpServerFeatures.AsyncResourceRegistration( - resource, - req -> Mono.just(new ReadResourceResult(List.of())) -); -``` - -#### After (0.8.0): - -```java -// Option 1: Create new specification directly -McpServerFeatures.AsyncToolSpecification toolSpec = new McpServerFeatures.AsyncToolSpecification( - tool, - (exchange, args) -> Mono.just(new CallToolResult("Result")) -); - -// Option 2: Convert from existing registration (during transition) -McpServerFeatures.AsyncToolRegistration oldToolReg = /* existing registration */; -McpServerFeatures.AsyncToolSpecification toolSpec = oldToolReg.toSpecification(); - -// Similarly for resources -McpServerFeatures.AsyncResourceSpecification resourceSpec = new McpServerFeatures.AsyncResourceSpecification( - resource, - (exchange, req) -> Mono.just(new ReadResourceResult(List.of())) -); -``` - -## Architecture Changes - -### Session-Based Architecture - -In 0.8.0, the MCP Java SDK introduces a session-based architecture where each client connection has its own session. This allows for better isolation between clients and more efficient resource management. - -The `McpServerTransportProvider` is responsible for creating `McpServerTransport` instances for each session, and the `McpServerSession` manages the communication with a specific client. - -### Exchange Objects - -The new exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide access to client-specific information and methods. They are passed to handler functions as the first parameter, allowing handlers to interact with the specific client that made the request. - -## Conclusion - -The changes in version 0.8.0 represent a significant architectural improvement to the MCP Java SDK. While they require some code changes, the new design provides a more flexible and maintainable foundation for building MCP applications. - -For assistance with migration or to report issues, please open an issue on the GitHub repository. diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000..3e27c3fb5 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,97 @@ +site_name: MCP Java SDK +site_url: https://modelcontextprotocol.github.io/java-sdk/ +site_description: Java SDK for the Model Context Protocol - standardized integration between AI models and tools +repo_url: https://github.com/modelcontextprotocol/java-sdk +repo_name: modelcontextprotocol/java-sdk +edit_uri: edit/main/docs/ + +theme: + name: material + favicon: images/favicon.svg + logo: images/logo-light.svg + palette: + - scheme: default + primary: blue grey + accent: blue grey + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - scheme: slate + primary: blue grey + accent: blue grey + toggle: + icon: material/brightness-4 + name: Switch to light mode + features: + - navigation.instant + - navigation.instant.progress + - navigation.tabs + - navigation.tabs.sticky + - navigation.sections + - navigation.top + - navigation.path + - navigation.indexes + - toc.follow + - search.suggest + - search.highlight + - content.code.copy + - content.code.annotate + - content.tabs.link + +nav: + - Documentation: + - Overview: overview.md + - Quickstart: quickstart.md + - MCP Components: + - MCP Client: client.md + - MCP Server: server.md + - Contributing: + - Contributing Guide: contribute.md + - Documentation: development.md + - API Reference: https://javadoc.io/doc/io.modelcontextprotocol.sdk/mcp-core/latest + - News: + - blog/index.md + +markdown_extensions: + - admonition + - pymdownx.details + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.tabbed: + alternate_style: true + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.mark + - pymdownx.critic + - pymdownx.caret + - pymdownx.keys + - pymdownx.tilde + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + - attr_list + - md_in_html + - tables + - toc: + permalink: true + +extra: + version: + provider: mike + default: + - latest-snapshot + - latest + social: + - icon: fontawesome/brands/github + link: https://github.com/modelcontextprotocol/java-sdk + generator: false + +plugins: + - search + - blog diff --git a/pom.xml b/pom.xml index c0b1f7a44..049536e0d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.12.0-SNAPSHOT + 1.1.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk @@ -59,23 +59,25 @@ 17 - 3.26.3 - 5.10.2 - 5.17.0 - 1.20.4 - 1.17.5 + 3.27.6 + 6.0.2 + 5.20.0 + 1.21.4 + 1.17.8 1.21.0 2.0.16 1.5.15 - 2.17.0 + 2.20 + 2.20.1 + 3.0.3 6.2.1 3.11.0 3.1.2 3.5.2 - 3.5.0 + 3.11.2 3.3.0 0.8.10 1.5.0 @@ -90,22 +92,24 @@ 1.0.0-alpha.4 0.0.4 1.6.2 - 5.10.5 11.0.2 6.1.0 4.2.0 7.1.0 4.1.0 - 1.5.7 + 2.0.0 + 3.0.0 mcp-bom mcp - mcp-spring/mcp-spring-webflux - mcp-spring/mcp-spring-webmvc + mcp-core + mcp-json-jackson2 + mcp-json-jackson3 mcp-test + conformance-tests @@ -276,6 +280,7 @@ ${maven-javadoc-plugin.version} false + true false none @@ -315,6 +320,9 @@ true central + + mcp-parent,conformance-tests,client-jdk-http-client,client-spring-http-client,server-servlet + true