diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index 36c88040e..514f979d7 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -19,13 +19,13 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 1 - name: Run Claude Code Review id: claude-review - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} plugin_marketplaces: "https://github.com/anthropics/claude-code.git" diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 490e9ae2c..8421cf954 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -27,13 +27,13 @@ jobs: actions: read # Required for Claude to read CI results on PRs steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 1 - name: Run Claude Code id: claude - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@2f8ba26a219c06cfb0f468eef8d97055fa814f97 # v1.0.53 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} use_commit_signing: true diff --git a/.github/workflows/publish-docs-manually.yml b/.github/workflows/publish-docs-manually.yml index f058174ab..ee45ab5c8 100644 --- a/.github/workflows/publish-docs-manually.yml +++ b/.github/workflows/publish-docs-manually.yml @@ -31,3 +31,5 @@ jobs: - run: uv sync --frozen --group docs - run: uv run --frozen --no-sync mkdocs gh-deploy --force + env: + ENABLE_SOCIAL_CARDS: "true" diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index fe97895f6..72e328b54 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -26,7 +26,19 @@ jobs: with: extra_args: --all-files --verbose env: - SKIP: no-commit-to-branch + SKIP: no-commit-to-branch,readme-v1-frozen + + # TODO(Max): Drop this in v2. + - name: Check README.md is not modified + if: github.event_name == 'pull_request' + run: | + git fetch --no-tags --depth=1 origin "$BASE_SHA" + if git diff --name-only "$BASE_SHA" -- README.md | grep -q .; then + echo "::error::README.md is frozen at v1. Edit README.v2.md instead." + exit 1 + fi + env: + BASE_SHA: ${{ github.event.pull_request.base.sha }} test: name: test (${{ matrix.python-version }}, ${{ matrix.dep-resolution.name }}, ${{ matrix.os }}) diff --git a/.github/workflows/weekly-lockfile-update.yml b/.github/workflows/weekly-lockfile-update.yml index 880882247..96507d793 100644 --- a/.github/workflows/weekly-lockfile-update.yml +++ b/.github/workflows/weekly-lockfile-update.yml @@ -14,9 +14,9 @@ jobs: update-lockfile: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - uses: astral-sh/setup-uv@v7.2.1 + - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 with: version: 0.9.5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 03a8ae038..42c12fded 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,12 @@ repos: language: system files: ^(pyproject\.toml|uv\.lock)$ pass_filenames: false + # TODO(Max): Drop this in v2. + - id: readme-v1-frozen + name: README.md is frozen (v1 docs) + entry: README.md is frozen at v1. Edit README.v2.md instead. + language: fail + files: ^README\.md$ - id: readme-snippets name: Check README snippets are up to date entry: uv run --frozen python scripts/update_readme_snippets.py --check diff --git a/CLAUDE.md b/CLAUDE.md index d7b175636..e48ce6e70 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,6 +28,14 @@ This document contains critical information about working with this codebase. Fo - Bug fixes require regression tests - IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns. - IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible. + - IMPORTANT: Before pushing, verify 100% branch coverage on changed files by running + `uv run --frozen pytest -x` (coverage is configured in `pyproject.toml` with `fail_under = 100` + and `branch = true`). If any branch is uncovered, add a test for it before pushing. + - Avoid `anyio.sleep()` with a fixed duration to wait for async operations. Instead: + - Use `anyio.Event` — set it in the callback/handler, `await event.wait()` in the test + - For stream messages, use `await stream.receive()` instead of `sleep()` + `receive_nowait()` + - Exception: `sleep()` is appropriate when testing time-based features (e.g., timeouts) + - Wrap indefinite waits (`event.wait()`, `stream.receive()`) in `anyio.fail_after(5)` to prevent hangs Test files mirror the source tree: `src/mcp/client/streamable_http.py` → `tests/client/test_streamable_http.py` Add tests to the existing file for that module. diff --git a/README.md b/README.md index dc23d0d1d..487d48bee 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,13 @@ -> [!IMPORTANT] -> **This is the `main` branch which contains v2 of the SDK (currently in development, pre-alpha).** -> -> We anticipate a stable v2 release in Q1 2026. Until then, **v1.x remains the recommended version** for production use. v1.x will continue to receive bug fixes and security updates for at least 6 months after v2 ships to give people time to upgrade. + + +> [!NOTE] +> **This README documents v1.x of the MCP Python SDK (the current stable release).** > -> For v1 documentation and code, see the [`v1.x` branch](https://github.com/modelcontextprotocol/python-sdk/tree/v1.x). +> For v1.x code and documentation, see the [`v1.x` branch](https://github.com/modelcontextprotocol/python-sdk/tree/v1.x). +> For the upcoming v2 documentation (pre-alpha, in development on `main`), see [`README.v2.md`](README.v2.md). ## Table of Contents @@ -45,7 +46,7 @@ - [Sampling](#sampling) - [Logging and Notifications](#logging-and-notifications) - [Authentication](#authentication) - - [MCPServer Properties](#mcpserver-properties) + - [FastMCP Properties](#fastmcp-properties) - [Session Properties and Methods](#session-properties-and-methods) - [Request Context Properties](#request-context-properties) - [Running Your Server](#running-your-server) @@ -134,18 +135,19 @@ uv run mcp Let's create a simple MCP server that exposes a calculator tool and some data: - + ```python -"""MCPServer quickstart example. +""" +FastMCP quickstart example. Run from the repository root: - uv run examples/snippets/servers/mcpserver_quickstart.py + uv run examples/snippets/servers/fastmcp_quickstart.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create an MCP server -mcp = MCPServer("Demo") +mcp = FastMCP("Demo", json_response=True) # Add an addition tool @@ -177,16 +179,16 @@ def greet_user(name: str, style: str = "friendly") -> str: # Run with streamable HTTP transport if __name__ == "__main__": - mcp.run(transport="streamable-http", json_response=True) + mcp.run(transport="streamable-http") ``` -_Full example: [examples/snippets/servers/mcpserver_quickstart.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/mcpserver_quickstart.py)_ +_Full example: [examples/snippets/servers/fastmcp_quickstart.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/fastmcp_quickstart.py)_ You can install this server in [Claude Code](https://docs.claude.com/en/docs/claude-code/mcp) and interact with it right away. First, run the server: ```bash -uv run --with mcp examples/snippets/servers/mcpserver_quickstart.py +uv run --with mcp examples/snippets/servers/fastmcp_quickstart.py ``` Then add it to Claude Code: @@ -216,7 +218,7 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui ### Server -The MCPServer server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: +The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: ```python @@ -226,7 +228,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession @@ -256,7 +258,7 @@ class AppContext: @asynccontextmanager -async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: +async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: """Manage application lifecycle with type-safe context.""" # Initialize on startup db = await Database.connect() @@ -268,7 +270,7 @@ async def app_lifespan(server: MCPServer) -> AsyncIterator[AppContext]: # Pass lifespan to server -mcp = MCPServer("My App", lifespan=app_lifespan) +mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @@ -288,9 +290,9 @@ Resources are how you expose data to LLMs. They're similar to GET endpoints in a ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer(name="Resource Example") +mcp = FastMCP(name="Resource Example") @mcp.resource("file://documents/{name}") @@ -319,9 +321,9 @@ Tools let LLMs take actions through your server. Unlike resources, tools are exp ```python -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer(name="Tool Example") +mcp = FastMCP(name="Tool Example") @mcp.tool() @@ -340,14 +342,14 @@ def get_weather(city: str, unit: str = "celsius") -> str: _Full example: [examples/snippets/servers/basic_tool.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/basic_tool.py)_ -Tools can optionally receive a Context object by including a parameter with the `Context` type annotation. This context is automatically injected by the MCPServer framework and provides access to MCP capabilities: +Tools can optionally receive a Context object by including a parameter with the `Context` type annotation. This context is automatically injected by the FastMCP framework and provides access to MCP capabilities: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Progress Example") +mcp = FastMCP(name="Progress Example") @mcp.tool() @@ -395,7 +397,7 @@ validated data that clients can easily process. **Note:** For backward compatibility, unstructured results are also returned. Unstructured results are provided for backward compatibility with previous versions of the MCP specification, and are quirks-compatible -with previous versions of MCPServer in the current version of the SDK. +with previous versions of FastMCP in the current version of the SDK. **Note:** In cases where a tool function's return type annotation causes the tool to be classified as structured _and this is undesirable_, @@ -414,10 +416,10 @@ from typing import Annotated from pydantic import BaseModel -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP from mcp.types import CallToolResult, TextContent -mcp = MCPServer("CallToolResult Example") +mcp = FastMCP("CallToolResult Example") class ValidationModel(BaseModel): @@ -441,7 +443,7 @@ def validated_tool() -> Annotated[CallToolResult, ValidationModel]: """Return CallToolResult with structured output validation.""" return CallToolResult( content=[TextContent(type="text", text="Validated response")], - structured_content={"status": "success", "data": {"result": 42}}, + structuredContent={"status": "success", "data": {"result": 42}}, _meta={"internal": "metadata"}, ) @@ -465,9 +467,9 @@ from typing import TypedDict from pydantic import BaseModel, Field -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("Structured Output Example") +mcp = FastMCP("Structured Output Example") # Using Pydantic models for rich structured data @@ -567,10 +569,10 @@ Prompts are reusable templates that help LLMs interact with your server effectiv ```python -from mcp.server.mcpserver import MCPServer -from mcp.server.mcpserver.prompts import base +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.prompts import base -mcp = MCPServer(name="Prompt Example") +mcp = FastMCP(name="Prompt Example") @mcp.prompt(title="Code Review") @@ -595,7 +597,7 @@ _Full example: [examples/snippets/servers/basic_prompt.py](https://github.com/mo MCP servers can provide icons for UI display. Icons can be added to the server implementation, tools, resources, and prompts: ```python -from mcp.server.mcpserver import MCPServer, Icon +from mcp.server.fastmcp import FastMCP, Icon # Create an icon from a file path or URL icon = Icon( @@ -605,7 +607,7 @@ icon = Icon( ) # Add icons to server -mcp = MCPServer( +mcp = FastMCP( "My Server", website_url="https://example.com", icons=[icon] @@ -623,21 +625,21 @@ def my_resource(): return "content" ``` -_Full example: [examples/mcpserver/icons_demo.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/mcpserver/icons_demo.py)_ +_Full example: [examples/fastmcp/icons_demo.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/fastmcp/icons_demo.py)_ ### Images -MCPServer provides an `Image` class that automatically handles image data: +FastMCP provides an `Image` class that automatically handles image data: ```python -"""Example showing image handling with MCPServer.""" +"""Example showing image handling with FastMCP.""" from PIL import Image as PILImage -from mcp.server.mcpserver import Image, MCPServer +from mcp.server.fastmcp import FastMCP, Image -mcp = MCPServer("Image Example") +mcp = FastMCP("Image Example") @mcp.tool() @@ -660,9 +662,9 @@ The Context object is automatically injected into tool and resource functions th To use context in a tool or resource function, add a parameter with the `Context` type annotation: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP -mcp = MCPServer(name="Context Example") +mcp = FastMCP(name="Context Example") @mcp.tool() @@ -678,11 +680,11 @@ The Context object provides the following capabilities: - `ctx.request_id` - Unique ID for the current request - `ctx.client_id` - Client ID if available -- `ctx.mcp_server` - Access to the MCPServer server instance (see [MCPServer Properties](#mcpserver-properties)) +- `ctx.fastmcp` - Access to the FastMCP server instance (see [FastMCP Properties](#fastmcp-properties)) - `ctx.session` - Access to the underlying session for advanced communication (see [Session Properties and Methods](#session-properties-and-methods)) - `ctx.request_context` - Access to request-specific data and lifespan resources (see [Request Context Properties](#request-context-properties)) - `await ctx.debug(message)` - Send debug log message -- `await ctx.info(message)` - Send info log message +- `await ctx.info(message)` - Send info log message - `await ctx.warning(message)` - Send warning log message - `await ctx.error(message)` - Send error log message - `await ctx.log(level, message, logger_name=None)` - Send log with custom level @@ -692,10 +694,10 @@ The Context object provides the following capabilities: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Progress Example") +mcp = FastMCP(name="Progress Example") @mcp.tool() @@ -726,8 +728,9 @@ Client usage: ```python -"""cd to the `examples/snippets` directory and run: -uv run completion-client +""" +cd to the `examples/snippets` directory and run: + uv run completion-client """ import asyncio @@ -755,8 +758,8 @@ async def run(): # List available resource templates templates = await session.list_resource_templates() print("Available resource templates:") - for template in templates.resource_templates: - print(f" - {template.uri_template}") + for template in templates.resourceTemplates: + print(f" - {template.uriTemplate}") # List available prompts prompts = await session.list_prompts() @@ -765,20 +768,20 @@ async def run(): print(f" - {prompt.name}") # Complete resource template arguments - if templates.resource_templates: - template = templates.resource_templates[0] - print(f"\nCompleting arguments for resource template: {template.uri_template}") + if templates.resourceTemplates: + template = templates.resourceTemplates[0] + print(f"\nCompleting arguments for resource template: {template.uriTemplate}") # Complete without context result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri=template.uri_template), + ref=ResourceTemplateReference(type="ref/resource", uri=template.uriTemplate), argument={"name": "owner", "value": "model"}, ) print(f"Completions for 'owner' starting with 'model': {result.completion.values}") # Complete with context - repo suggestions based on owner result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri=template.uri_template), + ref=ResourceTemplateReference(type="ref/resource", uri=template.uriTemplate), argument={"name": "repo", "value": ""}, context_arguments={"owner": "modelcontextprotocol"}, ) @@ -824,12 +827,12 @@ import uuid from pydantic import BaseModel, Field -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ElicitRequestURLParams -mcp = MCPServer(name="Elicitation Example") +mcp = FastMCP(name="Elicitation Example") class BookingPreferences(BaseModel): @@ -908,7 +911,7 @@ async def connect_service(service_name: str, ctx: Context[ServerSession, None]) mode="url", message=f"Authorization required to connect to {service_name}", url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}", - elicitation_id=elicitation_id, + elicitationId=elicitation_id, ) ] ) @@ -931,11 +934,11 @@ Tools can interact with LLMs through sampling (generating text): ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent -mcp = MCPServer(name="Sampling Example") +mcp = FastMCP(name="Sampling Example") @mcp.tool() @@ -968,10 +971,10 @@ Tools can send logs and notifications through the context: ```python -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession -mcp = MCPServer(name="Notifications Example") +mcp = FastMCP(name="Notifications Example") @mcp.tool() @@ -1002,15 +1005,16 @@ MCP servers can use authentication by providing an implementation of the `TokenV ```python -"""Run from the repository root: -uv run examples/snippets/servers/oauth_server.py +""" +Run from the repository root: + uv run examples/snippets/servers/oauth_server.py """ from pydantic import AnyHttpUrl from mcp.server.auth.provider import AccessToken, TokenVerifier from mcp.server.auth.settings import AuthSettings -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP class SimpleTokenVerifier(TokenVerifier): @@ -1020,9 +1024,10 @@ class SimpleTokenVerifier(TokenVerifier): pass # This is where you would implement actual token validation -# Create MCPServer instance as a Resource Server -mcp = MCPServer( +# Create FastMCP instance as a Resource Server +mcp = FastMCP( "Weather Service", + json_response=True, # Token verifier for authentication token_verifier=SimpleTokenVerifier(), # Auth settings for RFC 9728 Protected Resource Metadata @@ -1046,7 +1051,7 @@ async def get_weather(city: str = "London") -> dict[str, str]: if __name__ == "__main__": - mcp.run(transport="streamable-http", json_response=True) + mcp.run(transport="streamable-http") ``` _Full example: [examples/snippets/servers/oauth_server.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/oauth_server.py)_ @@ -1062,19 +1067,19 @@ For a complete example with separate Authorization Server and Resource Server im See [TokenVerifier](src/mcp/server/auth/provider.py) for more details on implementing token validation. -### MCPServer Properties +### FastMCP Properties -The MCPServer server instance accessible via `ctx.mcp_server` provides access to server configuration and metadata: +The FastMCP server instance accessible via `ctx.fastmcp` provides access to server configuration and metadata: -- `ctx.mcp_server.name` - The server's name as defined during initialization -- `ctx.mcp_server.instructions` - Server instructions/description provided to clients -- `ctx.mcp_server.website_url` - Optional website URL for the server -- `ctx.mcp_server.icons` - Optional list of icons for UI display -- `ctx.mcp_server.settings` - Complete server configuration object containing: +- `ctx.fastmcp.name` - The server's name as defined during initialization +- `ctx.fastmcp.instructions` - Server instructions/description provided to clients +- `ctx.fastmcp.website_url` - Optional website URL for the server +- `ctx.fastmcp.icons` - Optional list of icons for UI display +- `ctx.fastmcp.settings` - Complete server configuration object containing: - `debug` - Debug mode flag - `log_level` - Current logging level - `host` and `port` - Server network configuration - - `sse_path`, `streamable_http_path` - Transport paths + - `mount_path`, `sse_path`, `streamable_http_path` - Transport paths - `stateless_http` - Whether the server operates in stateless mode - And other configuration options @@ -1083,12 +1088,12 @@ The MCPServer server instance accessible via `ctx.mcp_server` provides access to def server_info(ctx: Context) -> dict: """Get information about the current server.""" return { - "name": ctx.mcp_server.name, - "instructions": ctx.mcp_server.instructions, - "debug_mode": ctx.mcp_server.settings.debug, - "log_level": ctx.mcp_server.settings.log_level, - "host": ctx.mcp_server.settings.host, - "port": ctx.mcp_server.settings.port, + "name": ctx.fastmcp.name, + "instructions": ctx.fastmcp.instructions, + "debug_mode": ctx.fastmcp.settings.debug, + "log_level": ctx.fastmcp.settings.log_level, + "host": ctx.fastmcp.settings.host, + "port": ctx.fastmcp.settings.port, } ``` @@ -1110,13 +1115,13 @@ The session object accessible via `ctx.session` provides advanced control over c async def notify_data_update(resource_uri: str, ctx: Context) -> str: """Update data and notify clients of the change.""" # Perform data update logic here - + # Notify clients that this specific resource changed await ctx.session.send_resource_updated(AnyUrl(resource_uri)) - + # If this affects the overall resource list, notify about that too await ctx.session.send_resource_list_changed() - + return f"Updated {resource_uri} and notified clients" ``` @@ -1145,11 +1150,11 @@ def query_with_config(query: str, ctx: Context) -> str: """Execute a query using shared database and configuration.""" # Access typed lifespan context app_ctx: AppContext = ctx.request_context.lifespan_context - + # Use shared resources connection = app_ctx.db settings = app_ctx.config - + # Execute query with configuration result = connection.execute(query, timeout=settings.query_timeout) return str(result) @@ -1203,9 +1208,9 @@ cd to the `examples/snippets` directory and run: python servers/direct_execution.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("My App") +mcp = FastMCP("My App") @mcp.tool() @@ -1234,7 +1239,7 @@ python servers/direct_execution.py uv run mcp run servers/direct_execution.py ``` -Note that `uv run mcp run` or `uv run mcp dev` only supports server using MCPServer and not the low-level server variant. +Note that `uv run mcp run` or `uv run mcp dev` only supports server using FastMCP and not the low-level server variant. ### Streamable HTTP Transport @@ -1242,13 +1247,22 @@ Note that `uv run mcp run` or `uv run mcp dev` only supports server using MCPSer ```python -"""Run from the repository root: -uv run examples/snippets/servers/streamable_config.py +""" +Run from the repository root: + uv run examples/snippets/servers/streamable_config.py """ -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("StatelessServer") +# Stateless server with JSON responses (recommended) +mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) + +# Other configuration options: +# Stateless server with SSE streaming responses +# mcp = FastMCP("StatelessServer", stateless_http=True) + +# Stateful server with session persistence +# mcp = FastMCP("StatefulServer") # Add a simple tool to demonstrate the server @@ -1259,28 +1273,20 @@ def greet(name: str = "World") -> str: # Run server with streamable_http transport -# Transport-specific options (stateless_http, json_response) are passed to run() if __name__ == "__main__": - # Stateless server with JSON responses (recommended) - mcp.run(transport="streamable-http", stateless_http=True, json_response=True) - - # Other configuration options: - # Stateless server with SSE streaming responses - # mcp.run(transport="streamable-http", stateless_http=True) - - # Stateful server with session persistence - # mcp.run(transport="streamable-http") + mcp.run(transport="streamable-http") ``` _Full example: [examples/snippets/servers/streamable_config.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/streamable_config.py)_ -You can mount multiple MCPServer servers in a Starlette application: +You can mount multiple FastMCP servers in a Starlette application: ```python -"""Run from the repository root: -uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload +""" +Run from the repository root: + uvicorn examples.snippets.servers.streamable_starlette_mount:app --reload """ import contextlib @@ -1288,10 +1294,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create the Echo server -echo_mcp = MCPServer(name="EchoServer") +echo_mcp = FastMCP(name="EchoServer", stateless_http=True, json_response=True) @echo_mcp.tool() @@ -1301,7 +1307,7 @@ def echo(message: str) -> str: # Create the Math server -math_mcp = MCPServer(name="MathServer") +math_mcp = FastMCP(name="MathServer", stateless_http=True, json_response=True) @math_mcp.tool() @@ -1322,16 +1328,16 @@ async def lifespan(app: Starlette): # Create the Starlette app and mount the MCP servers app = Starlette( routes=[ - Mount("/echo", echo_mcp.streamable_http_app(stateless_http=True, json_response=True)), - Mount("/math", math_mcp.streamable_http_app(stateless_http=True, json_response=True)), + Mount("/echo", echo_mcp.streamable_http_app()), + Mount("/math", math_mcp.streamable_http_app()), ], lifespan=lifespan, ) # Note: Clients connect to http://localhost:8000/echo/mcp and http://localhost:8000/math/mcp # To mount at the root of each path (e.g., /echo instead of /echo/mcp): -# echo_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) -# math_mcp.streamable_http_app(streamable_http_path="/", stateless_http=True, json_response=True) +# echo_mcp.settings.streamable_http_path = "/" +# math_mcp.settings.streamable_http_path = "/" ``` _Full example: [examples/snippets/servers/streamable_starlette_mount.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/streamable_starlette_mount.py)_ @@ -1389,7 +1395,8 @@ You can mount the StreamableHTTP server to an existing ASGI server using the `st ```python -"""Basic example showing how to mount StreamableHTTP server in Starlette. +""" +Basic example showing how to mount StreamableHTTP server in Starlette. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_basic_mounting:app --reload @@ -1400,10 +1407,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = MCPServer("My App") +mcp = FastMCP("My App", json_response=True) @mcp.tool() @@ -1420,10 +1427,9 @@ async def lifespan(app: Starlette): # Mount the StreamableHTTP server to the existing ASGI server -# Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ - Mount("/", app=mcp.streamable_http_app(json_response=True)), + Mount("/", app=mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1436,7 +1442,8 @@ _Full example: [examples/snippets/servers/streamable_http_basic_mounting.py](htt ```python -"""Example showing how to mount StreamableHTTP server using Host-based routing. +""" +Example showing how to mount StreamableHTTP server using Host-based routing. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_host_mounting:app --reload @@ -1447,10 +1454,10 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Host -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = MCPServer("MCP Host App") +mcp = FastMCP("MCP Host App", json_response=True) @mcp.tool() @@ -1467,10 +1474,9 @@ async def lifespan(app: Starlette): # Mount using Host-based routing -# Transport-specific options are passed to streamable_http_app() app = Starlette( routes=[ - Host("mcp.acme.corp", app=mcp.streamable_http_app(json_response=True)), + Host("mcp.acme.corp", app=mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1483,7 +1489,8 @@ _Full example: [examples/snippets/servers/streamable_http_host_mounting.py](http ```python -"""Example showing how to mount multiple StreamableHTTP servers with path configuration. +""" +Example showing how to mount multiple StreamableHTTP servers with path configuration. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_multiple_servers:app --reload @@ -1494,11 +1501,11 @@ import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -api_mcp = MCPServer("API Server") -chat_mcp = MCPServer("Chat Server") +api_mcp = FastMCP("API Server", json_response=True) +chat_mcp = FastMCP("Chat Server", json_response=True) @api_mcp.tool() @@ -1513,6 +1520,12 @@ def send_message(message: str) -> str: return f"Message sent: {message}" +# Configure servers to mount at the root of each path +# This means endpoints will be at /api and /chat instead of /api/mcp and /chat/mcp +api_mcp.settings.streamable_http_path = "/" +chat_mcp.settings.streamable_http_path = "/" + + # Create a combined lifespan to manage both session managers @contextlib.asynccontextmanager async def lifespan(app: Starlette): @@ -1522,12 +1535,11 @@ async def lifespan(app: Starlette): yield -# Mount the servers with transport-specific options passed to streamable_http_app() -# streamable_http_path="/" means endpoints will be at /api and /chat instead of /api/mcp and /chat/mcp +# Mount the servers app = Starlette( routes=[ - Mount("/api", app=api_mcp.streamable_http_app(json_response=True, streamable_http_path="/")), - Mount("/chat", app=chat_mcp.streamable_http_app(json_response=True, streamable_http_path="/")), + Mount("/api", app=api_mcp.streamable_http_app()), + Mount("/chat", app=chat_mcp.streamable_http_app()), ], lifespan=lifespan, ) @@ -1540,7 +1552,8 @@ _Full example: [examples/snippets/servers/streamable_http_multiple_servers.py](h ```python -"""Example showing path configuration when mounting MCPServer. +""" +Example showing path configuration during FastMCP initialization. Run from the repository root: uvicorn examples.snippets.servers.streamable_http_path_config:app --reload @@ -1549,10 +1562,15 @@ Run from the repository root: from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -# Create a simple MCPServer server -mcp_at_root = MCPServer("My Server") +# Configure streamable_http_path during initialization +# This server will mount at the root of wherever it's mounted +mcp_at_root = FastMCP( + "My Server", + json_response=True, + streamable_http_path="/", +) @mcp_at_root.tool() @@ -1561,14 +1579,10 @@ def process_data(data: str) -> str: return f"Processed: {data}" -# Mount at /process with streamable_http_path="/" so the endpoint is /process (not /process/mcp) -# Transport-specific options like json_response are passed to streamable_http_app() +# Mount at /process - endpoints will be at /process instead of /process/mcp app = Starlette( routes=[ - Mount( - "/process", - app=mcp_at_root.streamable_http_app(json_response=True, streamable_http_path="/"), - ), + Mount("/process", app=mcp_at_root.streamable_http_app()), ] ) ``` @@ -1585,10 +1599,10 @@ You can mount the SSE server to an existing ASGI server using the `sse_app` meth ```python from starlette.applications import Starlette from starlette.routing import Mount, Host -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP -mcp = MCPServer("My App") +mcp = FastMCP("My App") # Mount the SSE server to the existing ASGI server app = Starlette( @@ -1601,28 +1615,41 @@ app = Starlette( app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app())) ``` -You can also mount multiple MCP servers at different sub-paths. The SSE transport automatically detects the mount path via ASGI's `root_path` mechanism, so message endpoints are correctly routed: +When mounting multiple MCP servers under different paths, you can configure the mount path in several ways: ```python from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.mcpserver import MCPServer +from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -github_mcp = MCPServer("GitHub API") -browser_mcp = MCPServer("Browser") -search_mcp = MCPServer("Search") +github_mcp = FastMCP("GitHub API") +browser_mcp = FastMCP("Browser") +curl_mcp = FastMCP("Curl") +search_mcp = FastMCP("Search") + +# Method 1: Configure mount paths via settings (recommended for persistent configuration) +github_mcp.settings.mount_path = "/github" +browser_mcp.settings.mount_path = "/browser" -# Mount each server at its own sub-path -# The SSE transport automatically uses ASGI's root_path to construct -# the correct message endpoint (e.g., /github/messages/, /browser/messages/) +# Method 2: Pass mount path directly to sse_app (preferred for ad-hoc mounting) +# This approach doesn't modify the server's settings permanently + +# Create Starlette app with multiple mounted servers app = Starlette( routes=[ + # Using settings-based configuration Mount("/github", app=github_mcp.sse_app()), Mount("/browser", app=browser_mcp.sse_app()), - Mount("/search", app=search_mcp.sse_app()), + # Using direct mount path parameter + Mount("/curl", app=curl_mcp.sse_app("/curl")), + Mount("/search", app=search_mcp.sse_app("/search")), ] ) + +# Method 3: For direct execution, you can also pass the mount path to run() +if __name__ == "__main__": + search_mcp.run(transport="sse", mount_path="/search") ``` For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes). @@ -1635,8 +1662,9 @@ For more control, you can use the low-level server implementation directly. This ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/lifespan.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/lifespan.py """ from collections.abc import AsyncIterator @@ -1644,7 +1672,7 @@ from contextlib import asynccontextmanager from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1692,7 +1720,7 @@ async def handle_list_tools() -> list[types.Tool]: types.Tool( name="query_db", description="Query the database", - input_schema={ + inputSchema={ "type": "object", "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, "required": ["query"], @@ -1751,14 +1779,15 @@ The lifespan API provides: ```python -"""Run from the repository root: +""" +Run from the repository root: uv run examples/snippets/servers/lowlevel/basic.py """ import asyncio import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1829,15 +1858,16 @@ The low-level server supports structured output for tools, allowing you to retur ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/structured_output.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/structured_output.py """ import asyncio from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1851,12 +1881,12 @@ async def list_tools() -> list[types.Tool]: types.Tool( name="get_weather", description="Get current weather for a city", - input_schema={ + inputSchema={ "type": "object", "properties": {"city": {"type": "string", "description": "City name"}}, "required": ["city"], }, - output_schema={ + outputSchema={ "type": "object", "properties": { "temperature": {"type": "number", "description": "Temperature in Celsius"}, @@ -1931,15 +1961,16 @@ For full control over the response including the `_meta` field (for passing data ```python -"""Run from the repository root: -uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py """ import asyncio from typing import Any import mcp.server.stdio -from mcp import types +import mcp.types as types from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions @@ -1953,7 +1984,7 @@ async def list_tools() -> list[types.Tool]: types.Tool( name="advanced_tool", description="Tool with full control including _meta field", - input_schema={ + inputSchema={ "type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"], @@ -1969,7 +2000,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallTo message = str(arguments.get("message", "")) return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], - structured_content={"result": "success", "message": message}, + structuredContent={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) @@ -2010,9 +2041,13 @@ For servers that need to handle large datasets, the low-level server provides pa ```python -"""Example of implementing pagination with MCP server decorators.""" +""" +Example of implementing pagination with MCP server decorators. +""" + +from pydantic import AnyUrl -from mcp import types +import mcp.types as types from mcp.server.lowlevel import Server # Initialize the server @@ -2036,14 +2071,14 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types # Get page of resources page_items = [ - types.Resource(uri=f"resource://items/{item}", name=item, description=f"Description for {item}") + types.Resource(uri=AnyUrl(f"resource://items/{item}"), name=item, description=f"Description for {item}") for item in ITEMS[start:end] ] # Determine next cursor next_cursor = str(end) if end < len(ITEMS) else None - return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + return types.ListResourcesResult(resources=page_items, nextCursor=next_cursor) ``` _Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ @@ -2053,7 +2088,9 @@ _Full example: [examples/snippets/servers/pagination_example.py](https://github. ```python -"""Example of consuming paginated MCP endpoints from a client.""" +""" +Example of consuming paginated MCP endpoints from a client. +""" import asyncio @@ -2082,8 +2119,8 @@ async def list_all_resources() -> None: print(f"Fetched {len(result.resources)} resources") # Check if there are more pages - if result.next_cursor: - cursor = result.next_cursor + if result.nextCursor: + cursor = result.nextCursor else: break @@ -2112,28 +2149,31 @@ The SDK provides a high-level client interface for connecting to MCP servers usi ```python -"""cd to the `examples/snippets/clients` directory and run: -uv run client +""" +cd to the `examples/snippets/clients` directory and run: + uv run client """ import asyncio import os +from pydantic import AnyUrl + from mcp import ClientSession, StdioServerParameters, types -from mcp.client.context import ClientRequestContext from mcp.client.stdio import stdio_client +from mcp.shared.context import RequestContext # Create server parameters for stdio connection server_params = StdioServerParameters( command="uv", # Using uv to run the server - args=["run", "server", "mcpserver_quickstart", "stdio"], # We're already in snippets dir + args=["run", "server", "fastmcp_quickstart", "stdio"], # We're already in snippets dir env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) # Optional: create a sampling callback async def handle_sampling_message( - context: ClientRequestContext, params: types.CreateMessageRequestParams + context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2143,7 +2183,7 @@ async def handle_sampling_message( text="Hello, world! from model", ), model="gpt-3.5-turbo", - stop_reason="endTurn", + stopReason="endTurn", ) @@ -2157,7 +2197,7 @@ async def run(): prompts = await session.list_prompts() print(f"Available prompts: {[p.name for p in prompts.prompts]}") - # Get a prompt (greet_user prompt from mcpserver_quickstart) + # Get a prompt (greet_user prompt from fastmcp_quickstart) if prompts.prompts: prompt = await session.get_prompt("greet_user", arguments={"name": "Alice", "style": "friendly"}) print(f"Prompt result: {prompt.messages[0].content}") @@ -2170,18 +2210,18 @@ async def run(): tools = await session.list_tools() print(f"Available tools: {[t.name for t in tools.tools]}") - # Read a resource (greeting resource from mcpserver_quickstart) - resource_content = await session.read_resource("greeting://World") + # Read a resource (greeting resource from fastmcp_quickstart) + resource_content = await session.read_resource(AnyUrl("greeting://World")) content_block = resource_content.contents[0] if isinstance(content_block, types.TextContent): print(f"Resource content: {content_block.text}") - # Call a tool (add tool from mcpserver_quickstart) + # Call a tool (add tool from fastmcp_quickstart) result = await session.call_tool("add", arguments={"a": 5, "b": 3}) result_unstructured = result.content[0] if isinstance(result_unstructured, types.TextContent): print(f"Tool result: {result_unstructured.text}") - result_structured = result.structured_content + result_structured = result.structuredContent print(f"Structured tool result: {result_structured}") @@ -2201,8 +2241,9 @@ Clients can also connect using [Streamable HTTP transport](https://modelcontextp ```python -"""Run from the repository root: -uv run examples/snippets/clients/streamable_basic.py +""" +Run from the repository root: + uv run examples/snippets/clients/streamable_basic.py """ import asyncio @@ -2240,8 +2281,9 @@ When building MCP clients, the SDK provides utilities to help display human-read ```python -"""cd to the `examples/snippets` directory and run: -uv run display-utilities-client +""" +cd to the `examples/snippets` directory and run: + uv run display-utilities-client """ import asyncio @@ -2254,7 +2296,7 @@ from mcp.shared.metadata_utils import get_display_name # Create server parameters for stdio connection server_params = StdioServerParameters( command="uv", # Using uv to run the server - args=["run", "server", "mcpserver_quickstart", "stdio"], + args=["run", "server", "fastmcp_quickstart", "stdio"], env={"UV_INDEX": os.environ.get("UV_INDEX", "")}, ) @@ -2280,7 +2322,7 @@ async def display_resources(session: ClientSession): print(f"Resource: {display_name} ({resource.uri})") templates_response = await session.list_resource_templates() - for template in templates_response.resource_templates: + for template in templates_response.resourceTemplates: display_name = get_display_name(template) print(f"Resource Template: {display_name}") @@ -2324,7 +2366,8 @@ The SDK includes [authorization support](https://modelcontextprotocol.io/specifi ```python -"""Before running, specify running MCP RS server URL. +""" +Before running, specify running MCP RS server URL. To spin up RS server locally, see examples/servers/simple-auth/README.md diff --git a/README.v2.md b/README.v2.md index 67f181811..bd6927bf9 100644 --- a/README.v2.md +++ b/README.v2.md @@ -1642,12 +1642,11 @@ uv run examples/snippets/servers/lowlevel/lifespan.py from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import TypedDict import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext # Mock database class for example @@ -1670,52 +1669,58 @@ class Database: return [{"id": "1", "name": "Example", "query": query_str}] +class AppContext(TypedDict): + db: Database + + @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: +async def server_lifespan(_server: Server[AppContext]) -> AsyncIterator[AppContext]: """Manage server startup and shutdown lifecycle.""" - # Initialize resources on startup db = await Database.connect() try: yield {"db": db} finally: - # Clean up on shutdown await db.disconnect() -# Pass lifespan to server -server = Server("example-server", lifespan=server_lifespan) - - -@server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="query_db", - description="Query the database", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, - "required": ["query"], - }, - ) - ] + return types.ListToolsResult( + tools=[ + types.Tool( + name="query_db", + description="Query the database", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, + "required": ["query"], + }, + ) + ] + ) -@server.call_tool() -async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: +async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: types.CallToolRequestParams +) -> types.CallToolResult: """Handle database query tool call.""" - if name != "query_db": - raise ValueError(f"Unknown tool: {name}") + if params.name != "query_db": + raise ValueError(f"Unknown tool: {params.name}") - # Access lifespan context - ctx = server.request_context db = ctx.lifespan_context["db"] + results = await db.query((params.arguments or {})["query"]) - # Execute query - results = await db.query(arguments["query"]) + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Query results: {results}")]) - return [types.TextContent(type="text", text=f"Query results: {results}")] + +server = Server( + "example-server", + lifespan=server_lifespan, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1724,14 +1729,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example-server", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1760,32 +1758,30 @@ import asyncio import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions - -# Create a server instance -server = Server("example-server") +from mcp.server import Server, ServerRequestContext -@server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: """List available prompts.""" - return [ - types.Prompt( - name="example-prompt", - description="An example prompt template", - arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="example-prompt", + description="An example prompt template", + arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], + ) + ] + ) -@server.get_prompt() -async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: """Get a specific prompt by name.""" - if name != "example-prompt": - raise ValueError(f"Unknown prompt: {name}") + if params.name != "example-prompt": + raise ValueError(f"Unknown prompt: {params.name}") - arg1_value = (arguments or {}).get("arg1", "default") + arg1_value = (params.arguments or {}).get("arg1", "default") return types.GetPromptResult( description="Example prompt", @@ -1798,20 +1794,20 @@ async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> type ) +server = Server( + "example-server", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, +) + + async def run(): """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1835,62 +1831,67 @@ uv run examples/snippets/servers/lowlevel/structured_output.py """ import asyncio -from typing import Any +import json import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with structured output schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get current weather for a city", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number", "description": "Temperature in Celsius"}, - "condition": {"type": "string", "description": "Weather condition"}, - "humidity": {"type": "number", "description": "Humidity percentage"}, - "city": {"type": "string", "description": "City name"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get current weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], }, - "required": ["temperature", "condition", "humidity", "city"], - }, - ) - ] + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in Celsius"}, + "condition": {"type": "string", "description": "Weather condition"}, + "humidity": {"type": "number", "description": "Humidity percentage"}, + "city": {"type": "string", "description": "City name"}, + }, + "required": ["temperature", "condition", "humidity", "city"], + }, + ) + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls with structured output.""" - if name == "get_weather": - city = arguments["city"] + if params.name == "get_weather": + city = (params.arguments or {})["city"] - # Simulated weather data - in production, call a weather API weather_data = { "temperature": 22.5, "condition": "partly cloudy", "humidity": 65, - "city": city, # Include the requested city + "city": city, } - # low-level server will validate structured output against the tool's - # output schema, and additionally serialize it into a TextContent block - # for backwards compatibility with pre-2025-06-18 clients. - return weather_data - else: - raise ValueError(f"Unknown tool: {name}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1899,14 +1900,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -1917,18 +1911,11 @@ if __name__ == "__main__": _Full example: [examples/snippets/servers/lowlevel/structured_output.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/structured_output.py)_ -Tools can return data in four ways: +With the low-level server, handlers always return `CallToolResult` directly. You construct both the human-readable `content` and the machine-readable `structured_content` yourself, giving you full control over the response. -1. **Content only**: Return a list of content blocks (default behavior before spec revision 2025-06-18) -2. **Structured data only**: Return a dictionary that will be serialized to JSON (Introduced in spec revision 2025-06-18) -3. **Both**: Return a tuple of (content, structured_data) preferred option to use for backwards compatibility -4. **Direct CallToolResult**: Return `CallToolResult` directly for full control (including `_meta` field) +##### Returning CallToolResult with `_meta` -When an `outputSchema` is defined, the server automatically validates the structured output against the schema. This ensures type safety and helps catch errors early. - -##### Returning CallToolResult Directly - -For full control over the response including the `_meta` field (for passing data to client applications without exposing it to the model), return `CallToolResult` directly: +For passing data to client applications without exposing it to the model, use the `_meta` field on `CallToolResult`: ```python @@ -1937,44 +1924,49 @@ uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py """ import asyncio -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions - -server = Server("example-server") +from mcp.server import Server, ServerRequestContext -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="advanced_tool", - description="Tool with full control including _meta field", - input_schema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ) - ] + return types.ListToolsResult( + tools=[ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + ) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls by returning CallToolResult directly.""" - if name == "advanced_tool": - message = str(arguments.get("message", "")) + if params.name == "advanced_tool": + message = (params.arguments or {}).get("message", "") return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], structured_content={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -1983,14 +1975,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) @@ -2001,8 +1986,6 @@ if __name__ == "__main__": _Full example: [examples/snippets/servers/lowlevel/direct_call_tool_result.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/direct_call_tool_result.py)_ -**Note:** When returning `CallToolResult`, you bypass the automatic content/structured conversion. You must construct the complete response yourself. - ### Pagination (Advanced) For servers that need to handle large datasets, the low-level server provides paginated versions of list operations. This is an optional optimization - most servers won't need pagination unless they're dealing with hundreds or thousands of items. @@ -2011,25 +1994,23 @@ For servers that need to handle large datasets, the low-level server provides pa ```python -"""Example of implementing pagination with MCP server decorators.""" +"""Example of implementing pagination with the low-level MCP server.""" from mcp import types -from mcp.server.lowlevel import Server - -# Initialize the server -server = Server("paginated-server") +from mcp.server import Server, ServerRequestContext # Sample data to paginate ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources() -async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 # Extract cursor from request params - cursor = request.params.cursor if request.params is not None else None + cursor = params.cursor if params is not None else None # Parse cursor to get offset start = 0 if cursor is None else int(cursor) @@ -2045,6 +2026,9 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + + +server = Server("paginated-server", on_list_resources=handle_list_resources) ``` _Full example: [examples/snippets/servers/pagination_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/pagination_example.py)_ diff --git a/SECURITY.md b/SECURITY.md index 654515610..502924200 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,15 +1,21 @@ # Security Policy -Thank you for helping us keep the SDKs and systems they interact with secure. +Thank you for helping keep the Model Context Protocol and its ecosystem secure. ## Reporting Security Issues -This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model Context Protocol project. +If you discover a security vulnerability in this repository, please report it through +the [GitHub Security Advisory process](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability) +for this repository. -The security of our systems and user data is Anthropic’s top priority. We appreciate the work of security researchers acting in good faith in identifying and reporting potential vulnerabilities. +Please **do not** report security vulnerabilities through public GitHub issues, discussions, +or pull requests. -Our security program is managed on HackerOne and we ask that any validated vulnerability in this functionality be reported through their [submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability). +## What to Include -## Vulnerability Disclosure Program +To help us triage and respond quickly, please include: -Our Vulnerability Program Guidelines are defined on our [HackerOne program page](https://hackerone.com/anthropic-vdp). +- A description of the vulnerability +- Steps to reproduce the issue +- The potential impact +- Any suggested fixes (optional) diff --git a/docs/experimental/index.md b/docs/experimental/index.md index 1d496b3f1..c97fe2a3d 100644 --- a/docs/experimental/index.md +++ b/docs/experimental/index.md @@ -27,10 +27,9 @@ Tasks are useful for: Experimental features are accessed via the `.experimental` property: ```python -# Server-side -@server.experimental.get_task() -async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - ... +# Server-side: enable task support (auto-registers default handlers) +server = Server(name="my-server") +server.experimental.enable_tasks() # Client-side result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) diff --git a/docs/migration.md b/docs/migration.md index 7d30f0ac9..631683693 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -351,7 +351,6 @@ The nested `RequestParams.Meta` Pydantic model class has been replaced with a to - `RequestParams.Meta` (Pydantic model) → `RequestParamsMeta` (TypedDict) - Attribute access (`meta.progress_token`) → Dictionary access (`meta.get("progress_token")`) - `progress_token` field changed from `ProgressToken | None = None` to `NotRequired[ProgressToken]` -` **In request context handlers:** @@ -364,14 +363,15 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]: await ctx.session.send_progress_notification(ctx.meta.progress_token, 0.5, 100) # After (v2) -@server.call_tool() -async def handle_tool(name: str, arguments: dict) -> list[TextContent]: - ctx = server.request_context +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: if ctx.meta and "progress_token" in ctx.meta: await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100) + ... + +server = Server("my-server", on_call_tool=handle_call_tool) ``` -### `RequestContext` and `ProgressContext` type parameters simplified +### `RequestContext` type parameters simplified The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. @@ -380,40 +380,59 @@ The `RequestContext` class has been split to separate shared fields from server- - Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` - Server-specific fields (`lifespan_context`, `experimental`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` -**`ProgressContext` changes:** - -- Type parameters reduced from `ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]` to `ProgressContext[SessionT]` - **Before (v1):** ```python from mcp.client.session import ClientSession from mcp.shared.context import RequestContext, LifespanContextT, RequestT -from mcp.shared.progress import ProgressContext # RequestContext with 3 type parameters ctx: RequestContext[ClientSession, LifespanContextT, RequestT] - -# ProgressContext with 5 type parameters -progress_ctx: ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] ``` **After (v2):** ```python from mcp.client.context import ClientRequestContext -from mcp.client.session import ClientSession from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT -from mcp.shared.progress import ProgressContext # For client-side context (sampling, elicitation, list_roots callbacks) ctx: ClientRequestContext # For server-specific context with lifespan and request types server_ctx: ServerRequestContext[LifespanContextT, RequestT] +``` + +### `ProgressContext` and `progress()` context manager removed -# ProgressContext with 1 type parameter -progress_ctx: ProgressContext[ClientSession] +The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. + +**Before:** + +```python +from mcp.shared.progress import progress + +with progress(ctx, total=100) as p: + await p.progress(25) +``` + +**After — use `Context.report_progress()` (recommended):** + +```python +@server.tool() +async def my_tool(x: int, ctx: Context) -> str: + await ctx.report_progress(25, 100) + return "done" +``` + +**After — use `session.send_progress_notification()` (low-level):** + +```python +await session.send_progress_notification( + progress_token=progress_token, + progress=25, + total=100, +) ``` ### Resource URI type changed from `AnyUrl` to `str` @@ -471,12 +490,292 @@ await client.read_resource("test://resource") await client.read_resource(str(my_any_url)) ``` +### Lowlevel `Server`: constructor parameters are now keyword-only + +All parameters after `name` are now keyword-only. If you were passing `version` or other parameters positionally, use keyword arguments instead: + +```python +# Before (v1) +server = Server("my-server", "1.0") + +# After (v2) +server = Server("my-server", version="1.0") +``` + +### Lowlevel `Server`: type parameter reduced from 2 to 1 + +The `Server` class previously had two type parameters: `Server[LifespanResultT, RequestT]`. The `RequestT` parameter has been removed — handlers now receive typed params directly rather than a generic request type. + +```python +# Before (v1) +from typing import Any + +from mcp.server.lowlevel.server import Server + +server: Server[dict[str, Any], Any] = Server(...) + +# After (v2) +from typing import Any + +from mcp.server import Server + +server: Server[dict[str, Any]] = Server(...) +``` + +### Lowlevel `Server`: `request_handlers` and `notification_handlers` attributes removed + +The public `server.request_handlers` and `server.notification_handlers` dictionaries have been removed. Handler registration is now done exclusively through constructor `on_*` keyword arguments. There is no public API to register handlers after construction. + +```python +# Before (v1) — direct dict access +from mcp.types import ListToolsRequest + +if ListToolsRequest in server.request_handlers: + ... + +# After (v2) — no public access to handler dicts +# Use the on_* constructor params to register handlers +server = Server("my-server", on_list_tools=handle_list_tools) +``` + +### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params + +The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import Server + +server = Server("my-server") + +@server.list_tools() +async def handle_list_tools(): + return [types.Tool(name="my_tool", description="A tool", inputSchema={})] + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + return [types.TextContent(type="text", text=f"Called {name}")] +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", description="A tool", input_schema={})]) + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=f"Called {params.name}")], + is_error=False, + ) + +server = Server("my-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) +``` + +**Key differences:** + +- Handlers receive `(ctx, params)` instead of the full request object or unpacked arguments. `ctx` is a `ServerRequestContext` with `session`, `lifespan_context`, and `experimental` fields (plus `request_id`, `meta`, etc. for request handlers). `params` is the typed request params object. +- Handlers return the full result type (e.g. `ListToolsResult`) rather than unwrapped values (e.g. `list[Tool]`). +- The automatic `jsonschema` input/output validation that the old `call_tool()` decorator performed has been removed. There is no built-in replacement — if you relied on schema validation in the lowlevel server, you will need to validate inputs yourself in your handler. + +**Notification handlers:** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import ProgressNotificationParams + + +async def handle_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + print(f"Progress: {params.progress}/{params.total}") + +server = Server("my-server", on_progress=handle_progress) +``` + +### Lowlevel `Server`: automatic return value wrapping removed + +The old decorator-based handlers performed significant automatic wrapping of return values. This magic has been removed — handlers now return fully constructed result types. If you want these conveniences, use `MCPServer` (previously `FastMCP`) instead of the lowlevel `Server`. + +**`call_tool()` — structured output wrapping removed:** + +The old decorator accepted several return types and auto-wrapped them into `CallToolResult`: + +```python +# Before (v1) — returning a dict auto-wrapped into structured_content + JSON TextContent +@server.call_tool() +async def handle(name: str, arguments: dict) -> dict: + return {"temperature": 22.5, "city": "London"} + +# Before (v1) — returning a list auto-wrapped into CallToolResult.content +@server.call_tool() +async def handle(name: str, arguments: dict) -> list[TextContent]: + return [TextContent(type="text", text="Done")] +``` + +```python +# After (v2) — construct the full result yourself +import json + +async def handle(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + data = {"temperature": 22.5, "city": "London"} + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(data, indent=2))], + structured_content=data, + ) +``` + +Note: `params.arguments` can be `None` (the old decorator defaulted it to `{}`). Use `params.arguments or {}` to preserve the old behavior. + +**`read_resource()` — content type wrapping removed:** + +The old decorator auto-wrapped `str` into `TextResourceContents` and `bytes` into `BlobResourceContents` (with base64 encoding), and applied a default mime type of `text/plain`: + +```python +# Before (v1) — str/bytes auto-wrapped with mime type defaulting +@server.read_resource() +async def handle(uri: str) -> str: + return "file contents" + +@server.read_resource() +async def handle(uri: str) -> bytes: + return b"\x89PNG..." +``` + +```python +# After (v2) — construct TextResourceContents or BlobResourceContents yourself +import base64 + +async def handle_read(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + # Text content + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="file contents", mime_type="text/plain")] + ) + +async def handle_read(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + # Binary content — you must base64-encode it yourself + return ReadResourceResult( + contents=[BlobResourceContents( + uri=str(params.uri), + blob=base64.b64encode(b"\x89PNG...").decode("utf-8"), + mime_type="image/png", + )] + ) +``` + +**`list_tools()`, `list_resources()`, `list_prompts()` — list wrapping removed:** + +The old decorators accepted bare lists and wrapped them into the result type: + +```python +# Before (v1) +@server.list_tools() +async def handle() -> list[Tool]: + return [Tool(name="my_tool", ...)] + +# After (v2) +async def handle(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="my_tool", ...)]) +``` + +**Using `MCPServer` instead:** + +If you prefer the convenience of automatic wrapping, use `MCPServer` which still provides these features through its `@mcp.tool()`, `@mcp.resource()`, and `@mcp.prompt()` decorators. The lowlevel `Server` is intentionally minimal — it provides no magic and gives you full control over the MCP protocol types. + +### Lowlevel `Server`: `request_context` property removed + +The `server.request_context` property has been removed. Request context is now passed directly to handlers as the first argument (`ctx`). The `request_ctx` module-level contextvar is now an internal implementation detail and should not be relied upon. + +**Before (v1):** + +```python +from mcp.server.lowlevel.server import request_ctx + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict): + ctx = server.request_context # or request_ctx.get() + await ctx.session.send_log_message(level="info", data="Processing...") + return [types.TextContent(type="text", text="Done")] +``` + +**After (v2):** + +```python +from mcp.server import ServerRequestContext +from mcp.types import CallToolRequestParams, CallToolResult, TextContent + + +async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + await ctx.session.send_log_message(level="info", data="Processing...") + return CallToolResult( + content=[TextContent(type="text", text="Done")], + is_error=False, + ) +``` + +### `RequestContext`: request-specific fields are now optional + +The `RequestContext` class now uses optional fields for request-specific data (`request_id`, `meta`, etc.) so it can be used for both request and notification handlers. In notification handlers, these fields are `None`. + +```python +from mcp.server import ServerRequestContext + +# request_id, meta, etc. are available in request handlers +# but None in notification handlers +``` + +### Experimental: task handler decorators removed + +The experimental decorator methods on `ExperimentalHandlers` (`@server.experimental.list_tasks()`, `@server.experimental.get_task()`, etc.) have been removed. + +Default task handlers are still registered automatically via `server.experimental.enable_tasks()`. Custom handlers can be passed as `on_*` kwargs to override specific defaults. + +**Before (v1):** + +```python +server = Server("my-server") +server.experimental.enable_tasks() + +@server.experimental.get_task() +async def custom_get_task(request: GetTaskRequest) -> GetTaskResult: + ... +``` + +**After (v2):** + +```python +from mcp.server import Server, ServerRequestContext +from mcp.types import GetTaskRequestParams, GetTaskResult + + +async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + ... + + +server = Server("my-server") +server.experimental.enable_tasks(on_get_task=custom_get_task) +``` + ## Deprecations ## Bug Fixes +### Lowlevel `Server`: `subscribe` capability now correctly reported + +Previously, the lowlevel `Server` hardcoded `subscribe=False` in resource capabilities even when a `subscribe_resource()` handler was registered. The `subscribe` capability is now dynamically set to `True` when an `on_subscribe_resource` handler is provided. Clients that previously didn't see `subscribe: true` in capabilities will now see it when a handler is registered, which may change client behavior. + ### Extra fields no longer allowed on top-level MCP types MCP protocol types no longer accept arbitrary extra fields at the top level. This matches the MCP specification which only allows extra fields within `_meta` objects, not on the types themselves. @@ -506,16 +805,16 @@ params = CallToolRequestParams( The `streamable_http_app()` method is now available directly on the lowlevel `Server` class, not just `MCPServer`. This allows using the streamable HTTP transport without the MCPServer wrapper. ```python -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams -server = Server("my-server") -# Register handlers... -@server.list_tools() -async def list_tools(): - return [...] +async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[...]) + + +server = Server("my-server", on_list_tools=handle_list_tools) -# Create a Starlette app for streamable HTTP app = server.streamable_http_app( streamable_http_path="/mcp", json_response=False, diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 4fb7d9a1d..2101cff28 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -10,6 +10,7 @@ import logging import click +from mcp.server import ServerRequestContext from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.prompts.base import UserMessage from mcp.server.session import ServerSession @@ -20,13 +21,17 @@ CompletionArgument, CompletionContext, EmbeddedResource, + EmptyResult, ImageContent, JSONRPCMessage, PromptReference, ResourceTemplateReference, SamplingMessage, + SetLevelRequestParams, + SubscribeRequestParams, TextContent, TextResourceContents, + UnsubscribeRequestParams, ) from pydantic import BaseModel, Field @@ -393,28 +398,29 @@ def test_prompt_with_image() -> list[UserMessage]: # Custom request handlers # TODO(felix): Add public APIs to MCPServer for subscribe_resource, unsubscribe_resource, # and set_logging_level to avoid accessing protected _lowlevel_server attribute. -@mcp._lowlevel_server.set_logging_level() # pyright: ignore[reportPrivateUsage] -async def handle_set_logging_level(level: str) -> None: +async def handle_set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: """Handle logging level changes""" - logger.info(f"Log level set to: {level}") - # In a real implementation, you would adjust the logging level here - # For conformance testing, we just acknowledge the request + logger.info(f"Log level set to: {params.level}") + return EmptyResult() -async def handle_subscribe(uri: str) -> None: +async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: """Handle resource subscription""" - resource_subscriptions.add(str(uri)) - logger.info(f"Subscribed to resource: {uri}") + resource_subscriptions.add(str(params.uri)) + logger.info(f"Subscribed to resource: {params.uri}") + return EmptyResult() -async def handle_unsubscribe(uri: str) -> None: +async def handle_unsubscribe(ctx: ServerRequestContext, params: UnsubscribeRequestParams) -> EmptyResult: """Handle resource unsubscription""" - resource_subscriptions.discard(str(uri)) - logger.info(f"Unsubscribed from resource: {uri}") + resource_subscriptions.discard(str(params.uri)) + logger.info(f"Unsubscribed from resource: {params.uri}") + return EmptyResult() -mcp._lowlevel_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server.unsubscribe_resource()(handle_unsubscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server._add_request_handler("resources/unsubscribe", handle_unsubscribe) # pyright: ignore[reportPrivateUsage] @mcp.completion() diff --git a/examples/servers/simple-pagination/mcp_simple_pagination/server.py b/examples/servers/simple-pagination/mcp_simple_pagination/server.py index ff45ae224..bac27a0f1 100644 --- a/examples/servers/simple-pagination/mcp_simple_pagination/server.py +++ b/examples/servers/simple-pagination/mcp_simple_pagination/server.py @@ -1,17 +1,19 @@ """Simple MCP server demonstrating pagination for tools, resources, and prompts. -This example shows how to use the paginated decorators to handle large lists -of items that need to be split across multiple pages. +This example shows how to implement pagination with the low-level server API +to handle large lists of items that need to be split across multiple pages. """ -from typing import Any +from typing import TypeVar import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from starlette.requests import Request +T = TypeVar("T") + # Sample data - in real scenarios, this might come from a database SAMPLE_TOOLS = [ types.Tool( @@ -44,6 +46,102 @@ ] +def _paginate(cursor: str | None, items: list[T], page_size: int) -> tuple[list[T], str | None]: + """Helper to paginate a list of items given a cursor.""" + if cursor is not None: + try: + start_idx = int(cursor) + except (ValueError, TypeError): + return [], None + else: + start_idx = 0 + + page = items[start_idx : start_idx + page_size] + next_cursor = str(start_idx + page_size) if start_idx + page_size < len(items) else None + return page, next_cursor + + +# Paginated list_tools - returns 5 tools per page +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_TOOLS, page_size=5) + return types.ListToolsResult(tools=page, next_cursor=next_cursor) + + +# Paginated list_resources - returns 10 resources per page +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_RESOURCES, page_size=10) + return types.ListResourcesResult(resources=page, next_cursor=next_cursor) + + +# Paginated list_prompts - returns 7 prompts per page +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + cursor = params.cursor if params is not None else None + page, next_cursor = _paginate(cursor, SAMPLE_PROMPTS, page_size=7) + return types.ListPromptsResult(prompts=page, next_cursor=next_cursor) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + # Find the tool in our sample data + tool = next((t for t in SAMPLE_TOOLS if t.name == params.name), None) + if not tool: + raise ValueError(f"Unknown tool: {params.name}") + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=f"Called tool '{params.name}' with arguments: {params.arguments}", + ) + ] + ) + + +async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + resource = next((r for r in SAMPLE_RESOURCES if r.uri == str(params.uri)), None) + if not resource: + raise ValueError(f"Unknown resource: {params.uri}") + + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text=f"Content of {resource.name}: This is sample content for the resource.", + mime_type="text/plain", + ) + ] + ) + + +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + prompt = next((p for p in SAMPLE_PROMPTS if p.name == params.name), None) + if not prompt: + raise ValueError(f"Unknown prompt: {params.name}") + + message_text = f"This is the prompt '{params.name}'" + if params.arguments: + message_text += f" with arguments: {params.arguments}" + + return types.GetPromptResult( + description=prompt.description, + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text=message_text), + ) + ], + ) + + @click.command() @click.option("--port", default=8000, help="Port to listen on for SSE") @click.option( @@ -53,142 +151,15 @@ help="Transport type", ) def main(port: int, transport: str) -> int: - app = Server("mcp-simple-pagination") - - # Paginated list_tools - returns 5 tools per page - @app.list_tools() - async def list_tools_paginated(request: types.ListToolsRequest) -> types.ListToolsResult: - page_size = 5 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListToolsResult(tools=[], next_cursor=None) - - # Get the page of tools - page_tools = SAMPLE_TOOLS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_TOOLS): - next_cursor = str(start_idx + page_size) - - return types.ListToolsResult(tools=page_tools, next_cursor=next_cursor) - - # Paginated list_resources - returns 10 resources per page - @app.list_resources() - async def list_resources_paginated( - request: types.ListResourcesRequest, - ) -> types.ListResourcesResult: - page_size = 10 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListResourcesResult(resources=[], next_cursor=None) - - # Get the page of resources - page_resources = SAMPLE_RESOURCES[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_RESOURCES): - next_cursor = str(start_idx + page_size) - - return types.ListResourcesResult(resources=page_resources, next_cursor=next_cursor) - - # Paginated list_prompts - returns 7 prompts per page - @app.list_prompts() - async def list_prompts_paginated( - request: types.ListPromptsRequest, - ) -> types.ListPromptsResult: - page_size = 7 - - cursor = request.params.cursor if request.params is not None else None - if cursor is None: - # First page - start_idx = 0 - else: - # Parse cursor to get the start index - try: - start_idx = int(cursor) - except (ValueError, TypeError): - # Invalid cursor, return empty - return types.ListPromptsResult(prompts=[], next_cursor=None) - - # Get the page of prompts - page_prompts = SAMPLE_PROMPTS[start_idx : start_idx + page_size] - - # Determine if there are more pages - next_cursor = None - if start_idx + page_size < len(SAMPLE_PROMPTS): - next_cursor = str(start_idx + page_size) - - return types.ListPromptsResult(prompts=page_prompts, next_cursor=next_cursor) - - # Implement call_tool handler - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - # Find the tool in our sample data - tool = next((t for t in SAMPLE_TOOLS if t.name == name), None) - if not tool: - raise ValueError(f"Unknown tool: {name}") - - # Simple mock response - return [ - types.TextContent( - type="text", - text=f"Called tool '{name}' with arguments: {arguments}", - ) - ] - - # Implement read_resource handler - @app.read_resource() - async def read_resource(uri: str) -> str: - # Find the resource in our sample data - resource = next((r for r in SAMPLE_RESOURCES if r.uri == uri), None) - if not resource: - raise ValueError(f"Unknown resource: {uri}") - - # Return a simple string - the decorator will convert it to TextResourceContents - return f"Content of {resource.name}: This is sample content for the resource." - - # Implement get_prompt handler - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: - # Find the prompt in our sample data - prompt = next((p for p in SAMPLE_PROMPTS if p.name == name), None) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - - # Simple mock response - message_text = f"This is the prompt '{name}'" - if arguments: - message_text += f" with arguments: {arguments}" - - return types.GetPromptResult( - description=prompt.description, - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent(type="text", text=message_text), - ) - ], - ) + app = Server( + "mcp-simple-pagination", + on_list_tools=handle_list_tools, + on_list_resources=handle_list_resources, + on_list_prompts=handle_list_prompts, + on_call_tool=handle_call_tool, + on_read_resource=handle_read_resource, + on_get_prompt=handle_get_prompt, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index cbc5a9d68..6cf99d4b6 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -1,7 +1,7 @@ import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from starlette.requests import Request @@ -30,20 +30,11 @@ def create_messages(context: str | None = None, topic: str | None = None) -> lis return messages -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-prompt") - - @app.list_prompts() - async def list_prompts() -> list[types.Prompt]: - return [ +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ types.Prompt( name="simple", title="Simple Assistant Prompt", @@ -62,19 +53,35 @@ async def list_prompts() -> list[types.Prompt]: ], ) ] + ) - @app.get_prompt() - async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: - if name != "simple": - raise ValueError(f"Unknown prompt: {name}") - if arguments is None: - arguments = {} +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + if params.name != "simple": + raise ValueError(f"Unknown prompt: {params.name}") - return types.GetPromptResult( - messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), - description="A simple prompt with optional context and topic arguments", - ) + arguments = params.arguments or {} + + return types.GetPromptResult( + messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), + description="A simple prompt with optional context and topic arguments", + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-prompt", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 588d1044a..b9b6a1d96 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -1,8 +1,9 @@ +from urllib.parse import urlparse + import anyio import click from mcp import types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext from starlette.requests import Request SAMPLE_RESOURCES = { @@ -21,20 +22,11 @@ } -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-simple-resource") - - @app.list_resources() - async def list_resources() -> list[types.Resource]: - return [ +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: + return types.ListResourcesResult( + resources=[ types.Resource( uri=f"file:///{name}.txt", name=name, @@ -44,20 +36,45 @@ async def list_resources() -> list[types.Resource]: ) for name in SAMPLE_RESOURCES.keys() ] + ) - @app.read_resource() - async def read_resource(uri: str): - from urllib.parse import urlparse - parsed = urlparse(uri) - if not parsed.path: - raise ValueError(f"Invalid resource path: {uri}") - name = parsed.path.replace(".txt", "").lstrip("/") +async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams +) -> types.ReadResourceResult: + parsed = urlparse(str(params.uri)) + if not parsed.path: + raise ValueError(f"Invalid resource path: {params.uri}") + name = parsed.path.replace(".txt", "").lstrip("/") - if name not in SAMPLE_RESOURCES: - raise ValueError(f"Unknown resource: {uri}") + if name not in SAMPLE_RESOURCES: + raise ValueError(f"Unknown resource: {params.uri}") - return [ReadResourceContents(content=SAMPLE_RESOURCES[name]["content"], mime_type="text/plain")] + return types.ReadResourceResult( + contents=[ + types.TextResourceContents( + uri=str(params.uri), + text=SAMPLE_RESOURCES[name]["content"], + mime_type="text/plain", + ) + ] + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-simple-resource", + on_list_resources=handle_list_resources, + on_read_resource=handle_read_resource, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 9fed2f0aa..cb4a6503c 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,13 +1,12 @@ import contextlib import logging from collections.abc import AsyncIterator -from typing import Any import anyio import click import uvicorn from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware @@ -17,6 +16,64 @@ logger = logging.getLogger(__name__) +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="start-notification-stream", + description=("Sends a stream of notifications with configurable count and interval"), + input_schema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ("Identifier of the caller to include in notifications"), + }, + }, + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + arguments = params.arguments or {} + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i + 1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), + ) + ] + ) + + @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @click.option( @@ -41,59 +98,11 @@ def main( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - app = Server("mcp-streamable-http-stateless-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - await ctx.session.send_log_message( - level="info", - data=f"Notification {i + 1}/{count} from caller: {caller}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) - - return [ - types.TextContent( - type="text", - text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), - ) - ] - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="start-notification-stream", - description=("Sends a stream of notifications with configurable count and interval"), - input_schema={ - "type": "object", - "required": ["interval", "count", "caller"], - "properties": { - "interval": { - "type": "number", - "description": "Interval between notifications in seconds", - }, - "count": { - "type": "number", - "description": "Number of notifications to send", - }, - "caller": { - "type": "string", - "description": ("Identifier of the caller to include in notifications"), - }, - }, - }, - ) - ] + app = Server( + "mcp-streamable-http-stateless-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Create the session manager with true stateless mode session_manager = StreamableHTTPSessionManager( diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index ef03d9b08..2f2a53b1b 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,12 +1,11 @@ import contextlib import logging from collections.abc import AsyncIterator -from typing import Any import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware @@ -19,6 +18,75 @@ logger = logging.getLogger(__name__) +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="start-notification-stream", + description="Sends a stream of notifications with configurable count and interval", + input_schema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": "Identifier of the caller to include in notifications", + }, + }, + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + arguments = params.arguments or {} + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" + await ctx.session.send_log_message( + level="info", + data=notification_msg, + logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) + related_request_id=ctx.request_id, + ) + logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + # This will send a resource notification through standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri="http:///test_resource") + return types.CallToolResult( + content=[ + types.TextContent( + type="text", + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), + ) + ] + ) + + @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @click.option( @@ -43,70 +111,11 @@ def main( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - app = Server("mcp-streamable-http-demo") - - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - ctx = app.request_context - interval = arguments.get("interval", 1.0) - count = arguments.get("count", 5) - caller = arguments.get("caller", "unknown") - - # Send the specified number of notifications with the given interval - for i in range(count): - # Include more detailed message for resumability demonstration - notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" - await ctx.session.send_log_message( - level="info", - data=notification_msg, - logger="notification_stream", - # Associates this notification with the original request - # Ensures notifications are sent to the correct response stream - # Without this, notifications will either go to: - # - a standalone SSE stream (if GET request is supported) - # - nowhere (if GET request isn't supported) - related_request_id=ctx.request_id, - ) - logger.debug(f"Sent notification {i + 1}/{count} for caller: {caller}") - if i < count - 1: # Don't wait after the last notification - await anyio.sleep(interval) - - # This will send a resource notificaiton though standalone SSE - # established by GET request - await ctx.session.send_resource_updated(uri="http:///test_resource") - return [ - types.TextContent( - type="text", - text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), - ) - ] - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="start-notification-stream", - description=("Sends a stream of notifications with configurable count and interval"), - input_schema={ - "type": "object", - "required": ["interval", "count", "caller"], - "properties": { - "interval": { - "type": "number", - "description": "Interval between notifications in seconds", - }, - "count": { - "type": "number", - "description": "Number of notifications to send", - }, - "caller": { - "type": "string", - "description": ("Identifier of the caller to include in notifications"), - }, - }, - }, - ) - ] + app = Server( + "mcp-streamable-http-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Create event store for resumability # The InMemoryEventStore enables resumability support for StreamableHTTP transport. diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py index dc689ed94..6938b6552 100644 --- a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -13,42 +13,39 @@ import click import uvicorn from mcp import types +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount -server = Server("simple-task-interactive") -# Enable task support - this auto-registers all handlers -server.experimental.enable_tasks() +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="confirm_delete", + description="Asks for confirmation before deleting (demonstrates elicitation)", + input_schema={ + "type": "object", + "properties": {"filename": {"type": "string"}}, + }, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + types.Tool( + name="write_haiku", + description="Asks LLM to write a haiku (demonstrates sampling)", + input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ), + ] + ) -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="confirm_delete", - description="Asks for confirmation before deleting (demonstrates elicitation)", - input_schema={ - "type": "object", - "properties": {"filename": {"type": "string"}}, - }, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - types.Tool( - name="write_haiku", - description="Asks LLM to write a haiku (demonstrates sampling)", - input_schema={"type": "object", "properties": {"topic": {"type": "string"}}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ), - ] - - -async def handle_confirm_delete(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_confirm_delete(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the confirm_delete tool - demonstrates elicitation.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) filename = arguments.get("filename", "unknown.txt") @@ -80,9 +77,8 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -async def handle_write_haiku(arguments: dict[str, Any]) -> types.CreateTaskResult: +async def handle_write_haiku(ctx: ServerRequestContext, arguments: dict[str, Any]) -> types.CreateTaskResult: """Handle the write_haiku tool - demonstrates sampling.""" - ctx = server.request_context ctx.experimental.validate_task_mode(types.TASK_REQUIRED) topic = arguments.get("topic", "nature") @@ -111,18 +107,31 @@ async def work(task: ServerTaskContext) -> types.CallToolResult: return await ctx.experimental.run_task(work) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: """Dispatch tool calls to their handlers.""" - if name == "confirm_delete": - return await handle_confirm_delete(arguments) - elif name == "write_haiku": - return await handle_write_haiku(arguments) - else: - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], - is_error=True, - ) + arguments = params.arguments or {} + + if params.name == "confirm_delete": + return await handle_confirm_delete(ctx, arguments) + elif params.name == "write_haiku": + return await handle_write_haiku(ctx, arguments) + + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], + is_error=True, + ) + + +server = Server( + "simple-task-interactive", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) + +# Enable task support - this auto-registers all handlers +server.experimental.enable_tasks() def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py index ec16b15ae..50ae3ca9a 100644 --- a/examples/servers/simple-task/mcp_simple_task/server.py +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -2,66 +2,68 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any import anyio import click import uvicorn from mcp import types +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext -from mcp.server.lowlevel import Server from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount -server = Server("simple-task-server") -# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task -server.experimental.enable_tasks() +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="long_running_task", + description="A task that takes a few seconds to complete with status updates", + input_schema={"type": "object", "properties": {}}, + execution=types.ToolExecution(task_support=types.TASK_REQUIRED), + ) + ] + ) -@server.list_tools() -async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="long_running_task", - description="A task that takes a few seconds to complete with status updates", - input_schema={"type": "object", "properties": {}}, - execution=types.ToolExecution(task_support=types.TASK_REQUIRED), - ) - ] +async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams +) -> types.CallToolResult | types.CreateTaskResult: + """Dispatch tool calls to their handlers.""" + if params.name == "long_running_task": + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Starting work...") + await anyio.sleep(1) -async def handle_long_running_task(arguments: dict[str, Any]) -> types.CreateTaskResult: - """Handle the long_running_task tool - demonstrates status updates.""" - ctx = server.request_context - ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + await task.update_status("Processing step 1...") + await anyio.sleep(1) - async def work(task: ServerTaskContext) -> types.CallToolResult: - await task.update_status("Starting work...") - await anyio.sleep(1) + await task.update_status("Processing step 2...") + await anyio.sleep(1) - await task.update_status("Processing step 1...") - await anyio.sleep(1) + return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) - await task.update_status("Processing step 2...") - await anyio.sleep(1) + return await ctx.experimental.run_task(work) - return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")], + is_error=True, + ) - return await ctx.experimental.run_task(work) +server = Server( + "simple-task-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: - """Dispatch tool calls to their handlers.""" - if name == "long_running_task": - return await handle_long_running_task(arguments) - else: - return types.CallToolResult( - content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], - is_error=True, - ) +# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task +server.experimental.enable_tasks() @click.command() diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 1c253a22e..9fe71e5b7 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -1,9 +1,7 @@ -from typing import Any - import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.shared._httpx_utils import create_mcp_http_client from starlette.requests import Request @@ -18,28 +16,11 @@ async def fetch_website( return [types.TextContent(type="text", text=response.text)] -@click.command() -@click.option("--port", default=8000, help="Port to listen on for SSE") -@click.option( - "--transport", - type=click.Choice(["stdio", "sse"]), - default="stdio", - help="Transport type", -) -def main(port: int, transport: str) -> int: - app = Server("mcp-website-fetcher") - - @app.call_tool() - async def fetch_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - if name != "fetch": - raise ValueError(f"Unknown tool: {name}") - if "url" not in arguments: - raise ValueError("Missing required argument 'url'") - return await fetch_website(arguments["url"]) - - @app.list_tools() - async def list_tools() -> list[types.Tool]: - return [ +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ types.Tool( name="fetch", title="Website Fetcher", @@ -56,6 +37,33 @@ async def list_tools() -> list[types.Tool]: }, ) ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name != "fetch": + raise ValueError(f"Unknown tool: {params.name}") + arguments = params.arguments or {} + if "url" not in arguments: + raise ValueError("Missing required argument 'url'") + content = await fetch_website(arguments["url"]) + return types.CallToolResult(content=content) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on for SSE") +@click.option( + "--transport", + type=click.Choice(["stdio", "sse"]), + default="stdio", + help="Transport type", +) +def main(port: int, transport: str) -> int: + app = Server( + "mcp-website-fetcher", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py index 9d7071ca7..c8178c35a 100644 --- a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -15,12 +15,11 @@ import contextlib import logging from collections.abc import AsyncIterator -from typing import Any import anyio import click from mcp import types -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount @@ -31,111 +30,124 @@ logger = logging.getLogger(__name__) -@click.command() -@click.option("--port", default=3000, help="Port to listen on") -@click.option( - "--log-level", - default="INFO", - help="Logging level (DEBUG, INFO, WARNING, ERROR)", -) -@click.option( - "--retry-interval", - default=100, - help="SSE retry interval in milliseconds (sent to client)", -) -def main(port: int, log_level: str, retry_interval: int) -> int: - """Run the SSE Polling Demo server.""" - logging.basicConfig( - level=getattr(logging, log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + """List available tools.""" + return types.ListToolsResult( + tools=[ + types.Tool( + name="process_batch", + description=( + "Process a batch of items with periodic checkpoints. " + "Demonstrates SSE polling where server closes stream periodically." + ), + input_schema={ + "type": "object", + "properties": { + "items": { + "type": "integer", + "description": "Number of items to process (1-100)", + "default": 10, + }, + "checkpoint_every": { + "type": "integer", + "description": "Close stream after this many items (1-20)", + "default": 3, + }, + }, + }, + ) + ] ) - # Create the lowlevel server - app = Server("sse-polling-demo") - @app.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: - """Handle tool calls.""" - ctx = app.request_context +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + """Handle tool calls.""" + arguments = params.arguments or {} - if name == "process_batch": - items = arguments.get("items", 10) - checkpoint_every = arguments.get("checkpoint_every", 3) + if params.name == "process_batch": + items = arguments.get("items", 10) + checkpoint_every = arguments.get("checkpoint_every", 3) - if items < 1 or items > 100: - return [types.TextContent(type="text", text="Error: items must be between 1 and 100")] - if checkpoint_every < 1 or checkpoint_every > 20: - return [types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + if items < 1 or items > 100: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: items must be between 1 and 100")] + ) + if checkpoint_every < 1 or checkpoint_every > 20: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + ) + + await ctx.session.send_log_message( + level="info", + data=f"Starting batch processing of {items} items...", + logger="process_batch", + related_request_id=ctx.request_id, + ) + for i in range(1, items + 1): + # Simulate work + await anyio.sleep(0.5) + + # Report progress await ctx.session.send_log_message( level="info", - data=f"Starting batch processing of {items} items...", + data=f"[{i}/{items}] Processing item {i}", logger="process_batch", related_request_id=ctx.request_id, ) - for i in range(1, items + 1): - # Simulate work - await anyio.sleep(0.5) - - # Report progress + # Checkpoint: close stream to trigger client reconnect + if i % checkpoint_every == 0 and i < items: await ctx.session.send_log_message( level="info", - data=f"[{i}/{items}] Processing item {i}", + data=f"Checkpoint at item {i} - closing SSE stream for polling", logger="process_batch", related_request_id=ctx.request_id, ) - - # Checkpoint: close stream to trigger client reconnect - if i % checkpoint_every == 0 and i < items: - await ctx.session.send_log_message( - level="info", - data=f"Checkpoint at item {i} - closing SSE stream for polling", - logger="process_batch", - related_request_id=ctx.request_id, - ) - if ctx.close_sse_stream: - logger.info(f"Closing SSE stream at checkpoint {i}") - await ctx.close_sse_stream() - # Wait for client to reconnect (must be > retry_interval of 100ms) - await anyio.sleep(0.2) - - return [ + if ctx.close_sse_stream: + logger.info(f"Closing SSE stream at checkpoint {i}") + await ctx.close_sse_stream() + # Wait for client to reconnect (must be > retry_interval of 100ms) + await anyio.sleep(0.2) + + return types.CallToolResult( + content=[ types.TextContent( type="text", text=f"Successfully processed {items} items with checkpoints every {checkpoint_every} items", ) ] + ) - return [types.TextContent(type="text", text=f"Unknown tool: {name}")] + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Unknown tool: {params.name}")]) - @app.list_tools() - async def list_tools() -> list[types.Tool]: - """List available tools.""" - return [ - types.Tool( - name="process_batch", - description=( - "Process a batch of items with periodic checkpoints. " - "Demonstrates SSE polling where server closes stream periodically." - ), - input_schema={ - "type": "object", - "properties": { - "items": { - "type": "integer", - "description": "Number of items to process (1-100)", - "default": 10, - }, - "checkpoint_every": { - "type": "integer", - "description": "Close stream after this many items (1-20)", - "default": 3, - }, - }, - }, - ) - ] + +@click.command() +@click.option("--port", default=3000, help="Port to listen on") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR)", +) +@click.option( + "--retry-interval", + default=100, + help="SSE retry interval in milliseconds (sent to client)", +) +def main(port: int, log_level: str, retry_interval: int) -> int: + """Run the SSE Polling Demo server.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server( + "sse-polling-demo", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Create event store for resumability event_store = InMemoryEventStore() diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py index fd73a54cd..95fb90854 100644 --- a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py @@ -2,60 +2,54 @@ """Example low-level MCP server demonstrating structured output support. This example shows how to use the low-level server API to return -structured data from tools, with automatic validation against output -schemas. +structured data from tools. """ import asyncio +import json +import random from datetime import datetime -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -# Create low-level server instance -server = Server("structured-output-lowlevel-example") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with their schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get weather information (simulated)", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number"}, - "conditions": {"type": "string"}, - "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, - "wind_speed": {"type": "number"}, - "timestamp": {"type": "string", "format": "date-time"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get weather information (simulated)", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], + }, + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number"}, + "conditions": {"type": "string"}, + "humidity": {"type": "integer", "minimum": 0, "maximum": 100}, + "wind_speed": {"type": "number"}, + "timestamp": {"type": "string", "format": "date-time"}, + }, + "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], }, - "required": ["temperature", "conditions", "humidity", "wind_speed", "timestamp"], - }, - ), - ] + ), + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> Any: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool call with structured output.""" - if name == "get_weather": - # city = arguments["city"] # Would be used with real weather API - + if params.name == "get_weather": # Simulate weather data (in production, call a real weather API) - import random - weather_conditions = ["sunny", "cloudy", "rainy", "partly cloudy", "foggy"] weather_data = { @@ -66,12 +60,19 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "timestamp": datetime.now().isoformat(), } - # Return structured data only - # The low-level server will serialize this to JSON content automatically - return weather_data + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") - else: - raise ValueError(f"Unknown tool: {name}") + +server = Server( + "structured-output-lowlevel-example", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -80,14 +81,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-lowlevel-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index 9888c588e..2aecbeeee 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -24,8 +24,6 @@ import asyncio import json -import subprocess -import sys import webbrowser from typing import Any from urllib.parse import urlparse @@ -56,15 +54,19 @@ async def handle_elicitation( ) +ALLOWED_SCHEMES = {"http", "https"} + + async def handle_url_elicitation( params: types.ElicitRequestParams, ) -> types.ElicitResult: """Handle URL mode elicitation - show security warning and optionally open browser. This function demonstrates the security-conscious approach to URL elicitation: - 1. Display the full URL and domain for user inspection - 2. Show the server's reason for requesting this interaction - 3. Require explicit user consent before opening any URL + 1. Validate the URL scheme before prompting the user + 2. Display the full URL and domain for user inspection + 3. Show the server's reason for requesting this interaction + 4. Require explicit user consent before opening any URL """ # Extract URL parameters - these are available on URL mode requests url = getattr(params, "url", None) @@ -75,6 +77,12 @@ async def handle_url_elicitation( print("Error: No URL provided in elicitation request") return types.ElicitResult(action="cancel") + # Reject dangerous URL schemes before prompting the user + parsed = urlparse(str(url)) + if parsed.scheme.lower() not in ALLOWED_SCHEMES: + print(f"\nRejecting URL with disallowed scheme '{parsed.scheme}': {url}") + return types.ElicitResult(action="decline") + # Extract domain for security display domain = extract_domain(url) @@ -105,7 +113,11 @@ async def handle_url_elicitation( # Open the browser print(f"\nOpening browser to: {url}") - open_browser(url) + try: + webbrowser.open(url) + except Exception as e: + print(f"Failed to open browser: {e}") + print(f"Please manually open: {url}") print("Waiting for you to complete the interaction in your browser...") print("(The server will continue once you've finished)") @@ -121,20 +133,6 @@ def extract_domain(url: str) -> str: return "unknown" -def open_browser(url: str) -> None: - """Open URL in the default browser.""" - try: - if sys.platform == "darwin": - subprocess.run(["open", url], check=False) - elif sys.platform == "win32": - subprocess.run(["start", url], shell=True, check=False) - else: - webbrowser.open(url) - except Exception as e: - print(f"Failed to open browser: {e}") - print(f"Please manually open: {url}") - - async def call_tool_with_error_handling( session: ClientSession, tool_name: str, diff --git a/examples/snippets/servers/lowlevel/basic.py b/examples/snippets/servers/lowlevel/basic.py index 0d4432504..81f40e994 100644 --- a/examples/snippets/servers/lowlevel/basic.py +++ b/examples/snippets/servers/lowlevel/basic.py @@ -6,32 +6,30 @@ import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -# Create a server instance -server = Server("example-server") - -@server.list_prompts() -async def handle_list_prompts() -> list[types.Prompt]: +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: """List available prompts.""" - return [ - types.Prompt( - name="example-prompt", - description="An example prompt template", - arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], - ) - ] + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="example-prompt", + description="An example prompt template", + arguments=[types.PromptArgument(name="arg1", description="Example argument", required=True)], + ) + ] + ) -@server.get_prompt() -async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> types.GetPromptResult: +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: """Get a specific prompt by name.""" - if name != "example-prompt": - raise ValueError(f"Unknown prompt: {name}") + if params.name != "example-prompt": + raise ValueError(f"Unknown prompt: {params.name}") - arg1_value = (arguments or {}).get("arg1", "default") + arg1_value = (params.arguments or {}).get("arg1", "default") return types.GetPromptResult( description="Example prompt", @@ -44,20 +42,20 @@ async def handle_get_prompt(name: str, arguments: dict[str, str] | None) -> type ) +server = Server( + "example-server", + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, +) + + async def run(): """Run the basic low-level server.""" async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py index 725f5711a..7e8fc4dcb 100644 --- a/examples/snippets/servers/lowlevel/direct_call_tool_result.py +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -3,44 +3,49 @@ """ import asyncio -from typing import Any import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="advanced_tool", - description="Tool with full control including _meta field", - input_schema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ) - ] - - -@server.call_tool() -async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls by returning CallToolResult directly.""" - if name == "advanced_tool": - message = str(arguments.get("message", "")) + if params.name == "advanced_tool": + message = (params.arguments or {}).get("message", "") return types.CallToolResult( content=[types.TextContent(type="text", text=f"Processed: {message}")], structured_content={"result": "success", "message": message}, _meta={"hidden": "data for client applications only"}, ) - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -49,14 +54,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index da8ff7bdf..bcd96c893 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -4,12 +4,11 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import TypedDict import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext # Mock database class for example @@ -32,52 +31,58 @@ async def query(self, query_str: str) -> list[dict[str, str]]: return [{"id": "1", "name": "Example", "query": query_str}] +class AppContext(TypedDict): + db: Database + + @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: +async def server_lifespan(_server: Server[AppContext]) -> AsyncIterator[AppContext]: """Manage server startup and shutdown lifecycle.""" - # Initialize resources on startup db = await Database.connect() try: yield {"db": db} finally: - # Clean up on shutdown await db.disconnect() -# Pass lifespan to server -server = Server("example-server", lifespan=server_lifespan) - - -@server.list_tools() -async def handle_list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools.""" - return [ - types.Tool( - name="query_db", - description="Query the database", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, - "required": ["query"], - }, - ) - ] - - -@server.call_tool() -async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: + return types.ListToolsResult( + tools=[ + types.Tool( + name="query_db", + description="Query the database", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "SQL query to execute"}}, + "required": ["query"], + }, + ) + ] + ) + + +async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: types.CallToolRequestParams +) -> types.CallToolResult: """Handle database query tool call.""" - if name != "query_db": - raise ValueError(f"Unknown tool: {name}") + if params.name != "query_db": + raise ValueError(f"Unknown tool: {params.name}") - # Access lifespan context - ctx = server.request_context db = ctx.lifespan_context["db"] + results = await db.query((params.arguments or {})["query"]) + + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Query results: {results}")]) - # Execute query - results = await db.query(arguments["query"]) - return [types.TextContent(type="text", text=f"Query results: {results}")] +server = Server( + "example-server", + lifespan=server_lifespan, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -86,14 +91,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="example-server", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/lowlevel/structured_output.py b/examples/snippets/servers/lowlevel/structured_output.py index cad8f67da..f93c8875f 100644 --- a/examples/snippets/servers/lowlevel/structured_output.py +++ b/examples/snippets/servers/lowlevel/structured_output.py @@ -3,62 +3,67 @@ """ import asyncio -from typing import Any +import json import mcp.server.stdio from mcp import types -from mcp.server.lowlevel import NotificationOptions, Server -from mcp.server.models import InitializationOptions +from mcp.server import Server, ServerRequestContext -server = Server("example-server") - -@server.list_tools() -async def list_tools() -> list[types.Tool]: +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: """List available tools with structured output schemas.""" - return [ - types.Tool( - name="get_weather", - description="Get current weather for a city", - input_schema={ - "type": "object", - "properties": {"city": {"type": "string", "description": "City name"}}, - "required": ["city"], - }, - output_schema={ - "type": "object", - "properties": { - "temperature": {"type": "number", "description": "Temperature in Celsius"}, - "condition": {"type": "string", "description": "Weather condition"}, - "humidity": {"type": "number", "description": "Humidity percentage"}, - "city": {"type": "string", "description": "City name"}, + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Get current weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string", "description": "City name"}}, + "required": ["city"], }, - "required": ["temperature", "condition", "humidity", "city"], - }, - ) - ] + output_schema={ + "type": "object", + "properties": { + "temperature": {"type": "number", "description": "Temperature in Celsius"}, + "condition": {"type": "string", "description": "Weather condition"}, + "humidity": {"type": "number", "description": "Humidity percentage"}, + "city": {"type": "string", "description": "City name"}, + }, + "required": ["temperature", "condition", "humidity", "city"], + }, + ) + ] + ) -@server.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: """Handle tool calls with structured output.""" - if name == "get_weather": - city = arguments["city"] + if params.name == "get_weather": + city = (params.arguments or {})["city"] - # Simulated weather data - in production, call a weather API weather_data = { "temperature": 22.5, "condition": "partly cloudy", "humidity": 65, - "city": city, # Include the requested city + "city": city, } - # low-level server will validate structured output against the tool's - # output schema, and additionally serialize it into a TextContent block - # for backwards compatibility with pre-2025-06-18 clients. - return weather_data - else: - raise ValueError(f"Unknown tool: {name}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=json.dumps(weather_data, indent=2))], + structured_content=weather_data, + ) + + raise ValueError(f"Unknown tool: {params.name}") + + +server = Server( + "example-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, +) async def run(): @@ -67,14 +72,7 @@ async def run(): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="structured-output-example", - server_version="0.1.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) diff --git a/examples/snippets/servers/pagination_example.py b/examples/snippets/servers/pagination_example.py index bb406653e..bcd0ffb10 100644 --- a/examples/snippets/servers/pagination_example.py +++ b/examples/snippets/servers/pagination_example.py @@ -1,22 +1,20 @@ -"""Example of implementing pagination with MCP server decorators.""" +"""Example of implementing pagination with the low-level MCP server.""" from mcp import types -from mcp.server.lowlevel import Server - -# Initialize the server -server = Server("paginated-server") +from mcp.server import Server, ServerRequestContext # Sample data to paginate ITEMS = [f"Item {i}" for i in range(1, 101)] # 100 items -@server.list_resources() -async def list_resources_paginated(request: types.ListResourcesRequest) -> types.ListResourcesResult: +async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListResourcesResult: """List resources with pagination support.""" page_size = 10 # Extract cursor from request params - cursor = request.params.cursor if request.params is not None else None + cursor = params.cursor if params is not None else None # Parse cursor to get offset start = 0 if cursor is None else int(cursor) @@ -32,3 +30,6 @@ async def list_resources_paginated(request: types.ListResourcesRequest) -> types next_cursor = str(end) if end < len(ITEMS) else None return types.ListResourcesResult(resources=page_items, next_cursor=next_cursor) + + +server = Server("paginated-server", on_list_resources=handle_list_resources) diff --git a/mkdocs.yml b/mkdocs.yml index 3019f5214..070c533e3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -112,7 +112,8 @@ watch: plugins: - search - - social + - social: + enabled: !ENV [ENABLE_SOCIAL_CARDS, false] - glightbox - mkdocstrings: handlers: diff --git a/pyproject.toml b/pyproject.toml index bfc306713..737839a23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ dev = [ docs = [ "mkdocs>=1.6.1", "mkdocs-glightbox>=0.4.0", - "mkdocs-material>=9.5.45", + "mkdocs-material[imaging]>=9.5.45", "mkdocstrings-python>=2.0.1", ] @@ -173,6 +173,8 @@ xfail_strict = true addopts = """ --color=yes --capture=fd + -p anyio + -p examples """ filterwarnings = [ "error", diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 858ab7db2..62334a4a2 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -317,12 +317,12 @@ def run( ) -> None: # pragma: no cover """Run an MCP server. - The server can be specified in two ways:\n - 1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n - 2. Import approach: server.py:app - imports and runs the specified server object.\n\n + The server can be specified in two ways: + 1. Module approach: server.py - runs the module directly, expecting a server.run() call. + 2. Import approach: server.py:app - imports and runs the specified server object. Note: This command runs the server directly. You are responsible for ensuring - all dependencies are available.\n + all dependencies are available. For dependency management, use `mcp install` or `mcp dev` instead. """ # noqa: E501 file, server_object = _parse_file_path(file_spec) diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index 07f6180bf..cb6dafb40 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -450,7 +450,7 @@ def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # prag # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2 token_data["client_assertion"] = assertion token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" - # We need to set the audience to the resource server, the audience is difference from the one in claims + # We need to set the audience to the resource server, the audience is different from the one in claims # it represents the resource server that will validate the token token_data["audience"] = self.context.get_resource_url() diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 41aecc6f2..7f5af5186 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -215,6 +215,7 @@ def prepare_token_auth( class OAuthClientProvider(httpx.Auth): """OAuth2 authentication for httpx. + Handles OAuth flow with automatic client registration and token storage. """ @@ -241,7 +242,7 @@ def __init__( callback_handler: Handler for authorization callbacks. timeout: Timeout for the OAuth flow. client_metadata_url: URL-based client ID. When provided and the server - advertises client_id_metadata_document_supported=true, this URL will be + advertises client_id_metadata_document_supported=True, this URL will be used as the client_id instead of performing dynamic client registration. Must be a valid HTTPS URL with a non-root pathname. validate_resource_url: Optional callback to override resource URL validation. @@ -493,12 +494,6 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None if not prm_resource: return # pragma: no cover default_resource = resource_url_from_server_url(self.context.server_url) - # Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs - # (e.g. "https://example.com/") while resource_url_from_server_url may not. - if not default_resource.endswith("/"): - default_resource += "/" - if not prm_resource.endswith("/"): - prm_resource += "/" if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource): raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 1aa960b9c..0ca36b98d 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -38,7 +38,7 @@ def extract_field_from_www_auth(response: Response, field_name: str) -> str | No def extract_scope_from_www_auth(response: Response) -> str | None: - """Extract scope parameter from WWW-Authenticate header as per RFC6750. + """Extract scope parameter from WWW-Authenticate header as per RFC 6750. Returns: Scope string if found in WWW-Authenticate header, None otherwise @@ -47,7 +47,7 @@ def extract_scope_from_www_auth(response: Response) -> str | None: def extract_resource_metadata_from_www_auth(response: Response) -> str | None: - """Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + """Extract protected resource metadata URL from WWW-Authenticate header as per RFC 9728. Returns: Resource metadata URL if found in WWW-Authenticate header, None otherwise @@ -67,8 +67,8 @@ def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, s 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource Args: - www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header - server_url: server url + www_auth_url: Optional resource_metadata URL extracted from the WWW-Authenticate header + server_url: Server URL Returns: Ordered list of URLs to try for discovery @@ -120,10 +120,10 @@ def get_client_metadata_scopes( def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts. + """Generate an ordered list of URLs for authorization server metadata discovery. Args: - auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None + auth_server_url: OAuth Authorization Server Metadata URL if found, otherwise None server_url: URL for the MCP server, used as a fallback if auth_server_url is None """ @@ -170,7 +170,7 @@ async def handle_protected_resource_response( Per SEP-985, supports fallback when discovery fails at one URL. Returns: - True if metadata was successfully discovered, False if we should try next URL + ProtectedResourceMetadata if successfully discovered, None if we should try next URL """ if response.status_code == 200: try: @@ -206,7 +206,7 @@ def create_oauth_metadata_request(url: str) -> Request: def create_client_registration_request( auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str ) -> Request: - """Build registration request or skip if already registered.""" + """Build a client registration request.""" if auth_server_metadata and auth_server_metadata.registration_endpoint: registration_url = str(auth_server_metadata.registration_endpoint) @@ -261,7 +261,7 @@ def should_use_client_metadata_url( """Determine if URL-based client ID (CIMD) should be used instead of DCR. URL-based client IDs should be used when: - 1. The server advertises client_id_metadata_document_supported=true + 1. The server advertises client_id_metadata_document_supported=True 2. The client has a valid client_metadata_url configured Args: @@ -306,7 +306,7 @@ def create_client_info_from_metadata_url( async def handle_token_response_scopes( response: Response, ) -> OAuthToken: - """Parse and validate token response with optional scope validation. + """Parse and validate a token response. Parses token response JSON. Callers should check response.status_code before calling. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 29d4a7035..7dc67c584 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -37,8 +37,8 @@ class Client: """A high-level MCP client for connecting to MCP servers. - Currently supports in-memory transport for testing. Pass a Server or - MCPServer instance directly to the constructor. + Supports in-memory transport for testing (pass a Server or MCPServer instance), + Streamable HTTP transport (pass a URL string), or a custom Transport instance. Example: ```python @@ -205,7 +205,7 @@ async def read_resource(self, uri: str, *, meta: RequestParamsMeta | None = None Args: uri: The URI of the resource to read. - meta: Additional metadata for the request + meta: Additional metadata for the request. Returns: The resource content. @@ -239,7 +239,7 @@ async def call_tool( meta: Additional metadata for the request Returns: - The tool result + The tool result. """ return await self.session.call_tool( name=name, diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py index 28ff2b1f2..0ab513236 100644 --- a/src/mcp/client/experimental/task_handlers.py +++ b/src/mcp/client/experimental/task_handlers.py @@ -187,11 +187,13 @@ class ExperimentalTaskHandlers: WARNING: These APIs are experimental and may change without notice. Example: + ```python handlers = ExperimentalTaskHandlers( get_task=my_get_task_handler, list_tasks=my_list_tasks_handler, ) session = ClientSession(..., experimental_task_handlers=handlers) + ``` """ # Pure task request handlers diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py index 8ddc4face..a566df766 100644 --- a/src/mcp/client/experimental/tasks.py +++ b/src/mcp/client/experimental/tasks.py @@ -5,6 +5,7 @@ WARNING: These APIs are experimental and may change without notice. Example: + ```python # Call a tool as a task result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) task_id = result.task.task_id @@ -21,6 +22,7 @@ # Cancel a task await session.experimental.cancel_task(task_id) + ``` """ from collections.abc import AsyncIterator @@ -72,6 +74,7 @@ async def call_tool_as_task( CreateTaskResult containing the task reference Example: + ```python # Create task result = await session.experimental.call_tool_as_task( "long_running_tool", {"input": "data"} @@ -83,10 +86,11 @@ async def call_tool_as_task( status = await session.experimental.get_task(task_id) if status.status == "completed": break - await asyncio.sleep(0.5) + await anyio.sleep(0.5) # Get result final = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ return await self._session.send_request( types.CallToolRequest( @@ -177,7 +181,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: """Poll a task until it reaches a terminal status. Yields GetTaskResult for each poll, allowing the caller to react to - status changes (e.g., handle input_required). Exits when task reaches + status changes (e.g., handle input_required). Exits when the task reaches a terminal status (completed, failed, cancelled). Respects the pollInterval hint from the server. @@ -189,6 +193,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: GetTaskResult for each poll Example: + ```python async for status in session.experimental.poll_task(task_id): print(f"Status: {status.status}") if status.status == "input_required": @@ -197,6 +202,7 @@ async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: # Task is now terminal, get the result result = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ async for status in poll_until_terminal(self.get_task, task_id): yield status diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0687f98c3..a0ca751bd 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -206,8 +206,10 @@ def experimental(self) -> ExperimentalClientFeatures: These APIs are experimental and may change without notice. Example: + ```python status = await session.experimental.get_task(task_id) result = await session.experimental.get_task_result(task_id, CallToolResult) + ``` """ if self._experimental_features is None: self._experimental_features = ExperimentalClientFeatures(self) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index f4e6293b7..961021264 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -3,7 +3,7 @@ Tools, resources, and prompts are aggregated across servers. Servers may be connected to or disconnected from at any point after initialization. -This abstractions can handle naming collisions using a custom user-provided hook. +This abstraction can handle naming collisions using a custom user-provided hook. """ import contextlib @@ -30,7 +30,7 @@ class SseServerParameters(BaseModel): - """Parameters for initializing a sse_client.""" + """Parameters for initializing an sse_client.""" # The endpoint URL. url: str @@ -67,8 +67,8 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters -# Use dataclass instead of pydantic BaseModel -# because pydantic BaseModel cannot handle Protocol fields. +# Use dataclass instead of Pydantic BaseModel +# because Pydantic BaseModel cannot handle Protocol fields. @dataclass class ClientSessionParameters: """Parameters for establishing a client session to an MCP server.""" @@ -91,13 +91,14 @@ class ClientSessionGroup: For auxiliary handlers, such as resource subscription, this is delegated to the client and can be accessed via the session. - Example Usage: + Example: + ```python name_fn = lambda name, server_info: f"{(server_info.name)}_{name}" async with ClientSessionGroup(component_name_hook=name_fn) as group: for server_param in server_params: await group.connect_to_server(server_param) ... - + ``` """ class _ComponentNames(BaseModel): @@ -119,7 +120,7 @@ class _ComponentNames(BaseModel): _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, server_info) for custom names. - # This is provide a means to mitigate naming conflicts across servers. + # This is to provide a means to mitigate naming conflicts across servers. # Example: (tool_name, server_info) => "{result.server_info.name}.{tool_name}" _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] _component_name_hook: _ComponentNameHook | None diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 8f8e4dadc..61026aa0c 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -47,6 +47,7 @@ async def sse_client( headers: Optional headers to include in requests. timeout: HTTP timeout for regular operations (in seconds). sse_read_timeout: Timeout for SSE read operations (in seconds). + httpx_client_factory: Factory function for creating the HTTPX client. auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ @@ -138,7 +139,7 @@ async def post_writer(endpoint_url: str): json=session_message.message.model_dump( by_alias=True, mode="json", - exclude_none=True, + exclude_unset=True, ), ) response.raise_for_status() diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 605c5ea24..902dc8576 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -87,9 +87,9 @@ class StdioServerParameters(BaseModel): encoding: str = "utf-8" """ - The text encoding used when sending/receiving messages to the server + The text encoding used when sending/receiving messages to the server. - defaults to utf-8 + Defaults to utf-8. """ encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" @@ -97,7 +97,7 @@ class StdioServerParameters(BaseModel): The text encoding error handler. See https://docs.python.org/3/library/codecs.html#codec-base-classes for - explanations of possible values + explanations of possible values. """ @@ -167,7 +167,7 @@ async def stdin_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9d45bec6e..9f3dd5e0b 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -19,6 +19,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( + INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR, ErrorData, @@ -259,7 +260,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: async with ctx.client.stream( "POST", self.url, - json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + json=message.model_dump(by_alias=True, mode="json", exclude_unset=True), headers=headers, ) as response: if response.status_code == 202: @@ -273,7 +274,13 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: await ctx.read_stream_writer.send(session_message) return - response.raise_for_status() + if response.status_code >= 400: + if isinstance(message, JSONRPCRequest): + error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") + session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) + await ctx.read_stream_writer.send(session_message) + return + if is_initialization: self._maybe_extract_session_id_from_response(response) @@ -351,7 +358,7 @@ async def _handle_sse_response( resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), is_initialization=is_initialization, ) - # If the SSE event indicates completion, like returning respose/error + # If the SSE event indicates completion, like returning response/error # break the loop if is_complete: await response.aclose() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index cf4b86e99..79e75fad1 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -25,8 +25,8 @@ async def websocket_client( (read_stream, write_stream) - read_stream: As you read from this stream, you'll receive either valid - JSONRPCMessage objects or Exception objects (when validation fails). - - write_stream: Write JSONRPCMessage objects to this stream to send them + SessionMessage objects or Exception objects (when validation fails). + - write_stream: Write SessionMessage objects to this stream to send them over the WebSocket to the server. """ @@ -65,7 +65,7 @@ async def ws_writer(): async with write_stream_reader: async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True) + msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) async with anyio.create_task_group() as tg: diff --git a/src/mcp/os/win32/utilities.py b/src/mcp/os/win32/utilities.py index fa4e4b399..0e188691f 100644 --- a/src/mcp/os/win32/utilities.py +++ b/src/mcp/os/win32/utilities.py @@ -138,9 +138,9 @@ async def create_windows_process( ) -> Process | FallbackProcess: """Creates a subprocess in a Windows-compatible way with Job Object support. - Attempt to use anyio's open_process for async subprocess creation. - In some cases this will throw NotImplementedError on Windows, e.g. - when using the SelectorEventLoop which does not support async subprocesses. + Attempts to use anyio's open_process for async subprocess creation. + In some cases this will throw NotImplementedError on Windows, e.g., + when using the SelectorEventLoop, which does not support async subprocesses. In that case, we fall back to using subprocess.Popen. The process is automatically added to a Job Object to ensure all child @@ -242,8 +242,9 @@ def _create_job_object() -> int | None: def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None: - """Try to assign a process to a job object. If assignment fails - for any reason, the job handle is closed. + """Try to assign a process to a job object. + + If assignment fails for any reason, the job handle is closed. """ if not job: return @@ -312,8 +313,8 @@ async def terminate_windows_process(process: Process | FallbackProcess): Note: On Windows, terminating a process with process.terminate() doesn't always guarantee immediate process termination. - So we give it 2s to exit, or we call process.kill() - which sends a SIGKILL equivalent signal. + If the process does not exit within 2 seconds, process.kill() is called + to send a SIGKILL-equivalent signal. Args: process: The process to terminate diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index a2dada3af..aab5c33f7 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -1,5 +1,6 @@ +from .context import ServerRequestContext from .lowlevel import NotificationOptions, Server from .mcpserver import MCPServer from .models import InitializationOptions -__all__ = ["Server", "MCPServer", "NotificationOptions", "InitializationOptions"] +__all__ = ["Server", "ServerRequestContext", "MCPServer", "NotificationOptions", "InitializationOptions"] diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 68a3392b4..4efd15400 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -15,7 +15,7 @@ class RevocationRequest(BaseModel): - """# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1""" + """See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1""" token: str token_type_hint: Literal["access_token", "refresh_token"] | None = None diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 8a6a1b518..2832f8352 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -19,15 +19,17 @@ def __init__(self, message: str): class ClientAuthenticator: """ClientAuthenticator is a callable which validates requests from a client application, used to verify /token calls. + If, during registration, the client requested to be issued a secret, the authenticator asserts that /token calls must be authenticated with - that same token. + that same secret. + NOTE: clients can opt for no authentication during registration, in which case this logic is skipped. """ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): - """Initialize the dependency. + """Initialize the authenticator. Args: provider: Provider to look up client information @@ -83,7 +85,7 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation elif client.token_endpoint_auth_method == "client_secret_post": raw_form_data = form_data.get("client_secret") - # form_data.get() can return a UploadFile or None, so we need to check if it's a string + # form_data.get() can return an UploadFile or None, so we need to check if it's a string if isinstance(raw_form_data, str): request_client_secret = str(raw_form_data) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 5eb577fd4..957082a85 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -131,8 +131,9 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None """ async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Called as part of the /authorize endpoint, and returns a URL that the client + """Handle the /authorize endpoint and return a URL that the client will be redirected to. + Many MCP implementations will redirect to a third-party provider to perform a second OAuth exchange with that provider. In this sort of setup, the client has an OAuth connection with the MCP server, and the MCP server has an OAuth @@ -151,7 +152,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat | | +------------+ - Implementations will need to define another handler on the MCP server return + Implementations will need to define another handler on the MCP server's return flow to perform the second redirect, and generate and store an authorization code as part of completing the OAuth authorization step. @@ -182,7 +183,7 @@ async def load_authorization_code( authorization_code: The authorization code to get the challenge for. Returns: - The AuthorizationCode, or None if not found + The AuthorizationCode, or None if not found. """ ... @@ -199,7 +200,7 @@ async def exchange_authorization_code( The OAuth token, containing access and refresh tokens. Raises: - TokenError: If the request is invalid + TokenError: If the request is invalid. """ ... @@ -234,18 +235,18 @@ async def exchange_refresh_token( The OAuth token, containing access and refresh tokens. Raises: - TokenError: If the request is invalid + TokenError: If the request is invalid. """ ... async def load_access_token(self, token: str) -> AccessTokenT | None: - """Loads an access token by its token. + """Loads an access token by its token string. Args: token: The access token to verify. Returns: - The AuthInfo, or None if the token is invalid. + The access token, or None if the token is invalid. """ async def revoke_token( @@ -261,7 +262,7 @@ async def revoke_token( provided. Args: - token: the token to revoke + token: The token to revoke. """ diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 08f735f36..a72e81947 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -25,25 +25,21 @@ def validate_issuer_url(url: AnyHttpUrl): """Validate that the issuer URL meets OAuth 2.0 requirements. Args: - url: The issuer URL to validate + url: The issuer URL to validate. Raises: - ValueError: If the issuer URL is invalid + ValueError: If the issuer URL is invalid. """ - # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing - if ( - url.scheme != "https" - and url.host != "localhost" - and (url.host is not None and not url.host.startswith("127.0.0.1")) - ): - raise ValueError("Issuer URL must be HTTPS") # pragma: no cover + # RFC 8414 requires HTTPS, but we allow loopback/localhost HTTP for testing + if url.scheme != "https" and url.host not in ("localhost", "127.0.0.1", "[::1]"): + raise ValueError("Issuer URL must be HTTPS") # No fragments or query parameters allowed if url.fragment: - raise ValueError("Issuer URL must not have a fragment") # pragma: no cover + raise ValueError("Issuer URL must not have a fragment") if url.query: - raise ValueError("Issuer URL must not have a query string") # pragma: no cover + raise ValueError("Issuer URL must not have a query string") AUTHORIZATION_PATH = "/authorize" @@ -217,6 +213,8 @@ def create_protected_resource_routes( resource_url: The URL of this resource server authorization_servers: List of authorization servers that can issue tokens scopes_supported: Optional list of scopes supported by this resource + resource_name: Optional human-readable name for this resource + resource_documentation: Optional URL to documentation for this resource Returns: List of Starlette routes for protected resource metadata diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 43b9d3800..d8e11d78b 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -10,7 +10,7 @@ from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback -LifespanContextT = TypeVar("LifespanContextT") +LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 58e9fe448..731c914ed 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -112,8 +112,8 @@ async def elicit_with_validation( This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. For sensitive data like credentials or OAuth flows, use elicit_url() instead. diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 80ae5912b..3eba65822 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -62,8 +62,8 @@ def validate_task_mode( """Validate that the request is compatible with the tool's task execution mode. Per MCP spec: - - "required": Clients MUST invoke as task. Server returns -32601 if not. - - "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do. + - "required": Clients MUST invoke as a task. Server returns -32601 if not. + - "forbidden" (or None): Clients MUST NOT invoke as a task. Server returns -32601 if they do. - "optional": Either is acceptable. Args: @@ -111,7 +111,7 @@ def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: """Check if this client can use a tool with the given task mode. Useful for filtering tool lists or providing warnings. - Returns False if tool requires "required" but client doesn't support tasks. + Returns False if the tool's task mode is "required" but the client doesn't support tasks. Args: tool_task_mode: The tool's execution.taskSupport value @@ -160,19 +160,18 @@ async def run_task( RuntimeError: If task support is not enabled or task_metadata is missing Example: - @server.call_tool() - async def handle_tool(name: str, args: dict): - ctx = server.request_context - + ```python + async def handle_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: async def work(task: ServerTaskContext) -> CallToolResult: result = await task.elicit( message="Are you sure?", - requestedSchema={"type": "object", ...} + requested_schema={"type": "object", ...} ) confirmed = result.content.get("confirm", False) return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) return await ctx.experimental.run_task(work) + ``` WARNING: This API is experimental and may change without notice. """ diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py index 9b626c986..1fc45badf 100644 --- a/src/mcp/server/experimental/task_context.py +++ b/src/mcp/server/experimental/task_context.py @@ -56,6 +56,7 @@ class ServerTaskContext: - Status notifications via the session Example: + ```python async def my_task_work(task: ServerTaskContext) -> CallToolResult: await task.update_status("Starting...") @@ -68,6 +69,7 @@ async def my_task_work(task: ServerTaskContext) -> CallToolResult: return CallToolResult(content=[TextContent(text="Done!")]) else: return CallToolResult(content=[TextContent(text="Cancelled")]) + ``` """ def __init__( diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 991221bd0..b2268bc1c 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -44,17 +44,14 @@ class TaskResultHandler: 5. Returns the final result Usage: - # Create handler with store and queue - handler = TaskResultHandler(task_store, message_queue) - - # Register it with the server - @server.experimental.get_task_result() - async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = server.request_context - return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) + async def handle_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + ... + + server.experimental.enable_tasks( + on_task_result=handle_task_result, + ) """ def __init__( diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index 23b5d9cc8..b54219504 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -31,14 +31,20 @@ class TaskSupport: - Manages a task group for background task execution Example: - # Simple in-memory setup + Simple in-memory setup: + + ```python server.experimental.enable_tasks() + ``` + + Custom store/queue for distributed systems: - # Custom store/queue for distributed systems + ```python server.experimental.enable_tasks( store=RedisTaskStore(redis_url), queue=RedisTaskMessageQueue(redis_url), ) + ``` """ store: TaskStore diff --git a/src/mcp/server/lowlevel/__init__.py b/src/mcp/server/lowlevel/__init__.py index 66df38991..37191ba1a 100644 --- a/src/mcp/server/lowlevel/__init__.py +++ b/src/mcp/server/lowlevel/__init__.py @@ -1,3 +1,3 @@ from .server import NotificationOptions, Server -__all__ = ["Server", "NotificationOptions"] +__all__ = ["NotificationOptions", "Server"] diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 9b472c023..5a907b640 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -7,10 +7,12 @@ import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from typing import Any, Generic +from typing_extensions import TypeVar + +from mcp.server.context import ServerRequestContext from mcp.server.experimental.task_support import TaskSupport -from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import MCPError from mcp.shared.experimental.tasks.helpers import cancel_task from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore @@ -18,16 +20,16 @@ from mcp.shared.experimental.tasks.store import TaskStore from mcp.types import ( INVALID_PARAMS, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, - ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, TasksCallCapability, @@ -36,13 +38,12 @@ TasksToolsCapability, ) -if TYPE_CHECKING: - from mcp.server.lowlevel.server import Server - logger = logging.getLogger(__name__) +LifespanResultT = TypeVar("LifespanResultT", default=Any) -class ExperimentalHandlers: + +class ExperimentalHandlers(Generic[LifespanResultT]): """Experimental request/notification handlers. WARNING: These APIs are experimental and may change without notice. @@ -50,13 +51,13 @@ class ExperimentalHandlers: def __init__( self, - server: Server, - request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], - notification_handlers: dict[type, Callable[..., Awaitable[None]]], - ): - self._server = server - self._request_handlers = request_handlers - self._notification_handlers = notification_handlers + add_request_handler: Callable[ + [str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None + ], + has_handler: Callable[[str], bool], + ) -> None: + self._add_request_handler = add_request_handler + self._has_handler = has_handler self._task_support: TaskSupport | None = None @property @@ -66,16 +67,13 @@ def task_support(self) -> TaskSupport | None: def update_capabilities(self, capabilities: ServerCapabilities) -> None: # Only add tasks capability if handlers are registered - if not any( - req_type in self._request_handlers - for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest] - ): + if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]): return capabilities.tasks = ServerTasksCapability() - if ListTasksRequest in self._request_handlers: + if self._has_handler("tasks/list"): capabilities.tasks.list = TasksListCapability() - if CancelTaskRequest in self._request_handlers: + if self._has_handler("tasks/cancel"): capabilities.tasks.cancel = TasksCancelCapability() capabilities.tasks.requests = ServerTasksRequestsCapability( @@ -86,28 +84,54 @@ def enable_tasks( self, store: TaskStore | None = None, queue: TaskMessageQueue | None = None, + *, + on_get_task: Callable[[ServerRequestContext[LifespanResultT], GetTaskRequestParams], Awaitable[GetTaskResult]] + | None = None, + on_task_result: Callable[ + [ServerRequestContext[LifespanResultT], GetTaskPayloadRequestParams], Awaitable[GetTaskPayloadResult] + ] + | None = None, + on_list_tasks: Callable[ + [ServerRequestContext[LifespanResultT], PaginatedRequestParams | None], Awaitable[ListTasksResult] + ] + | None = None, + on_cancel_task: Callable[ + [ServerRequestContext[LifespanResultT], CancelTaskRequestParams], Awaitable[CancelTaskResult] + ] + | None = None, ) -> TaskSupport: """Enable experimental task support. - This sets up the task infrastructure and auto-registers default handlers - for tasks/get, tasks/result, tasks/list, and tasks/cancel. + This sets up the task infrastructure and registers handlers for + tasks/get, tasks/result, tasks/list, and tasks/cancel. Custom handlers + can be provided via the on_* kwargs; any not provided will use defaults. Args: store: Custom TaskStore implementation (defaults to InMemoryTaskStore) queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) + on_get_task: Custom handler for tasks/get + on_task_result: Custom handler for tasks/result + on_list_tasks: Custom handler for tasks/list + on_cancel_task: Custom handler for tasks/cancel Returns: The TaskSupport configuration object Example: - # Simple in-memory setup + Simple in-memory setup: + + ```python server.experimental.enable_tasks() + ``` - # Custom store/queue for distributed systems + Custom store/queue for distributed systems: + + ```python server.experimental.enable_tasks( store=RedisTaskStore(redis_url), queue=RedisTaskMessageQueue(redis_url), ) + ``` WARNING: This API is experimental and may change without notice. """ @@ -117,24 +141,27 @@ def enable_tasks( queue = InMemoryTaskMessageQueue() self._task_support = TaskSupport(store=store, queue=queue) - - # Auto-register default handlers - self._register_default_task_handlers() - - return self._task_support - - def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" - assert self._task_support is not None - support = self._task_support - - # Register get_task handler if not already registered - if GetTaskRequest not in self._request_handlers: - - async def _default_get_task(req: GetTaskRequest) -> ServerResult: - task = await support.store.get_task(req.params.task_id) + task_support = self._task_support + + # Register user-provided handlers + if on_get_task is not None: + self._add_request_handler("tasks/get", on_get_task) + if on_task_result is not None: + self._add_request_handler("tasks/result", on_task_result) + if on_list_tasks is not None: + self._add_request_handler("tasks/list", on_list_tasks) + if on_cancel_task is not None: + self._add_request_handler("tasks/cancel", on_cancel_task) + + # Fill in defaults for any not provided + if not self._has_handler("tasks/get"): + + async def _default_get_task( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams + ) -> GetTaskResult: + task = await task_support.store.get_task(params.task_id) if task is None: - raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {req.params.task_id}") + raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}") return GetTaskResult( task_id=task.task_id, status=task.status, @@ -145,136 +172,39 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: poll_interval=task.poll_interval, ) - self._request_handlers[GetTaskRequest] = _default_get_task + self._add_request_handler("tasks/get", _default_get_task) - # Register get_task_result handler if not already registered - if GetTaskPayloadRequest not in self._request_handlers: + if not self._has_handler("tasks/result"): - async def _default_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - ctx = self._server.request_context - result = await support.handler.handle(req, ctx.session, ctx.request_id) + async def _default_get_task_result( + ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: + assert ctx.request_id is not None + req = GetTaskPayloadRequest(params=params) + result = await task_support.handler.handle(req, ctx.session, ctx.request_id) return result - self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result + self._add_request_handler("tasks/result", _default_get_task_result) - # Register list_tasks handler if not already registered - if ListTasksRequest not in self._request_handlers: + if not self._has_handler("tasks/list"): - async def _default_list_tasks(req: ListTasksRequest) -> ListTasksResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) + async def _default_list_tasks( + ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListTasksResult: + cursor = params.cursor if params else None + tasks, next_cursor = await task_support.store.list_tasks(cursor) return ListTasksResult(tasks=tasks, next_cursor=next_cursor) - self._request_handlers[ListTasksRequest] = _default_list_tasks - - # Register cancel_task handler if not already registered - if CancelTaskRequest not in self._request_handlers: - - async def _default_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - result = await cancel_task(support.store, req.params.task_id) - return result - - self._request_handlers[CancelTaskRequest] = _default_cancel_task - - def list_tasks( - self, - ) -> Callable[ - [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], - Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ]: - """Register a handler for listing tasks. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], - ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: - logger.debug("Registering handler for ListTasksRequest") - wrapper = create_call_wrapper(func, ListTasksRequest) - - async def handler(req: ListTasksRequest) -> ListTasksResult: - result = await wrapper(req) - return result - - self._request_handlers[ListTasksRequest] = handler - return func - - return decorator - - def get_task( - self, - ) -> Callable[ - [Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]] - ]: - """Register a handler for getting task status. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]], - ) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]: - logger.debug("Registering handler for GetTaskRequest") - wrapper = create_call_wrapper(func, GetTaskRequest) - - async def handler(req: GetTaskRequest) -> GetTaskResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskRequest] = handler - return func - - return decorator - - def get_task_result( - self, - ) -> Callable[ - [Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]], - Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ]: - """Register a handler for getting task results/payload. - - WARNING: This API is experimental and may change without notice. - """ - - def decorator( - func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], - ) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]: - logger.debug("Registering handler for GetTaskPayloadRequest") - wrapper = create_call_wrapper(func, GetTaskPayloadRequest) - - async def handler(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: - result = await wrapper(req) - return result - - self._request_handlers[GetTaskPayloadRequest] = handler - return func - - return decorator - - def cancel_task( - self, - ) -> Callable[ - [Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]], - Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ]: - """Register a handler for cancelling tasks. - - WARNING: This API is experimental and may change without notice. - """ + self._add_request_handler("tasks/list", _default_list_tasks) - def decorator( - func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], - ) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]: - logger.debug("Registering handler for CancelTaskRequest") - wrapper = create_call_wrapper(func, CancelTaskRequest) + if not self._has_handler("tasks/cancel"): - async def handler(req: CancelTaskRequest) -> CancelTaskResult: - result = await wrapper(req) + async def _default_cancel_task( + ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams + ) -> CancelTaskResult: + result = await cancel_task(task_support.store, params.task_id) return result - self._request_handlers[CancelTaskRequest] = handler - return func + self._add_request_handler("tasks/cancel", _default_cancel_task) - return decorator + return task_support diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py deleted file mode 100644 index d17697090..000000000 --- a/src/mcp/server/lowlevel/func_inspection.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from collections.abc import Callable -from typing import Any, TypeVar, get_type_hints - -T = TypeVar("T") -R = TypeVar("R") - - -def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]: - """Create a wrapper function that knows how to call func with the request object. - - Returns a wrapper function that takes the request and calls func appropriately. - - The wrapper handles three calling patterns: - 1. Positional-only parameter typed as request_type (no default): func(req) - 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) - 3. No request parameter or parameter with default: func() - """ - try: - sig = inspect.signature(func) - type_hints = get_type_hints(func) - except (ValueError, TypeError, NameError): # pragma: no cover - return lambda _: func() - - # Check for positional-only parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind == inspect.Parameter.POSITIONAL_ONLY: - param_type = type_hints.get(param_name) - if param_type == request_type: # pragma: no branch - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - # Found positional-only parameter with correct type and no default - return lambda req: func(req) - - # Check for any positional/keyword parameter typed as request_type - for param_name, param in sig.parameters.items(): - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): # pragma: no branch - param_type = type_hints.get(param_name) - if param_type == request_type: - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: # pragma: no cover - return lambda _: func() - - # Found keyword parameter with correct type and no default - # Need to capture param_name in closure properly - def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: - return lambda req: func(**{name: req}) - - return make_keyword_wrapper(param_name) - - # No request parameter found - use old style - return lambda _: func() diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7bd79bb37..aee644040 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -2,82 +2,49 @@ This module provides a framework for creating an MCP (Model Context Protocol) server. It allows you to easily define and handle various types of requests and notifications -in an asynchronous manner. +using constructor-based handler registration. Usage: -1. Create a Server instance: - server = Server("your_server_name") - -2. Define request handlers using decorators: - @server.list_prompts() - async def handle_list_prompts(request: types.ListPromptsRequest) -> types.ListPromptsResult: - # Implementation - - @server.get_prompt() - async def handle_get_prompt( - name: str, arguments: dict[str, str] | None - ) -> types.GetPromptResult: - # Implementation - - @server.list_tools() - async def handle_list_tools(request: types.ListToolsRequest) -> types.ListToolsResult: - # Implementation - - @server.call_tool() - async def handle_call_tool( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - # Implementation - - @server.list_resource_templates() - async def handle_list_resource_templates() -> list[types.ResourceTemplate]: - # Implementation - -3. Define notification handlers if needed: - @server.progress_notification() - async def handle_progress( - progress_token: str | int, progress: float, total: float | None, - message: str | None - ) -> None: - # Implementation - -4. Run the server: +1. Define handler functions: + async def my_list_tools(ctx, params): + return types.ListToolsResult(tools=[...]) + + async def my_call_tool(ctx, params): + return types.CallToolResult(content=[...]) + +2. Create a Server instance with on_* handlers: + server = Server( + "your_server_name", + on_list_tools=my_list_tools, + on_call_tool=my_call_tool, + ) + +3. Run the server: async def main(): async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, - InitializationOptions( - server_name="your_server_name", - server_version="your_version", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), + server.create_initialization_options(), ) asyncio.run(main()) -The Server class provides methods to register handlers for various MCP requests and -notifications. It automatically manages the request context and handles incoming -messages from the client. +The Server class dispatches incoming requests and notifications to registered +handler callables by method string. """ from __future__ import annotations -import base64 import contextvars -import json import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic, TypeAlias, cast +from typing import Any, Generic import anyio -import jsonschema from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.applications import Starlette from starlette.middleware import Middleware @@ -94,30 +61,20 @@ async def main(): from mcp.server.context import ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError +from mcp.shared.exceptions import MCPError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.shared.tool_name_validation import validate_and_warn_tool_name logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT", default=Any) -RequestT = TypeVar("RequestT", default=Any) - -# type aliases for tool call results -StructuredContent: TypeAlias = dict[str, Any] -UnstructuredContent: TypeAlias = Iterable[types.ContentBlock] -CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] -# This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[ServerRequestContext[Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[ServerRequestContext[Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -128,22 +85,24 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. - Args: - server: The server instance this lifespan is managing - Returns: An empty context object """ yield {} -class Server(Generic[LifespanResultT, RequestT]): +async def _ping_handler(ctx: ServerRequestContext[Any], params: types.RequestParams | None) -> types.EmptyResult: + return types.EmptyResult() + + +class Server(Generic[LifespanResultT]): def __init__( self, name: str, + *, version: str | None = None, title: str | None = None, description: str | None = None, @@ -151,9 +110,80 @@ def __init__( website_url: str | None = None, icons: list[types.Icon] | None = None, lifespan: Callable[ - [Server[LifespanResultT, RequestT]], + [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + # Request handlers + on_list_tools: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListToolsResult], + ] + | None = None, + on_call_tool: Callable[ + [ServerRequestContext[LifespanResultT], types.CallToolRequestParams], + Awaitable[types.CallToolResult | types.CreateTaskResult], + ] + | None = None, + on_list_resources: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourcesResult], + ] + | None = None, + on_list_resource_templates: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListResourceTemplatesResult], + ] + | None = None, + on_read_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.ReadResourceRequestParams], + Awaitable[types.ReadResourceResult], + ] + | None = None, + on_subscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.SubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_unsubscribe_resource: Callable[ + [ServerRequestContext[LifespanResultT], types.UnsubscribeRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_list_prompts: Callable[ + [ServerRequestContext[LifespanResultT], types.PaginatedRequestParams | None], + Awaitable[types.ListPromptsResult], + ] + | None = None, + on_get_prompt: Callable[ + [ServerRequestContext[LifespanResultT], types.GetPromptRequestParams], + Awaitable[types.GetPromptResult], + ] + | None = None, + on_completion: Callable[ + [ServerRequestContext[LifespanResultT], types.CompleteRequestParams], + Awaitable[types.CompleteResult], + ] + | None = None, + on_set_logging_level: Callable[ + [ServerRequestContext[LifespanResultT], types.SetLevelRequestParams], + Awaitable[types.EmptyResult], + ] + | None = None, + on_ping: Callable[ + [ServerRequestContext[LifespanResultT], types.RequestParams | None], + Awaitable[types.EmptyResult], + ] = _ping_handler, + # Notification handlers + on_roots_list_changed: Callable[ + [ServerRequestContext[LifespanResultT], types.NotificationParams | None], + Awaitable[None], + ] + | None = None, + on_progress: Callable[ + [ServerRequestContext[LifespanResultT], types.ProgressNotificationParams], + Awaitable[None], + ] + | None = None, ): self.name = name self.version = version @@ -163,15 +193,64 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { - types.PingRequest: _ping_handler, - } - self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self._tool_cache: dict[str, types.Tool] = {} - self._experimental_handlers: ExperimentalHandlers | None = None + self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} + self._notification_handlers: dict[ + str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] + ] = {} + self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None logger.debug("Initializing server %r", name) + # Populate internal handler dicts from on_* kwargs + self._request_handlers.update( + { + method: handler + for method, handler in { + "ping": on_ping, + "prompts/list": on_list_prompts, + "prompts/get": on_get_prompt, + "resources/list": on_list_resources, + "resources/templates/list": on_list_resource_templates, + "resources/read": on_read_resource, + "resources/subscribe": on_subscribe_resource, + "resources/unsubscribe": on_unsubscribe_resource, + "tools/list": on_list_tools, + "tools/call": on_call_tool, + "logging/setLevel": on_set_logging_level, + "completion/complete": on_completion, + }.items() + if handler is not None + } + ) + + self._notification_handlers.update( + { + method: handler + for method, handler in { + "notifications/roots/list_changed": on_roots_list_changed, + "notifications/progress": on_progress, + }.items() + if handler is not None + } + ) + + def _add_request_handler( + self, + method: str, + handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + ) -> None: + """Add a request handler, silently replacing any existing handler for the same method.""" + self._request_handlers[method] = handler + + def _has_handler(self, method: str) -> bool: + """Check if a handler is registered for the given method.""" + return method in self._request_handlers or method in self._notification_handlers + + # TODO: Rethink capabilities API. Currently capabilities are derived from registered + # handlers but require NotificationOptions to be passed externally for list_changed + # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities + # entirely from server state (e.g. constructor params for list_changed) instead of + # requiring callers to assemble them at create_initialization_options() time. def create_initialization_options( self, notification_options: NotificationOptions | None = None, @@ -214,25 +293,26 @@ def get_capabilities( completions_capability = None # Set prompt capabilities if handler exists - if types.ListPromptsRequest in self.request_handlers: + if "prompts/list" in self._request_handlers: prompts_capability = types.PromptsCapability(list_changed=notification_options.prompts_changed) # Set resource capabilities if handler exists - if types.ListResourcesRequest in self.request_handlers: + if "resources/list" in self._request_handlers: resources_capability = types.ResourcesCapability( - subscribe=False, list_changed=notification_options.resources_changed + subscribe="resources/subscribe" in self._request_handlers, + list_changed=notification_options.resources_changed, ) # Set tool capabilities if handler exists - if types.ListToolsRequest in self.request_handlers: + if "tools/list" in self._request_handlers: tools_capability = types.ToolsCapability(list_changed=notification_options.tools_changed) # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: + if "logging/setLevel" in self._request_handlers: logging_capability = types.LoggingCapability() # Set completions capabilities if handler exists - if types.CompleteRequest in self.request_handlers: + if "completion/complete" in self._request_handlers: completions_capability = types.CompletionsCapability() capabilities = types.ServerCapabilities( @@ -248,12 +328,7 @@ def get_capabilities( return capabilities @property - def request_context(self) -> ServerRequestContext[LifespanResultT, RequestT]: - """If called outside of a request context, this will raise a LookupError.""" - return request_ctx.get() - - @property - def experimental(self) -> ExperimentalHandlers: + def experimental(self) -> ExperimentalHandlers[LifespanResultT]: """Experimental APIs for tasks and other features. WARNING: These APIs are experimental and may change without notice. @@ -261,7 +336,10 @@ def experimental(self) -> ExperimentalHandlers: # We create this inline so we only add these capabilities _if_ they're actually used if self._experimental_handlers is None: - self._experimental_handlers = ExperimentalHandlers(self, self.request_handlers, self.notification_handlers) + self._experimental_handlers = ExperimentalHandlers( + add_request_handler=self._add_request_handler, + has_handler=self._has_handler, + ) return self._experimental_handlers @property @@ -278,374 +356,6 @@ def session_manager(self) -> StreamableHTTPSessionManager: ) return self._session_manager # pragma: no cover - def list_prompts(self): - def decorator( - func: Callable[[], Awaitable[list[types.Prompt]]] - | Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]], - ): - logger.debug("Registering handler for PromptListRequest") - - wrapper = create_call_wrapper(func, types.ListPromptsRequest) - - async def handler(req: types.ListPromptsRequest): - result = await wrapper(req) - # Handle both old style (list[Prompt]) and new style (ListPromptsResult) - if isinstance(result, types.ListPromptsResult): - return result - else: - # Old style returns list[Prompt] - return types.ListPromptsResult(prompts=result) - - self.request_handlers[types.ListPromptsRequest] = handler - return func - - return decorator - - def get_prompt(self): - def decorator( - func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], - ): - logger.debug("Registering handler for GetPromptRequest") - - async def handler(req: types.GetPromptRequest): - prompt_get = await func(req.params.name, req.params.arguments) - return prompt_get - - self.request_handlers[types.GetPromptRequest] = handler - return func - - return decorator - - def list_resources(self): - def decorator( - func: Callable[[], Awaitable[list[types.Resource]]] - | Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]], - ): - logger.debug("Registering handler for ListResourcesRequest") - - wrapper = create_call_wrapper(func, types.ListResourcesRequest) - - async def handler(req: types.ListResourcesRequest): - result = await wrapper(req) - # Handle both old style (list[Resource]) and new style (ListResourcesResult) - if isinstance(result, types.ListResourcesResult): - return result - else: - # Old style returns list[Resource] - return types.ListResourcesResult(resources=result) - - self.request_handlers[types.ListResourcesRequest] = handler - return func - - return decorator - - def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): - logger.debug("Registering handler for ListResourceTemplatesRequest") - - async def handler(_: Any): - templates = await func() - return types.ListResourceTemplatesResult(resource_templates=templates) - - self.request_handlers[types.ListResourceTemplatesRequest] = handler - return func - - return decorator - - def read_resource(self): - def decorator( - func: Callable[[str], Awaitable[str | bytes | Iterable[ReadResourceContents]]], - ): - logger.debug("Registering handler for ReadResourceRequest") - - async def handler(req: types.ReadResourceRequest): - result = await func(req.params.uri) - - def create_content(data: str | bytes, mime_type: str | None, meta: dict[str, Any] | None = None): - # Note: ResourceContents uses Field(alias="_meta"), so we must use the alias key - meta_kwargs: dict[str, Any] = {"_meta": meta} if meta is not None else {} - match data: - case str() as data: - return types.TextResourceContents( - uri=req.params.uri, - text=data, - mime_type=mime_type or "text/plain", - **meta_kwargs, - ) - case bytes() as data: # pragma: no branch - return types.BlobResourceContents( - uri=req.params.uri, - blob=base64.b64encode(data).decode(), - mime_type=mime_type or "application/octet-stream", - **meta_kwargs, - ) - - match result: - case str() | bytes() as data: # pragma: lax no cover - warnings.warn( - "Returning str or bytes from read_resource is deprecated. " - "Use Iterable[ReadResourceContents] instead.", - DeprecationWarning, - stacklevel=2, - ) - content = create_content(data, None) - case Iterable() as contents: - contents_list = [ - create_content( - content_item.content, content_item.mime_type, getattr(content_item, "meta", None) - ) - for content_item in contents - ] - return types.ReadResourceResult(contents=contents_list) - case _: # pragma: no cover - raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - - return types.ReadResourceResult(contents=[content]) # pragma: no cover - - self.request_handlers[types.ReadResourceRequest] = handler - return func - - return decorator - - def set_logging_level(self): - def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): - logger.debug("Registering handler for SetLevelRequest") - - async def handler(req: types.SetLevelRequest): - await func(req.params.level) - return types.EmptyResult() - - self.request_handlers[types.SetLevelRequest] = handler - return func - - return decorator - - def subscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for SubscribeRequest") - - async def handler(req: types.SubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.SubscribeRequest] = handler - return func - - return decorator - - def unsubscribe_resource(self): - def decorator(func: Callable[[str], Awaitable[None]]): - logger.debug("Registering handler for UnsubscribeRequest") - - async def handler(req: types.UnsubscribeRequest): - await func(req.params.uri) - return types.EmptyResult() - - self.request_handlers[types.UnsubscribeRequest] = handler - return func - - return decorator - - def list_tools(self): - def decorator( - func: Callable[[], Awaitable[list[types.Tool]]] - | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], - ): - logger.debug("Registering handler for ListToolsRequest") - - wrapper = create_call_wrapper(func, types.ListToolsRequest) - - async def handler(req: types.ListToolsRequest): - result = await wrapper(req) - - # Handle both old style (list[Tool]) and new style (ListToolsResult) - if isinstance(result, types.ListToolsResult): - # Refresh the tool cache with returned tools - for tool in result.tools: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return result - else: - # Old style returns list[Tool] - # Clear and refresh the entire tool cache - self._tool_cache.clear() - for tool in result: - validate_and_warn_tool_name(tool.name) - self._tool_cache[tool.name] = tool - return types.ListToolsResult(tools=result) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def _make_error_result(self, error_message: str) -> types.CallToolResult: - """Create a CallToolResult with an error.""" - return types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - is_error=True, - ) - - async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: - """Get tool definition from cache, refreshing if necessary. - - Returns the Tool object if found, None otherwise. - """ - if tool_name not in self._tool_cache: - if types.ListToolsRequest in self.request_handlers: - logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) - - tool = self._tool_cache.get(tool_name) - if tool is None: - logger.warning("Tool '%s' not listed, no validation will be performed", tool_name) - - return tool - - def call_tool(self, *, validate_input: bool = True): - """Register a tool call handler. - - Args: - validate_input: If True, validates input against inputSchema. Default is True. - - The handler validates input against inputSchema (if validate_input=True), calls the tool function, - and builds a CallToolResult with the results: - - Unstructured content (iterable of ContentBlock): returned in content - - Structured content (dict): returned in structuredContent, serialized JSON text returned in content - - Both: returned in content and structuredContent - - If outputSchema is defined, validates structuredContent or errors if missing. - """ - - def decorator( - func: Callable[ - [str, dict[str, Any]], - Awaitable[ - UnstructuredContent - | StructuredContent - | CombinationContent - | types.CallToolResult - | types.CreateTaskResult - ], - ], - ): - logger.debug("Registering handler for CallToolRequest") - - async def handler(req: types.CallToolRequest): - try: - tool_name = req.params.name - arguments = req.params.arguments or {} - tool = await self._get_cached_tool_definition(tool_name) - - # input validation - if validate_input and tool: - try: - jsonschema.validate(instance=arguments, schema=tool.input_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Input validation error: {e.message}") - - # tool call - results = await func(tool_name, arguments) - - # output normalization - unstructured_content: UnstructuredContent - maybe_structured_content: StructuredContent | None - if isinstance(results, types.CallToolResult): - return results - elif isinstance(results, types.CreateTaskResult): - # Task-augmented execution returns task info instead of result - return results - elif isinstance(results, tuple) and len(results) == 2: - # tool returned both structured and unstructured content - unstructured_content, maybe_structured_content = cast(CombinationContent, results) - elif isinstance(results, dict): - # tool returned structured content only - maybe_structured_content = cast(StructuredContent, results) - unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] - elif hasattr(results, "__iter__"): - # tool returned unstructured content only - unstructured_content = cast(UnstructuredContent, results) - maybe_structured_content = None - else: # pragma: no cover - return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") - - # output validation - if tool and tool.output_schema is not None: - if maybe_structured_content is None: - return self._make_error_result( - "Output validation error: outputSchema defined but no structured output returned" - ) - else: - try: - jsonschema.validate(instance=maybe_structured_content, schema=tool.output_schema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Output validation error: {e.message}") - - # result - return types.CallToolResult( - content=list(unstructured_content), - structured_content=maybe_structured_content, - is_error=False, - ) - except UrlElicitationRequiredError: - # Re-raise UrlElicitationRequiredError so it can be properly handled - # by _handle_request, which converts it to an error response with code -32042 - raise - except Exception as e: - return self._make_error_result(str(e)) - - self.request_handlers[types.CallToolRequest] = handler - return func - - return decorator - - def progress_notification(self): - def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], - ): - logger.debug("Registering handler for ProgressNotification") - - async def handler(req: types.ProgressNotification): - await func( - req.params.progress_token, - req.params.progress, - req.params.total, - req.params.message, - ) - - self.notification_handlers[types.ProgressNotification] = handler - return func - - return decorator - - def completion(self): - """Provides completions for prompts and resource templates""" - - def decorator( - func: Callable[ - [ - types.PromptReference | types.ResourceTemplateReference, - types.CompletionArgument, - types.CompletionContext | None, - ], - Awaitable[types.Completion | None], - ], - ): - logger.debug("Registering handler for CompleteRequest") - - async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, has_more=None), - ) - - self.request_handlers[types.CompleteRequest] = handler - return func - - return decorator - async def run( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -715,7 +425,7 @@ async def _handle_message( if raise_exceptions: raise message case _: - await self._handle_notification(message) + await self._handle_notification(message, session, lifespan_context) for warning in w: # pragma: lax no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) @@ -730,10 +440,9 @@ async def _handle_request( ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): + if handler := self._request_handlers.get(req.method): logger.debug("Dispatching request of type %s", type(req).__name__) - token = None try: # Extract request context and close_sse_stream from message metadata request_data = None @@ -744,32 +453,32 @@ async def _handle_request( close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - # Set our global state that can be retrieved via - # app.get_request_context() client_capabilities = session.client_params.capabilities if session.client_params else None task_support = self._experimental_handlers.task_support if self._experimental_handlers else None # Get task metadata from request params if present task_metadata = None if hasattr(req, "params") and req.params is not None: task_metadata = getattr(req.params, "task", None) - token = request_ctx.set( - ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=task_metadata, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) + ctx = ServerRequestContext( + request_id=message.request_id, + meta=message.request_meta, + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - response = await handler(req) + token = request_ctx.set(ctx) + try: + response = await handler(ctx, req.params) + finally: + request_ctx.reset(token) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): @@ -778,11 +487,7 @@ async def _handle_request( except Exception as err: if raise_exceptions: # pragma: no cover raise err - response = types.ErrorData(code=0, message=str(err), data=None) - finally: - # Reset the global state after we are done - if token is not None: # pragma: no branch - request_ctx.reset(token) + response = types.ErrorData(code=0, message=str(err)) await message.respond(response) else: # pragma: no cover @@ -790,12 +495,29 @@ async def _handle_request( logger.debug("Response sent") - async def _handle_notification(self, notify: Any): - if handler := self.notification_handlers.get(type(notify)): # type: ignore + async def _handle_notification( + self, + notify: types.ClientNotification, + session: ServerSession, + lifespan_context: LifespanResultT, + ) -> None: + if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") @@ -910,7 +632,3 @@ def streamable_http_app( middleware=middleware, lifespan=lambda app: session_manager.run(), ) - - -async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.EmptyResult() diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 64e633806..42aecd6e3 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -109,7 +109,7 @@ def from_function( class FileResource(Resource): """A resource that reads from a file. - Set is_binary=True to read file as binary data instead of text. + Set is_binary=True to read the file as binary data instead of text. """ path: Path = Field(description="Path to the file") diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342b..9c7105a7b 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -2,7 +2,9 @@ from __future__ import annotations +import base64 import inspect +import json import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager @@ -29,7 +31,7 @@ from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation from mcp.server.elicitation import elicit_url as _elicit_url from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT, Server +from mcp.server.lowlevel.server import LifespanResultT, Server, request_ctx from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.prompts import Prompt, PromptManager @@ -42,7 +44,30 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations +from mcp.shared.exceptions import MCPError +from mcp.types import ( + Annotations, + BlobResourceContents, + CallToolRequestParams, + CallToolResult, + CompleteRequestParams, + CompleteResult, + Completion, + ContentBlock, + GetPromptRequestParams, + GetPromptResult, + Icon, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + ToolAnnotations, +) from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -83,7 +108,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): warn_on_duplicate_prompts: bool lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None - """A async context manager that will be called when the server is started.""" + """An async context manager that will be called when the server is started.""" auth: AuthSettings | None @@ -91,9 +116,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: MCPServer[LifespanResultT], lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[Server[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: +) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(_: Server[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: + async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context @@ -132,6 +157,9 @@ def __init__( auth=auth, ) + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) self._lowlevel_server = Server( name=name or "mcp-server", title=title, @@ -140,13 +168,17 @@ def __init__( website_url=website_url, icons=icons, version=version, + on_list_tools=self._handle_list_tools, + on_call_tool=self._handle_call_tool, + on_list_resources=self._handle_list_resources, + on_read_resource=self._handle_read_resource, + on_list_resource_templates=self._handle_list_resource_templates, + on_list_prompts=self._handle_list_prompts, + on_get_prompt=self._handle_get_prompt, # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: # pragma: no cover @@ -164,9 +196,6 @@ def __init__( self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._custom_starlette_routes: list[Route] = [] - # Set up MCP protocol handlers - self._setup_handlers() - # Configure logging configure_logging(self.settings.log_level) @@ -263,18 +292,83 @@ def run( case "streamable-http": # pragma: no cover anyio.run(lambda: self.run_streamable_http_async(**kwargs)) - def _setup_handlers(self) -> None: - """Set up core MCP protocol handlers.""" - self._lowlevel_server.list_tools()(self.list_tools) - # Note: we disable the lowlevel server's input validation. - # MCPServer does ad hoc conversion of incoming data before validating - - # for now we preserve this for backwards compatibility. - self._lowlevel_server.call_tool(validate_input=False)(self.call_tool) - self._lowlevel_server.list_resources()(self.list_resources) - self._lowlevel_server.read_resource()(self.read_resource) - self._lowlevel_server.list_prompts()(self.list_prompts) - self._lowlevel_server.get_prompt()(self.get_prompt) - self._lowlevel_server.list_resource_templates()(self.list_resource_templates) + async def _handle_list_tools( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListToolsResult: + return ListToolsResult(tools=await self.list_tools()) + + async def _handle_call_tool( + self, ctx: ServerRequestContext[LifespanResultT], params: CallToolRequestParams + ) -> CallToolResult: + try: + result = await self.call_tool(params.name, params.arguments or {}) + except MCPError: + raise + except Exception as e: + return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True) + if isinstance(result, CallToolResult): + return result + if isinstance(result, tuple) and len(result) == 2: + unstructured_content, structured_content = result + return CallToolResult( + content=list(unstructured_content), # type: ignore[arg-type] + structured_content=structured_content, # type: ignore[arg-type] + ) + if isinstance(result, dict): # pragma: no cover + # TODO: this code path is unreachable — convert_result never returns a raw dict. + # The call_tool return type (Sequence[ContentBlock] | dict[str, Any]) is wrong + # and needs to be cleaned up. + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(result, indent=2))], + structured_content=result, + ) + return CallToolResult(content=list(result)) + + async def _handle_list_resources( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=await self.list_resources()) + + async def _handle_read_resource( + self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams + ) -> ReadResourceResult: + results = await self.read_resource(params.uri) + contents: list[TextResourceContents | BlobResourceContents] = [] + for item in results: + if isinstance(item.content, bytes): + contents.append( + BlobResourceContents( + uri=params.uri, + blob=base64.b64encode(item.content).decode(), + mime_type=item.mime_type or "application/octet-stream", + _meta=item.meta, + ) + ) + else: + contents.append( + TextResourceContents( + uri=params.uri, + text=item.content, + mime_type=item.mime_type or "text/plain", + _meta=item.meta, + ) + ) + return ReadResourceResult(contents=contents) + + async def _handle_list_resource_templates( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult(resource_templates=await self.list_resource_templates()) + + async def _handle_list_prompts( + self, ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=await self.list_prompts()) + + async def _handle_get_prompt( + self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams + ) -> GetPromptResult: + return await self.get_prompt(params.name, params.arguments) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -294,11 +388,13 @@ async def list_tools(self) -> list[MCPTool]: ] def get_context(self) -> Context[LifespanResultT, Request]: - """Returns a Context object. Note that the context will only be valid - during a request; outside a request, most methods will error. + """Return a Context object. + + Note that the context will only be valid during a request; outside a + request, most methods will error. """ try: - request_context = self._lowlevel_server.request_context + request_context = request_ctx.get() except LookupError: request_context = None return Context(request_context=request_context, mcp_server=self) @@ -381,6 +477,8 @@ def add_tool( title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + icons: Optional list of icons for the tool + meta: Optional metadata dictionary for the tool structured_output: Controls whether the tool's output is structured or unstructured - If None, auto-detects based on the function's return type annotation - If True, creates a structured tool (return type annotation permitting) @@ -429,25 +527,33 @@ def tool( title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information + icons: Optional list of icons for the tool + meta: Optional metadata dictionary for the tool structured_output: Controls whether the tool's output is structured or unstructured - If None, auto-detects based on the function's return type annotation - If True, creates a structured tool (return type annotation permitting) - If False, unconditionally creates an unstructured tool Example: + ```python @server.tool() def my_tool(x: int) -> str: return str(x) + ``` + ```python @server.tool() - def tool_with_context(x: int, ctx: Context) -> str: - ctx.info(f"Processing {x}") + async def tool_with_context(x: int, ctx: Context) -> str: + await ctx.info(f"Processing {x}") return str(x) + ``` + ```python @server.tool() async def async_tool(x: int, context: Context) -> str: await context.report_progress(50, 100) return str(x) + ``` """ # Check if user passed function directly instead of calling decorator if callable(name): @@ -479,14 +585,33 @@ def completion(self): - context: Optional CompletionContext with previously resolved arguments Example: + ```python @mcp.completion() async def handle_completion(ref, argument, context): if isinstance(ref, ResourceTemplateReference): # Return completions based on ref, argument, and context return Completion(values=["option1", "option2"]) return None + ``` """ - return self._lowlevel_server.completion() + + def decorator(func: _CallableT) -> _CallableT: + async def handler( + ctx: ServerRequestContext[LifespanResultT], params: CompleteRequestParams + ) -> CompleteResult: + result = await func(params.ref, params.argument, params.context) + return CompleteResult( + completion=result if result is not None else Completion(values=[], total=None, has_more=None), + ) + + # TODO(maxisbey): remove private access — completion needs post-construction + # handler registration, find a better pattern for this + self._lowlevel_server._add_request_handler( # pyright: ignore[reportPrivateUsage] + "completion/complete", handler + ) + return func + + return decorator def add_resource(self, resource: Resource) -> None: """Add a resource to the server. @@ -525,15 +650,18 @@ def resource( title: Optional human-readable title for the resource description: Optional description of the resource mime_type: Optional MIME type for the resource + icons: Optional list of icons for the resource + annotations: Optional annotations for the resource meta: Optional metadata dictionary for the resource Example: + ```python @server.resource("resource://my-resource") def get_data() -> str: return "Hello, world!" @server.resource("resource://my-resource") - async get_data() -> str: + async def get_data() -> str: data = await fetch_data() return f"Hello, world! {data}" @@ -545,6 +673,7 @@ def get_weather(city: str) -> str: async def get_weather(city: str) -> str: data = await fetch_weather(city) return f"Weather for {city}: {data}" + ``` """ # Check if user passed function directly instead of calling decorator if callable(uri): @@ -625,8 +754,10 @@ def prompt( name: Optional name for the prompt (defaults to function name) title: Optional human-readable title for the prompt description: Optional description of what the prompt does + icons: Optional list of icons for the prompt Example: + ```python @server.prompt() def analyze_table(table_name: str) -> list[Message]: schema = read_table_schema(table_name) @@ -652,6 +783,7 @@ async def analyze_file(path: str) -> list[Message]: } } ] + ``` """ # Check if user passed function directly instead of calling decorator if callable(name): @@ -693,9 +825,11 @@ def custom_route( include_in_schema: Whether to include in OpenAPI schema, defaults to True Example: + ```python @server.custom_route("/health", methods=["GET"]) async def health_check(request: Request) -> Response: return JSONResponse({"status": "ok"}) + ``` """ def decorator( # pragma: no cover @@ -981,18 +1115,18 @@ class Context(BaseModel, Generic[LifespanContextT, RequestT]): ```python @server.tool() - def my_tool(x: int, ctx: Context) -> str: + async def my_tool(x: int, ctx: Context) -> str: # Log messages to the client - ctx.info(f"Processing {x}") - ctx.debug("Debug info") - ctx.warning("Warning message") - ctx.error("Error message") + await ctx.info(f"Processing {x}") + await ctx.debug("Debug info") + await ctx.warning("Warning message") + await ctx.error("Error message") # Report progress - ctx.report_progress(50, 100) + await ctx.report_progress(50, 100) # Access resources - data = ctx.read_resource("resource://data") + data = await ctx.read_resource("resource://data") # Get request info request_id = ctx.request_id @@ -1038,9 +1172,9 @@ async def report_progress(self, progress: float, total: float | None = None, mes """Report progress for the current operation. Args: - progress: Current progress value e.g. 24 - total: Optional total value e.g. 100 - message: Optional message e.g. Starting render... + progress: Current progress value (e.g., 24) + total: Optional total value (e.g., 100) + message: Optional message (e.g., "Starting render...") """ progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None @@ -1052,6 +1186,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes progress=progress, total=total, message=message, + related_request_id=self.request_id, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: @@ -1075,15 +1210,14 @@ async def elicit( This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking + user and collect a response according to the provided schema. If the client + is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. Args: - schema: A Pydantic model class defining the expected response structure, according to the specification, - only primitive types are allowed. - message: Optional message to present to the user. If not provided, will use - a default message based on the schema + message: Message to present to the user + schema: A Pydantic model class defining the expected response structure. + According to the specification, only primitive types are allowed. Returns: An ElicitationResult containing the action taken and the data if accepted @@ -1117,7 +1251,7 @@ async def elicit_url( The response indicates whether the user consented to navigate to the URL. The actual interaction happens out-of-band. When the elicitation completes, - call `self.session.send_elicit_complete(elicitation_id)` to notify the client. + call `ctx.session.send_elicit_complete(elicitation_id)` to notify the client. Args: message: Human-readable explanation of why the interaction is needed @@ -1187,7 +1321,7 @@ async def close_sse_stream(self) -> None: be replayed when the client reconnects with Last-Event-ID. Use this to implement polling behavior during long-running operations - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Note: This is a no-op if not using StreamableHTTP transport with event_store. diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 4b539ce1f..062b47d0f 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -46,7 +46,7 @@ class ArgModelBase(BaseModel): def model_dump_one_level(self) -> dict[str, Any]: """Return a dict of the model's fields, one level deep. - That is, sub-models etc are not dumped - they are kept as pydantic models. + That is, sub-models etc are not dumped - they are kept as Pydantic models. """ kwargs: dict[str, Any] = {} for field_name, field_info in self.__class__.model_fields.items(): @@ -89,8 +89,7 @@ async def call_fn_with_arg_validation( return await anyio.to_thread.run_sync(functools.partial(fn, **arguments_parsed_dict)) def convert_result(self, result: Any) -> Any: - """Convert the result of a function call to the appropriate format for - the lowlevel server tool call handler: + """Convert a function call result to the format for the lowlevel tool call handler. - If output_model is None, return the unstructured content directly. - If output_model is not None, convert the result to structured output format @@ -126,11 +125,11 @@ def convert_result(self, result: Any) -> Any: def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: """Pre-parse data from JSON. - Return a dict with same keys as input but with values parsed from JSON + Return a dict with the same keys as input but with values parsed from JSON if appropriate. This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside - a string rather than an actual list. Claude desktop is prone to this - in fact + a string rather than an actual list. Claude Desktop is prone to this - in fact it seems incapable of NOT doing this. For sub-models, it tends to pass dicts (JSON objects) as JSON strings, which can be pre-parsed here. """ @@ -173,8 +172,7 @@ def func_metadata( skip_names: Sequence[str] = (), structured_output: bool | None = None, ) -> FuncMetadata: - """Given a function, return metadata including a pydantic model representing its - signature. + """Given a function, return metadata including a Pydantic model representing its signature. The use case for this is ``` @@ -183,11 +181,11 @@ def func_metadata( return func(**validated_args.model_dump_one_level()) ``` - **critically** it also provides pre-parse helper to attempt to parse things from + **critically** it also provides a pre-parse helper to attempt to parse things from JSON. Args: - func: The function to convert to a pydantic model + func: The function to convert to a Pydantic model skip_names: A list of parameter names to skip. These will not be included in the model. structured_output: Controls whether the tool's output is structured or unstructured @@ -195,8 +193,8 @@ def func_metadata( - If True, creates a structured tool (return type annotation permitting) - If False, unconditionally creates an unstructured tool - If structured, creates a Pydantic model for the function's result based on its annotation. - Supports various return types: + If structured, creates a Pydantic model for the function's result based on its annotation. + Supports various return types: - BaseModel subclasses (used directly) - Primitive types (str, int, float, bool, bytes, None) - wrapped in a model with a 'result' field @@ -206,9 +204,9 @@ def func_metadata( Returns: A FuncMetadata object containing: - - arg_model: A pydantic model representing the function's arguments - - output_model: A pydantic model for the return type if output is structured - - output_conversion: Records how function output should be converted before returning. + - arg_model: A Pydantic model representing the function's arguments + - output_model: A Pydantic model for the return type if the output is structured + - wrap_output: Whether the function result needs to be wrapped in `{"result": ...}` for structured output. """ try: sig = inspect.signature(func, eval_str=True) @@ -296,7 +294,7 @@ def func_metadata( ] # pragma: no cover else: # We only had `Annotated[CallToolResult, ReturnType]`, treat the original annotation - # as beging `ReturnType`: + # as being `ReturnType`: original_annotation = return_type_expr else: return FuncMetadata(arg_model=arguments_model) @@ -355,7 +353,7 @@ def _try_create_model_and_schema( if origin is dict: args = get_args(type_expr) if len(args) == 2 and args[0] is str: - # TODO: should we use the original annotation? We are loosing any potential `Annotated` + # TODO: should we use the original annotation? We are losing any potential `Annotated` # metadata for Pydantic here: model = _create_dict_model(func_name, type_expr) else: diff --git a/src/mcp/server/mcpserver/utilities/logging.py b/src/mcp/server/mcpserver/utilities/logging.py index c394f2bfa..04ca38853 100644 --- a/src/mcp/server/mcpserver/utilities/logging.py +++ b/src/mcp/server/mcpserver/utilities/logging.py @@ -8,10 +8,10 @@ def get_logger(name: str) -> logging.Logger: """Get a logger nested under MCP namespace. Args: - name: the name of the logger + name: The name of the logger. Returns: - a configured logger instance + A configured logger instance. """ return logging.getLogger(name) @@ -22,7 +22,7 @@ def configure_logging( """Configure logging for MCP. Args: - level: the log level to use + level: The log level to use. """ handlers: list[logging.Handler] = [] try: diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 41b9224c1..3861f42a7 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -1,4 +1,4 @@ -"""This module provides simpler types to use with the server for managing prompts +"""This module provides simplified types to use with the server for managing prompts and tools. """ diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f496121a3..759d2131a 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,30 +6,22 @@ Common usage pattern: ``` - server = Server(name) - - @server.call_tool() - async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any: + async def handle_call_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: # Check client capabilities before proceeding if ctx.session.check_client_capability( types.ClientCapabilities(experimental={"advanced_tools": dict()}) ): - # Perform advanced tool operations - result = await perform_advanced_tool_operation(arguments) + result = await perform_advanced_tool_operation(params.arguments) else: - # Fall back to basic tool operations - result = await perform_basic_tool_operation(arguments) - + result = await perform_basic_tool_operation(params.arguments) return result - @server.list_prompts() - async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: - # Access session for any necessary checks or operations + async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: if ctx.session.client_params: - # Customize prompts based on client initialization parameters - return generate_custom_prompts(ctx.session.client_params) - else: - return default_prompts + return ListPromptsResult(prompts=generate_custom_prompts(ctx.session.client_params)) + return ListPromptsResult(prompts=default_prompts) + + server = Server(name, on_call_tool=handle_call_tool, on_list_prompts=handle_list_prompts) ``` The ServerSession class is typically used internally by the Server class and should not @@ -371,12 +363,12 @@ async def elicit( """Send a form mode elicitation/create request. Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation + message: The message to present to the user. + requested_schema: Schema defining the expected response structure. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response + The client's response. Note: This method is deprecated in favor of elicit_form(). It remains for @@ -393,12 +385,12 @@ async def elicit_form( """Send a form mode elicitation/create request. Args: - message: The message to present to the user - requested_schema: Schema defining the expected response structure - related_request_id: Optional ID of the request that triggered this elicitation + message: The message to present to the user. + requested_schema: Schema defining the expected response structure. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response with form data + The client's response with form data. Raises: StatelessModeNotSupported: If called in stateless HTTP mode. @@ -429,13 +421,13 @@ async def elicit_url( like OAuth flows, credential collection, or payment processing. Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - related_request_id: Optional ID of the request that triggered this elicitation + message: Human-readable explanation of why the interaction is needed. + url: The URL the user should navigate to. + elicitation_id: Unique identifier for tracking this elicitation. + related_request_id: Optional ID of the request that triggered this elicitation. Returns: - The client's response indicating acceptance, decline, or cancellation + The client's response indicating acceptance, decline, or cancellation. Raises: StatelessModeNotSupported: If called in stateless HTTP mode. @@ -507,7 +499,7 @@ async def send_elicit_complete( Args: elicitation_id: The unique identifier of the completed elicitation - related_request_id: Optional ID of the request that triggered this + related_request_id: Optional ID of the request that triggered this notification """ await self.send_notification( types.ElicitCompleteNotification( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 5be6b78ca..9007230ce 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -2,8 +2,8 @@ This module implements a Server-Sent Events (SSE) transport layer for MCP servers. -Example usage: -``` +Example: + ```python # Create an SSE transport at an endpoint sse = SseServerTransport("/messages/") @@ -27,10 +27,10 @@ async def handle_sse(request): # Create and run Starlette app starlette_app = Starlette(routes=routes) uvicorn.run(starlette_app, host="127.0.0.1", port=port) -``` + ``` -Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' -object is not callable" error when client disconnects. The example above returns +Note: The handle_sse function must return a Response to avoid a +"TypeError: 'NoneType' object is not callable" error when client disconnects. The example above returns an empty Response() after the SSE connection ends to fix this. See SseServerTransport class documentation for more details. @@ -61,8 +61,8 @@ async def handle_sse(request): class SseServerTransport: - """SSE server transport for MCP. This class provides _two_ ASGI applications, - suitable to be used with a framework like Starlette and a server like Hypercorn: + """SSE server transport for MCP. This class provides two ASGI applications, + suitable for use with a framework like Starlette and a server like Hypercorn: 1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client. @@ -170,7 +170,7 @@ async def sse_writer(): await sse_stream_writer.send( { "event": "message", - "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True), + "data": session_message.message.model_dump_json(by_alias=True, exclude_unset=True), } ) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 7f3aa2ac2..e526bab56 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -4,8 +4,8 @@ that can be used to communicate with an MCP client through standard input/output streams. -Example usage: -``` +Example: + ```python async def run_server(): async with stdio_server() as (read_stream, write_stream): # read_stream contains incoming JSONRPCMessages from stdin @@ -14,7 +14,7 @@ async def run_server(): await server.run(read_stream, write_stream, init_options) anyio.run(run_server) -``` + ``` """ import sys @@ -71,7 +71,7 @@ async def stdout_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: # pragma: no cover diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 54ac7374a..04aed345e 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -89,7 +89,7 @@ async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) message: The JSON-RPC message to store, or None for priming events Returns: - The generated event ID for the stored event + The generated event ID for the stored event. """ pass # pragma: no cover @@ -106,7 +106,7 @@ async def replay_events_after( send_callback: A callback function to send events to the client Returns: - The stream ID of the replayed events + The stream ID of the replayed events, or None if no events were found. """ pass # pragma: no cover @@ -169,6 +169,8 @@ def __init__( ] = {} self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} self._terminated = False + # Idle timeout cancel scope; managed by the session manager. + self.idle_scope: anyio.CancelScope | None = None @property def is_terminated(self) -> bool: @@ -183,7 +185,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover be replayed when the client reconnects with Last-Event-ID. Use this to implement polling behavior during long-running operations - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Args: request_id: The request ID whose SSE stream should be closed. @@ -211,7 +213,7 @@ def close_standalone_sse_stream(self) -> None: # pragma: no cover with Last-Event-ID to resume receiving notifications. Use this to implement polling behavior for the notification stream - - client will reconnect after the retry interval specified in the priming event. + the client will reconnect after the retry interval specified in the priming event. Note: This is a no-op if there is no active standalone SSE stream. @@ -298,12 +300,12 @@ def _create_error_response( # Return a properly formatted JSON error response error_response = JSONRPCError( jsonrpc="2.0", - id="server-error", # We don't have a request ID for general errors + id=None, error=ErrorData(code=error_code, message=error_message), ) return Response( - error_response.model_dump_json(by_alias=True, exclude_none=True), + error_response.model_dump_json(by_alias=True, exclude_unset=True), status_code=status_code, headers=response_headers, ) @@ -314,7 +316,7 @@ def _create_json_response( status_code: HTTPStatus = HTTPStatus.OK, headers: dict[str, str] | None = None, ) -> Response: - """Create a JSON response from a JSONRPCMessage""" + """Create a JSON response from a JSONRPCMessage.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: # pragma: lax no cover response_headers.update(headers) @@ -323,7 +325,7 @@ def _create_json_response( response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None, + response_message.model_dump_json(by_alias=True, exclude_unset=True) if response_message else None, status_code=status_code, headers=response_headers, ) @@ -336,7 +338,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", - "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), + "data": event_message.message.model_dump_json(by_alias=True, exclude_unset=True), } # If an event ID was provided, include it @@ -360,7 +362,7 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: self._request_streams.pop(request_id, None) async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Application entry point that handles all HTTP requests""" + """Application entry point that handles all HTTP requests.""" request = Request(scope, receive) # Validate request headers for DNS rebinding protection @@ -534,7 +536,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): response_message = event_message.message break - # For notifications and request, keep waiting + # For notifications and requests, keep waiting else: # pragma: no cover logger.debug(f"received: {event_message.message.method}") @@ -858,6 +860,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover """Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. """ event_store = self._event_store @@ -975,12 +978,11 @@ async def message_router(): # Determine which request stream(s) should receive this message message = session_message.message target_request_id = None - # Check if this is a response - if isinstance(message, JSONRPCResponse | JSONRPCError): - response_id = str(message.id) - # If this response is for an existing request stream, - # send it there - target_request_id = response_id + # Check if this is a response with a known request id. + # Null-id errors (e.g., parse errors) fall through to + # the GET stream since they can't be correlated. + if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: + target_request_id = str(message.id) # Extract related_request_id from meta if it exists elif ( # pragma: no cover session_message.metadata is not None diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index ddc6e5014..50bcd5e79 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -39,6 +39,7 @@ class StreamableHTTPSessionManager: 2. Resumability via an optional event store 3. Connection management and lifecycle 4. Request handling and transport setup + 5. Idle session cleanup via optional timeout Important: Only one StreamableHTTPSessionManager instance should be created per application. The instance cannot be reused after its run() context has @@ -46,33 +47,44 @@ class StreamableHTTPSessionManager: Args: app: The MCP server instance - event_store: Optional event store for resumability support. - If provided, enables resumable connections where clients - can reconnect and receive missed events. - If None, sessions are still tracked but not resumable. + event_store: Optional event store for resumability support. If provided, enables resumable connections + where clients can reconnect and receive missed events. If None, sessions are still tracked but not + resumable. json_response: Whether to use JSON responses instead of SSE streams - stateless: If True, creates a completely fresh transport for each request - with no session tracking or state persistence between requests. + stateless: If True, creates a completely fresh transport for each request with no session tracking or + state persistence between requests. security_settings: Optional transport security settings. - retry_interval: Retry interval in milliseconds to suggest to clients in SSE - retry field. Used for SSE polling behavior. + retry_interval: Retry interval in milliseconds to suggest to clients in SSE retry field. Used for SSE + polling behavior. + session_idle_timeout: Optional idle timeout in seconds for stateful sessions. If set, sessions that + receive no HTTP requests for this duration will be automatically terminated and removed. When + retry_interval is also configured, ensure the idle timeout comfortably exceeds the retry interval to + avoid reaping sessions during normal SSE polling gaps. Default is None (no timeout). A value of 1800 + (30 minutes) is recommended for most deployments. """ def __init__( self, - app: Server[Any, Any], + app: Server[Any], event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, security_settings: TransportSecuritySettings | None = None, retry_interval: int | None = None, + session_idle_timeout: float | None = None, ): + if session_idle_timeout is not None and session_idle_timeout <= 0: + raise ValueError("session_idle_timeout must be a positive number of seconds") + if stateless and session_idle_timeout is not None: + raise RuntimeError("session_idle_timeout is not supported in stateless mode") + self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless self.security_settings = security_settings self.retry_interval = retry_interval + self.session_idle_timeout = session_idle_timeout # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -184,6 +196,9 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") + # Push back idle deadline on activity + if transport.idle_scope is not None and self.session_idle_timeout is not None: + transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout # pragma: no cover await transport.handle_request(scope, receive, send) return @@ -210,16 +225,31 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE read_stream, write_stream = streams task_status.started() try: - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=False, # Stateful mode - ) + # Use a cancel scope for idle timeout — when the + # deadline passes the scope cancels app.run() and + # execution continues after the ``with`` block. + # Incoming requests push the deadline forward. + idle_scope = anyio.CancelScope() + if self.session_idle_timeout is not None: + idle_scope.deadline = anyio.current_time() + self.session_idle_timeout + http_transport.idle_scope = idle_scope + + with idle_scope: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, + ) + + if idle_scope.cancelled_caught: + assert http_transport.mcp_session_id is not None + logger.info(f"Session {http_transport.mcp_session_id} idle timeout") + self._server_instances.pop(http_transport.mcp_session_id, None) + await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") finally: - # Only remove from instances if not terminated if ( # pragma: no branch http_transport.mcp_session_id and http_transport.mcp_session_id in self._server_instances @@ -244,11 +274,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 error_response = JSONRPCError( jsonrpc="2.0", - id="server-error", + id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found"), ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_none=True), + content=error_response.model_dump_json(by_alias=True, exclude_unset=True), status_code=HTTPStatus.NOT_FOUND, media_type="application/json", ) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index a4c844811..3e675da5f 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -12,8 +12,8 @@ @asynccontextmanager # pragma: no cover async def websocket_server(scope: Scope, receive: Receive, send: Send): - """WebSocket server transport for MCP. This is an ASGI application, suitable to be - used with a framework like Starlette and a server like Hypercorn. + """WebSocket server transport for MCP. This is an ASGI application, suitable for use + with a framework like Starlette and a server like Hypercorn. """ websocket = WebSocket(scope, receive, send) @@ -47,7 +47,7 @@ async def ws_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await websocket.send_text(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py index 2facc2a49..bbcee2d02 100644 --- a/src/mcp/shared/_context.py +++ b/src/mcp/shared/_context.py @@ -13,8 +13,12 @@ @dataclass(kw_only=True) class RequestContext(Generic[SessionT]): - """Common context for handling incoming requests.""" + """Common context for handling incoming requests. + + For request handlers, request_id is always populated. + For notification handlers, request_id is None. + """ - request_id: RequestId - meta: RequestParamsMeta | None session: SessionT + request_id: RequestId | None = None + meta: RequestParamsMeta | None = None diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 8cf7bda2a..251469eaa 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -44,26 +44,38 @@ def create_mcp_http_client( The returned AsyncClient must be used as a context manager to ensure proper cleanup of connections. - Examples: - # Basic usage with MCP defaults + Example: + Basic usage with MCP defaults: + + ```python async with create_mcp_http_client() as client: response = await client.get("https://api.example.com") + ``` + + With custom headers: - # With custom headers + ```python headers = {"Authorization": "Bearer token"} async with create_mcp_http_client(headers) as client: response = await client.get("/endpoint") + ``` - # With both custom headers and timeout + With both custom headers and timeout: + + ```python timeout = httpx.Timeout(60.0, read=300.0) async with create_mcp_http_client(headers, timeout) as client: response = await client.get("/long-request") + ``` + + With authentication: - # With authentication + ```python from httpx import BasicAuth auth = BasicAuth(username="user", password="pass") async with create_mcp_http_client(headers, timeout, auth) as client: response = await client.get("/protected-endpoint") + ``` """ # Set MCP defaults kwargs: dict[str, Any] = {"follow_redirects": True} diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index bf03a8b8d..ca5b7b45a 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -33,9 +33,8 @@ def __init__(self, message: str): class OAuthClientMetadata(BaseModel): - """RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. + """RFC 7591 OAuth 2.0 Dynamic Client Registration Metadata. See https://datatracker.ietf.org/doc/html/rfc7591#section-2 - for the full specification. """ redirect_uris: list[AnyUrl] | None = Field(..., min_length=1) @@ -145,9 +144,9 @@ class ProtectedResourceMetadata(BaseModel): resource_documentation: AnyHttpUrl | None = None resource_policy_uri: AnyHttpUrl | None = None resource_tos_uri: AnyHttpUrl | None = None - # tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity + # tls_client_certificate_bound_access_tokens default is False, but omitted here for clarity tls_client_certificate_bound_access_tokens: bool | None = None authorization_details_types_supported: list[str] | None = None dpop_signing_alg_values_supported: list[str] | None = None - # dpop_bound_access_tokens_required default is False, but ommited here for clarity + # dpop_bound_access_tokens_required default is False, but omitted here for clarity dpop_bound_access_tokens_required: bool | None = None diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 8f3c542f2..3ba880f40 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -51,22 +51,17 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): return False - # Handle cases like requested=/foo and configured=/foo/ + # Normalize trailing slashes before comparison so that + # "/foo" and "/foo/" are treated as equivalent. requested_path = requested.path configured_path = configured.path - - # If requested path is shorter, it cannot be a child - if len(requested_path) < len(configured_path): - return False - - # Check if the requested path starts with the configured path - # Ensure both paths end with / for proper comparison - # This ensures that paths like "/api123" don't incorrectly match "/api" if not requested_path.endswith("/"): requested_path += "/" if not configured_path.endswith("/"): configured_path += "/" + # Check hierarchical match: requested must start with configured path. + # The trailing-slash normalization ensures "/api123/" won't match "/api/". return requested_path.startswith(configured_path) diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 7a2b2ded4..f153ea319 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -12,7 +12,10 @@ class MCPError(Exception): def __init__(self, code: int, message: str, data: Any = None): super().__init__(code, message, data) - self.error = ErrorData(code=code, message=message, data=data) + if data is not None: + self.error = ErrorData(code=code, message=message, data=data) + else: + self.error = ErrorData(code=code, message=message) @property def code(self) -> int: @@ -62,6 +65,7 @@ class UrlElicitationRequiredError(MCPError): must complete one or more URL elicitations before the request can be processed. Example: + ```python raise UrlElicitationRequiredError([ ElicitRequestURLParams( message="Authorization required for your files", @@ -69,6 +73,7 @@ class UrlElicitationRequiredError(MCPError): elicitation_id="auth-001" ) ]) + ``` """ def __init__(self, elicitations: list[ElicitRequestURLParams], message: str | None = None): diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py index 38ca802da..3f91cd0d0 100644 --- a/src/mcp/shared/experimental/tasks/helpers.py +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -72,9 +72,10 @@ async def cancel_task( - Task is already in a terminal state (completed, failed, cancelled) Example: - @server.experimental.cancel_task() - async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult: - return await cancel_task(store, request.params.taskId) + ```python + async def handle_cancel(ctx, params: CancelTaskRequestParams) -> CancelTaskResult: + return await cancel_task(store, params.task_id) + ``` """ task = await store.get_task(task_id) if task is None: diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index d01d28b80..f2d5e2b9a 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -17,7 +17,7 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: """Creates a pair of bidirectional memory streams for client-server communication. - Returns: + Yields: A tuple of (client_streams, server_streams) where each is a tuple of (read_stream, write_stream) """ diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 9dedd2e5d..1858eeac3 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,6 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Any from mcp.types import JSONRPCMessage, RequestId @@ -30,8 +31,10 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None - # Request-specific context (e.g., headers, auth info) - request_context: object | None = None + # Transport-specific request context (e.g. starlette Request for HTTP + # transports, None for stdio). Typed as Any because the server layer is + # transport-agnostic. + request_context: Any = None # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None # Callback to close the standalone GET SSE stream (for unsolicited notifications) diff --git a/src/mcp/shared/metadata_utils.py b/src/mcp/shared/metadata_utils.py index 2b66996bd..6e4d33da0 100644 --- a/src/mcp/shared/metadata_utils.py +++ b/src/mcp/shared/metadata_utils.py @@ -1,7 +1,7 @@ """Utility functions for working with metadata in MCP types. These utilities are primarily intended for client-side usage to properly display -human-readable names in user interfaces in a spec compliant way. +human-readable names in user interfaces in a spec-compliant way. """ from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool @@ -18,11 +18,13 @@ def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implemen For other objects: title > name Example: + ```python # In a client displaying available tools tools = await session.list_tools() for tool in tools.tools: display_name = get_display_name(tool) print(f"Available tool: {display_name}") + ``` Args: obj: An MCP object with name and optional title fields diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py deleted file mode 100644 index 510bd8163..000000000 --- a/src/mcp/shared/progress.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections.abc import Generator -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Generic - -from pydantic import BaseModel - -from mcp.shared._context import RequestContext, SessionT -from mcp.types import ProgressToken - - -class Progress(BaseModel): - progress: float - total: float | None - - -@dataclass -class ProgressContext(Generic[SessionT]): - session: SessionT - progress_token: ProgressToken - total: float | None - current: float = field(default=0.0, init=False) - - async def progress(self, amount: float, message: str | None = None) -> None: - self.current += amount - - await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total, message=message - ) - - -@contextmanager -def progress( - ctx: RequestContext[SessionT], - total: float | None = None, -) -> Generator[ProgressContext[SessionT], None]: - progress_token = ctx.meta.get("progress_token") if ctx.meta else None - if progress_token is None: # pragma: no cover - raise ValueError("No progress token provided") - - progress_ctx = ProgressContext(ctx.session, progress_token, total) - try: - yield progress_ctx - finally: - pass diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py index 7ec4a443c..fe24b016f 100644 --- a/src/mcp/shared/response_router.py +++ b/src/mcp/shared/response_router.py @@ -25,6 +25,7 @@ class ResponseRouter(Protocol): and deliver the response/error to the appropriate handler. Example: + ```python class TaskResultHandler(ResponseRouter): def route_response(self, request_id, response): resolver = self._pending_requests.pop(request_id, None) @@ -32,6 +33,7 @@ def route_response(self, request_id, response): resolver.set_result(response) return True return False + ``` """ def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 453e36274..b617d702f 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -60,8 +60,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): cancellation handling: Example: + ```python with request_responder as resp: await resp.respond(result) + ``` The context manager ensures: 1. Proper cancellation scope setup and cleanup @@ -115,6 +117,7 @@ async def respond(self, response: SendResultT | ErrorData) -> None: """Send a response for this request. Must be called within a context manager block. + Raises: RuntimeError: If not used within a context manager AssertionError: If request was already responded to @@ -142,7 +145,7 @@ async def cancel(self) -> None: # Send an error response to indicate cancellation await self._session._send_response( # type: ignore[reportPrivateUsage] request_id=self.request_id, - response=ErrorData(code=0, message="Request cancelled", data=None), + response=ErrorData(code=0, message="Request cancelled"), ) @property @@ -235,7 +238,7 @@ async def send_request( metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: - """Sends a request and wait for a response. + """Sends a request and waits for a response. Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. @@ -458,6 +461,12 @@ async def _handle_response(self, message: SessionMessage) -> None: if not isinstance(message.message, JSONRPCResponse | JSONRPCError): return # pragma: no cover + if message.message.id is None: + # Narrows to JSONRPCError since JSONRPCResponse.id is always RequestId + error = message.message.error + logging.warning(f"Received error with null ID: {error.message}") + await self._handle_incoming(MCPError(error.code, error.message, error.data)) + return # Normalize response ID to handle type mismatches (e.g., "0" vs 0) response_id = self._normalize_request_id(message.message.id) @@ -506,4 +515,4 @@ async def send_progress_notification( async def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception ) -> None: - """A generic handler for incoming messages. Overwritten by subclasses.""" + """A generic handler for incoming messages. Overridden by subclasses.""" diff --git a/src/mcp/types/_types.py b/src/mcp/types/_types.py index 320422636..9005d253a 100644 --- a/src/mcp/types/_types.py +++ b/src/mcp/types/_types.py @@ -22,7 +22,7 @@ provided by the client. See the "Protocol Version Header" at -https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header). +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#protocol-version-header. """ ProgressToken = str | int @@ -108,8 +108,7 @@ class Request(MCPModel, Generic[RequestParamsT, MethodT]): class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): - """Base class for paginated requests, - matching the schema's PaginatedRequest interface.""" + """Base class for paginated requests, matching the schema's PaginatedRequest interface.""" params: PaginatedRequestParams | None = None @@ -174,10 +173,10 @@ class Icon(MCPModel): theme: IconTheme | None = None """Optional theme specifier. - - `"light"` indicates the icon is designed for a light background, `"dark"` indicates the icon + + `"light"` indicates the icon is designed for a light background, `"dark"` indicates the icon is designed for a dark background. - + See https://modelcontextprotocol.io/specification/2025-11-25/schema#icon for more details. """ @@ -536,7 +535,7 @@ class TaskStatusNotificationParams(NotificationParams, Task): class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): """An optional notification from the receiver to the requestor, informing them that a task's status has changed. - Receivers are not required to send these notifications + Receivers are not required to send these notifications. """ method: Literal["notifications/tasks/status"] = "notifications/tasks/status" @@ -608,7 +607,7 @@ class ProgressNotificationParams(NotificationParams): message: str | None = None """Message related to progress. - This should provide relevant human readable progress information. + This should provide relevant human-readable progress information. """ @@ -999,7 +998,9 @@ class ToolResultContent(MCPModel): SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent """Basic content types for sampling responses (without tool use). -Used for backwards-compatible CreateMessageResult when tools are not used.""" + +Used for backwards-compatible CreateMessageResult when tools are not used. +""" class SamplingMessage(MCPModel): @@ -1117,7 +1118,7 @@ class ToolAnnotations(MCPModel): idempotent_hint: bool | None = None """ If true, calling the tool repeatedly with the same arguments - will have no additional effect on the its environment. + will have no additional effect on its environment. (This property is meaningful only when `read_only_hint == false`) Default: false """ @@ -1265,7 +1266,7 @@ class ModelPreferences(MCPModel): sampling. Because LLMs can vary along multiple dimensions, choosing the "best" model is - rarely straightforward. Different models excel in different areas—some are + rarely straightforward. Different models excel in different areas—some are faster but less capable, others are more capable but more expensive, and so on. This interface allows servers to express their priorities across multiple dimensions to help clients make an appropriate selection for their use case. @@ -1369,7 +1370,7 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling class CreateMessageResult(Result): - """The client's response to a sampling/create_message request from the server. + """The client's response to a sampling/createMessage request from the server. This is the backwards-compatible version that returns single content (no arrays). Used when the request does not include tools. @@ -1386,7 +1387,7 @@ class CreateMessageResult(Result): class CreateMessageResultWithTools(Result): - """The client's response to a sampling/create_message request when tools were provided. + """The client's response to a sampling/createMessage request when tools were provided. This version supports array content for tool use flows. """ @@ -1426,14 +1427,14 @@ class PromptReference(MCPModel): type: Literal["ref/prompt"] = "ref/prompt" name: str - """The name of the prompt or prompt template""" + """The name of the prompt or prompt template.""" class CompletionArgument(MCPModel): """The argument's information for completion requests.""" name: str - """The name of the argument""" + """The name of the argument.""" value: str """The value of the argument to use for completion matching.""" @@ -1451,7 +1452,7 @@ class CompleteRequestParams(RequestParams): ref: ResourceTemplateReference | PromptReference argument: CompletionArgument context: CompletionContext | None = None - """Additional, optional context for completions""" + """Additional, optional context for completions.""" class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): @@ -1479,7 +1480,7 @@ class Completion(MCPModel): class CompleteResult(Result): - """The server's response to a completion/complete request""" + """The server's response to a completion/complete request.""" completion: Completion @@ -1522,6 +1523,7 @@ class Root(MCPModel): class ListRootsResult(Result): """The client's response to a roots/list request from the server. + This result contains an array of Root objects, each representing a root directory or file that the server can operate on. """ @@ -1643,7 +1645,7 @@ class ElicitRequestFormParams(RequestParams): requested_schema: ElicitRequestedSchema """ - A restricted subset of JSON Schema defining the structure of expected response. + A restricted subset of JSON Schema defining the structure of the expected response. Only top-level properties are allowed, without nesting. """ @@ -1697,8 +1699,8 @@ class ElicitResult(Result): content: dict[str, str | int | float | bool | list[str] | None] | None = None """ The submitted form data, only present when action is "accept" in form mode. - Contains values matching the requested schema. Values can be strings, integers, - booleans, or arrays of strings. + Contains values matching the requested schema. Values can be strings, integers, floats, + booleans, arrays of strings, or null. For URL mode, this field is omitted. """ diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 0cfdc993a..84304a37c 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -75,7 +75,7 @@ class JSONRPCError(BaseModel): """A response to a request that indicates an error occurred.""" jsonrpc: Literal["2.0"] - id: RequestId + id: RequestId | None error: ErrorData diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 268e968aa..2e39f1363 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -77,7 +77,8 @@ def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNot def stream_spy() -> Generator[Callable[[], StreamSpyCollection], None, None]: """Fixture that provides spies for both client and server write streams. - Example usage: + Example: + ```python async def test_something(stream_spy): # ... set up server and client ... @@ -92,6 +93,7 @@ async def test_something(stream_spy): # Clear for the next operation spies.clear() + ``` """ client_spy = None server_spy = None diff --git a/tests/client/test_client.py b/tests/client/test_client.py index d483ae54b..45300063a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -11,7 +11,7 @@ from mcp import types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer from mcp.types import ( CallToolResult, @@ -41,33 +41,36 @@ @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - @server.list_resources() - async def handle_list_resources(): - return [Resource(uri="memory://test", name="Test Resource", description="A test resource")] + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[Resource(uri="memory://test", name="Test Resource", description="A test resource")] + ) - @server.subscribe_resource() - async def handle_subscribe_resource(uri: str): - pass + async def handle_subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + return EmptyResult() - @server.unsubscribe_resource() - async def handle_unsubscribe_resource(uri: str): - pass + async def handle_unsubscribe_resource( + ctx: ServerRequestContext, params: types.UnsubscribeRequestParams + ) -> EmptyResult: + return EmptyResult() - @server.set_logging_level() - async def handle_set_logging_level(level: str): - pass + async def handle_set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + return EmptyResult() - @server.completion() - async def handle_completion( - ref: types.PromptReference | types.ResourceTemplateReference, - argument: types.CompletionArgument, - context: types.CompletionContext | None, - ) -> types.Completion | None: - return types.Completion(values=[]) + async def handle_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + return types.CompleteResult(completion=types.Completion(values=[])) - return server + return Server( + name="test_server", + on_list_resources=handle_list_resources, + on_subscribe_resource=handle_subscribe_resource, + on_unsubscribe_resource=handle_unsubscribe_resource, + on_set_logging_level=handle_set_logging_level, + on_completion=handle_completion, + ) @pytest.fixture @@ -202,19 +205,14 @@ async def test_client_send_progress_notification(): """Test sending progress notification.""" received_from_client = None event = anyio.Event() - server = Server(name="test_server") - - @server.progress_notification() - async def handle_progress_notification( - progress_token: str | int, - progress: float = 0.0, - total: float | None = None, - message: str | None = None, - ) -> None: + + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: nonlocal received_from_client - received_from_client = {"progress_token": progress_token, "progress": progress} + received_from_client = {"progress_token": params.progress_token, "progress": params.progress} event.set() + server = Server(name="test_server", on_progress=handle_progress) + async with Client(server) as client: await client.send_progress_notification(progress_token="token123", progress=50.0) await event.wait() diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 5cca8c194..cc2e14e46 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -8,7 +8,6 @@ import socket from collections.abc import AsyncGenerator, Generator from contextlib import asynccontextmanager -from typing import Any import pytest from starlette.applications import Starlette @@ -17,7 +16,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool from tests.test_helpers import wait_for_server @@ -47,54 +46,56 @@ def run_unicode_server(port: int) -> None: # pragma: no cover import uvicorn # Need to recreate the server setup in this process - server = Server(name="unicode_test_server") - - @server.list_tools() - async def list_tools() -> list[Tool]: - """List tools with Unicode descriptions.""" - return [ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, + }, + "required": ["text"], }, - "required": ["text"], - }, - ), - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: - """Handle tool calls with Unicode content.""" - if name == "echo_unicode": - text = arguments.get("text", "") if arguments else "" - return [ - TextContent( - type="text", - text=f"Echo: {text}", - ) + ), ] - else: - raise ValueError(f"Unknown tool: {name}") - - @server.list_prompts() - async def list_prompts() -> list[types.Prompt]: - """List prompts with Unicode names and descriptions.""" - return [ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "echo_unicode": + text = params.arguments.get("text", "") if params.arguments else "" + return types.CallToolResult( + content=[ + TextContent( + type="text", + text=f"Echo: {text}", + ) + ] ) - ] + else: + raise ValueError(f"Unknown tool: {params.name}") + + async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], + ) + ] + ) - @server.get_prompt() - async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPromptResult: - """Get a prompt with Unicode content.""" - if name == "unicode_prompt": + async def handle_get_prompt( + ctx: ServerRequestContext, params: types.GetPromptRequestParams + ) -> types.GetPromptResult: + if params.name == "unicode_prompt": return types.GetPromptResult( messages=[ types.PromptMessage( @@ -106,7 +107,15 @@ async def get_prompt(name: str, arguments: dict[str, Any] | None) -> types.GetPr ) ] ) - raise ValueError(f"Unknown prompt: {name}") + raise ValueError(f"Unknown prompt: {params.name}") + + server = Server( + name="unicode_test_server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + on_list_prompts=handle_list_prompts, + on_get_prompt=handle_get_prompt, + ) # Create the session manager session_manager = StreamableHTTPSessionManager( diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 4d7c53db2..f70fb9277 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -3,9 +3,9 @@ import pytest from mcp import Client, types -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import ListToolsRequest, ListToolsResult +from mcp.types import ListToolsResult from .conftest import StreamSpyCollection @@ -105,14 +105,16 @@ async def test_list_tools_with_strict_server_validation( async def test_list_tools_with_lowlevel_server(): """Test that list_tools works with a lowlevel Server using params.""" - server = Server("test-lowlevel") - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListToolsResult: # Echo back what cursor we received in the tool description - cursor = request.params.cursor if request.params else None + cursor = params.cursor if params else None return ListToolsResult(tools=[types.Tool(name="test_tool", description=f"cursor={cursor}", input_schema={})]) + server = Server("test-lowlevel", on_list_tools=handle_list_tools) + async with Client(server) as client: result = await client.list_tools() assert result.tools[0].description == "cursor=None" diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 9e233acc3..69c8afeb8 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -116,6 +116,58 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None: await session.list_tools() +def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette: + """Create a server that returns an HTTP error for non-init requests.""" + + async def handle_mcp_request(request: Request) -> Response: + body = await request.body() + data = json.loads(body) + + if data.get("method") == "initialize": + return _init_json_response(data) + + if "id" not in data: + if error_on_notifications: + return Response(status_code=error_status) + return Response(status_code=202) + + return Response(status_code=error_status) + + return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) + + +async def test_http_error_status_sends_jsonrpc_error() -> None: + """Verify HTTP 5xx errors unblock the pending request with an MCPError. + + When a server returns a non-2xx status code (e.g. 500), the client should + send a JSONRPCError so the pending request resolves immediately instead of + raising an unhandled httpx.HTTPStatusError that causes the caller to hang. + """ + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_http_error_app(500))) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + with pytest.raises(MCPError, match="Server returned an error response"): # pragma: no branch + await session.list_tools() + + +async def test_http_error_on_notification_does_not_hang() -> None: + """Verify HTTP errors on notifications are silently ignored. + + When a notification gets an HTTP error, there is no pending request to + unblock, so the client should just return without sending a JSONRPCError. + """ + app = _create_http_error_app(500, error_on_notifications=True) + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + # Should not raise or hang — the error is silently ignored for notifications + await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) + + def _create_invalid_json_response_app() -> Starlette: """Create a server that returns invalid JSON for requests.""" diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index cc93d303b..d78197b5c 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -1,50 +1,41 @@ -import inspect import logging -from contextlib import contextmanager from typing import Any -from unittest.mock import patch -import jsonschema import pytest from mcp import Client -from mcp.server.lowlevel import Server -from mcp.types import Tool - - -@contextmanager -def bypass_server_output_validation(): - """Context manager that bypasses server-side output validation. - This simulates a malicious or non-compliant server that doesn't validate - its outputs, allowing us to test client-side validation. - """ - # Save the original validate function - original_validate = jsonschema.validate - - # Create a mock that tracks which module is calling it - def selective_mock(instance: Any = None, schema: Any = None, *args: Any, **kwargs: Any) -> None: - # Check the call stack to see where this is being called from - for frame_info in inspect.stack(): - # If called from the server module, skip validation - # TODO: fix this as it's a rather gross workaround and will eventually break - # Normalize path separators for cross-platform compatibility - normalized_path = frame_info.filename.replace("\\", "/") - if "mcp/server/lowlevel/server.py" in normalized_path: - return None - # Otherwise, use the real validation (for client-side) - return original_validate(instance=instance, schema=schema, *args, **kwargs) - - with patch("jsonschema.validate", selective_mock): - yield +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +def _make_server( + tools: list[Tool], + structured_content: dict[str, Any], +) -> Server: + """Create a low-level server that returns the given structured_content for any tool call.""" + + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=tools) + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text="result")], + structured_content=structured_content, + ) + + return Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_basemodel(): """Test that client validates structured content against schema for BaseModel outputs""" - # Create a malicious low-level server that returns invalid structured content - server = Server("test-server") - - # Define the expected schema for our tool output_schema = { "type": "object", "properties": {"name": {"type": "string", "title": "Name"}, "age": {"type": "integer", "title": "Age"}}, @@ -52,39 +43,27 @@ async def test_tool_structured_output_client_side_validation_basemodel(): "title": "UserOutput", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_user", description="Get user data", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - age is string instead of integer - # The low-level server will wrap this in CallToolResult - return {"name": "John", "age": "invalid"} # Invalid: age should be int + ], + structured_content={"name": "John", "age": "invalid"}, # Invalid: age should be int + ) - # Test that client validates the structured content - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_user", {}) - # Verify it's a validation error - assert "Invalid structured content returned by tool get_user" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_user", {}) + assert "Invalid structured content returned by tool get_user" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_primitive(): """Test that client validates structured content for primitive outputs""" - server = Server("test-server") - - # Primitive types are wrapped in {"result": value} output_schema = { "type": "object", "properties": {"result": {"type": "integer", "title": "Result"}}, @@ -92,122 +71,95 @@ async def test_tool_structured_output_client_side_validation_primitive(): "title": "calculate_Output", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="calculate", description="Calculate something", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - result is string instead of integer - return {"result": "not_a_number"} # Invalid: should be int + ], + structured_content={"result": "not_a_number"}, # Invalid: should be int + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("calculate", {}) - assert "Invalid structured content returned by tool calculate" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("calculate", {}) + assert "Invalid structured content returned by tool calculate" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_dict_typed(): """Test that client validates dict[str, T] structured content""" - server = Server("test-server") - - # dict[str, int] schema output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_scores", description="Get scores", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - values should be integers - return {"alice": "100", "bob": "85"} # Invalid: values should be int + ], + structured_content={"alice": "100", "bob": "85"}, # Invalid: values should be int + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_scores", {}) - assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_scores", {}) + assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) @pytest.mark.anyio async def test_tool_structured_output_client_side_validation_missing_required(): """Test that client validates missing required fields""" - server = Server("test-server") - output_schema = { "type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "email": {"type": "string"}}, - "required": ["name", "age", "email"], # All fields required + "required": ["name", "age", "email"], "title": "PersonOutput", } - @server.list_tools() - async def list_tools(): - return [ + server = _make_server( + tools=[ Tool( name="get_person", description="Get person data", input_schema={"type": "object"}, output_schema=output_schema, ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return structured content missing required field 'email' - return {"name": "John", "age": 30} # Missing required 'email' + ], + structured_content={"name": "John", "age": 30}, # Missing required 'email' + ) - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_person", {}) - assert "Invalid structured content returned by tool get_person" in str(exc_info.value) + async with Client(server) as client: + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_person", {}) + assert "Invalid structured content returned by tool get_person" in str(exc_info.value) @pytest.mark.anyio async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): """Test that client logs warning when tool is not in list_tools but has output_schema""" - server = Server("test-server") - @server.list_tools() - async def list_tools() -> list[Tool]: - # Return empty list - tool is not listed - return [] + async def on_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) + + async def on_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text="result")], + structured_content={"result": 42}, + ) - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - # Server still responds to the tool call with structured content - return {"result": 42} + server = Server("test-server", on_list_tools=on_list_tools, on_call_tool=on_call_tool) - # Set logging level to capture warnings caplog.set_level(logging.WARNING) - with bypass_server_output_validation(): - async with Client(server) as client: - # Call a tool that wasn't listed - result = await client.call_tool("mystery_tool", {}) - assert result.structured_content == {"result": 42} - assert result.is_error is False + async with Client(server) as client: + result = await client.call_tool("mystery_tool", {}) + assert result.structured_content == {"result": 42} + assert result.is_error is False - # Check that warning was logged - assert "Tool mystery_tool not listed" in caplog.text + assert "Tool mystery_tool not listed" in caplog.text diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 30ecb0ac3..47be3e208 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -2,31 +2,31 @@ import pytest -from mcp import Client +from mcp import Client, types from mcp.client._memory import InMemoryTransport -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer -from mcp.types import Resource +from mcp.types import ListResourcesResult, Resource @pytest.fixture def simple_server() -> Server: """Create a simple MCP server for testing.""" - server = Server(name="test_server") - - # pragma: no cover - handler exists only to register a resource capability. - # Transport tests verify stream creation, not handler invocation. - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - return server + async def handle_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: # pragma: no cover + return ListResourcesResult( + resources=[ + Resource( + uri="memory://test", + name="Test Resource", + description="A test resource", + ) + ] + ) + + return Server(name="test_server", on_list_resources=handle_list_resources) @pytest.fixture diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py index f21abf4d0..613c794eb 100644 --- a/tests/experimental/tasks/client/test_tasks.py +++ b/tests/experimental/tasks/client/test_tasks.py @@ -1,43 +1,38 @@ """Tests for the experimental client task methods (session.experimental).""" +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any import anyio import pytest from anyio import Event from anyio.abc import TaskGroup -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.experimental.tasks.helpers import task_execution from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder from mcp.types import ( CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, - ClientResult, CreateTaskResult, - GetTaskPayloadRequest, + GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, - ServerNotification, - ServerRequest, + ListToolsResult, + PaginatedRequestParams, TaskMetadata, TextContent, - Tool, ) +pytestmark = pytest.mark.anyio + @dataclass class AppContext: @@ -48,44 +43,53 @@ class AppContext: task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -@pytest.mark.anyio -async def test_session_experimental_get_task() -> None: - """Test session.experimental.get_task() method.""" - # Note: We bypass the normal lifespan mechanism - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] - store = InMemoryTaskStore() +async def _handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None +) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) +async def _handle_call_tool_with_done_event( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams, *, result_text: str = "Done" +) -> CallToolResult | CreateTaskResult: + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) - done_event = Event() - app.task_done_events[task.task_id] = done_event + done_event = Event() + app.task_done_events[task.task_id] = done_event - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() + async def do_work() -> None: + async with task_execution(task.task_id, app.store) as task_ctx: + await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) + done_event.set() - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) - raise NotImplementedError + raise NotImplementedError + + +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): + @asynccontextmanager + async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + return app_lifespan + + +async def test_session_experimental_get_task() -> None: + """Test session.experimental.get_task() method.""" + store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -96,280 +100,141 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use session.experimental to get task status - task_status = await client_session.experimental.get_task(task_id) + # Wait for task to complete + await task_done_events[task_id].wait() - assert task_status.task_id == task_id - assert task_status.status == "completed" + # Use session.experimental to get task status + task_status = await client.session.experimental.get_task(task_id) - tg.cancel_scope.cancel() + assert task_status.task_id == task_id + assert task_status.status == "completed" -@pytest.mark.anyio async def test_session_experimental_get_task_result() -> None: """Test session.experimental.get_task_result() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: + return await _handle_call_tool_with_done_event(ctx, params, result_text="Task result content") - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context - app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete( - CallToolResult(content=[TextContent(type="text", text="Task result content")]) - ) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) return GetTaskPayloadResult(**result.model_dump()) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks(on_task_result=handle_get_task_result) + + async with Client(server) as client: + # Create a task + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Wait for task to complete - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Use TaskClient to get task result - task_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + # Wait for task to complete + await task_done_events[task_id].wait() - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Task result content" + # Use TaskClient to get task result + task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - tg.cancel_scope.cancel() + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Task result content" -@pytest.mark.anyio async def test_session_experimental_list_tasks() -> None: """Test TaskClient.list_tasks() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: app = ctx.lifespan_context - if ctx.experimental.is_task: - task_metadata = ctx.experimental.task_metadata - assert task_metadata is not None - task = await app.store.create_task(task_metadata) - - done_event = Event() - app.task_done_events[task.task_id] = done_event - - async def do_work(): - async with task_execution(task.task_id, app.store) as task_ctx: - await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) - done_event.set() - - app.task_group.start_soon(do_work) - return CreateTaskResult(task=task) - - raise NotImplementedError - - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: - app = server.request_context.lifespan_context - tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + cursor = params.cursor if params else None + tasks_list, next_cursor = await app.store.list_tasks(cursor=cursor) return ListTasksResult(tasks=tasks_list, next_cursor=next_cursor) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool_with_done_event, + ) + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) + + async with Client(server) as client: + # Create two tasks + for _ in range(2): + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create two tasks - for _ in range(2): - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - await app_context.task_done_events[create_result.task.task_id].wait() - - # Use TaskClient to list tasks - list_result = await client_session.experimental.list_tasks() + CreateTaskResult, + ) + await task_done_events[create_result.task.task_id].wait() - assert len(list_result.tasks) == 2 + # Use TaskClient to list tasks + list_result = await client.session.experimental.list_tasks() - tg.cancel_scope.cancel() + assert len(list_result.tasks) == 2 -@pytest.mark.anyio async def test_session_experimental_cancel_task() -> None: """Test TaskClient.cancel_task() method.""" - server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} - @server.list_tools() - async def list_tools(): - return [Tool(name="test_tool", description="Test", input_schema={"type": "object"})] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool_no_work( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context if ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata @@ -377,14 +242,12 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextCon task = await app.store.create_task(task_metadata) # Don't start any work - task stays in "working" status return CreateTaskResult(task=task) - raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -395,14 +258,14 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" - await app.store.update_task(request.params.task_id, status="cancelled") - # CancelTaskResult extends Task, so we need to return the updated task info - updated_task = await app.store.get_task(request.params.task_id) + async def handle_cancel_task( + ctx: ServerRequestContext[AppContext], params: CancelTaskRequestParams + ) -> CancelTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" + await app.store.update_task(params.task_id, status="cancelled") + updated_task = await app.store.get_task(params.task_id) assert updated_task is not None return CancelTaskResult( task_id=updated_task.task_id, @@ -412,63 +275,35 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: ttl=updated_task.ttl, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server: Server[AppContext] = Server( + "test-server", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=_handle_list_tools, + on_call_tool=handle_call_tool_no_work, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task, on_cancel_task=handle_cancel_task) + + async with Client(server) as client: + # Create a task (but don't complete it) + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Create a task (but don't complete it) - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="test_tool", - arguments={}, - task=TaskMetadata(ttl=60000), - ) - ), - CreateTaskResult, - ) - task_id = create_result.task.task_id - - # Verify task is working - status_before = await client_session.experimental.get_task(task_id) - assert status_before.status == "working" + CreateTaskResult, + ) + task_id = create_result.task.task_id - # Cancel the task - await client_session.experimental.cancel_task(task_id) + # Verify task is working + status_before = await client.session.experimental.get_task(task_id) + assert status_before.status == "working" - # Verify task is cancelled - status_after = await client_session.experimental.get_task(task_id) - assert status_after.status == "cancelled" + # Cancel the task + await client.session.experimental.cancel_task(task_id) - tg.cancel_scope.cancel() + # Verify task is cancelled + status_after = await client.session.experimental.get_task(task_id) + assert status_after.status == "cancelled" diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py index 41cecc129..b5b79033d 100644 --- a/tests/experimental/tasks/server/test_integration.py +++ b/tests/experimental/tasks/server/test_integration.py @@ -8,46 +8,37 @@ 5. Client retrieves result with tasks/result """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any import anyio import pytest from anyio import Event from anyio.abc import TaskGroup -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.experimental.tasks.helpers import task_execution from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder from mcp.types import ( - TASK_REQUIRED, CallToolRequest, CallToolRequestParams, CallToolResult, - ClientResult, CreateTaskResult, - GetTaskPayloadRequest, GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, - ServerNotification, - ServerRequest, + ListToolsResult, + PaginatedRequestParams, TaskMetadata, TextContent, - Tool, - ToolExecution, ) +pytestmark = pytest.mark.anyio + @dataclass class AppContext: @@ -55,77 +46,57 @@ class AppContext: task_group: TaskGroup store: InMemoryTaskStore - # Events to signal when tasks complete (for testing without sleeps) task_done_events: dict[str, Event] = field(default_factory=lambda: {}) -@pytest.mark.anyio +def _make_lifespan(store: InMemoryTaskStore, task_done_events: dict[str, Event]): + @asynccontextmanager + async def app_lifespan(server: Server[AppContext]) -> AsyncIterator[AppContext]: + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store, task_done_events=task_done_events) + + return app_lifespan + + async def test_task_lifecycle_with_task_execution() -> None: - """Test the complete task lifecycle using the task_execution pattern. - - This demonstrates the recommended way to implement task-augmented tools: - 1. Create task in store - 2. Spawn work using task_execution() context manager - 3. Return CreateTaskResult immediately - 4. Work executes in background, auto-fails on exception - """ - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks") # type: ignore[assignment] + """Test the complete task lifecycle using the task_execution pattern.""" store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="process_data", - description="Process data asynchronously", - input_schema={ - "type": "object", - "properties": {"input": {"type": "string"}}, - }, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "process_data" and ctx.experimental.is_task: - # 1. Create task in store + if params.name == "process_data" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None task = await app.store.create_task(task_metadata) - # 2. Create event to signal completion (for testing) done_event = Event() app.task_done_events[task.task_id] = done_event - # 3. Define work function using task_execution for safety - async def do_work(): + async def do_work() -> None: async with task_execution(task.task_id, app.store) as task_ctx: await task_ctx.update_status("Processing input...") - # Simulate work - input_value = arguments.get("input", "") + input_value = (params.arguments or {}).get("input", "") result_text = f"Processed: {input_value.upper()}" await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) - # Signal completion done_event.set() - # 4. Spawn work in task group (from lifespan_context) app.task_group.start_soon(do_work) - - # 5. Return CreateTaskResult immediately return CreateTaskResult(task=task) raise NotImplementedError - # Register task query handlers (delegate to store) - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -136,134 +107,91 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - @server.experimental.get_task_result() async def handle_get_task_result( - request: GetTaskPayloadRequest, + ctx: ServerRequestContext[AppContext], params: GetTaskPayloadRequestParams ) -> GetTaskPayloadResult: - app = server.request_context.lifespan_context - result = await app.store.get_result(request.params.task_id) - assert result is not None, f"Test setup error: result for {request.params.task_id} should exist" + app = ctx.lifespan_context + result = await app.store.get_result(params.task_id) + assert result is not None, f"Test setup error: result for {params.task_id} should exist" assert isinstance(result, CallToolResult) - # Return as GetTaskPayloadResult (which accepts extra fields) return GetTaskPayloadResult(**result.model_dump()) - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListTasksResult: raise NotImplementedError - # Set up client-server communication - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no cover - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-tasks", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks( + on_get_task=handle_get_task, + on_task_result=handle_get_task_result, + on_list_tasks=handle_list_tasks, + ) + + async with Client(server) as client: + # Step 1: Send task-augmented tool call + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - # Create app context with task group and store - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # === Step 1: Send task-augmented tool call === - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="process_data", - arguments={"input": "hello world"}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - assert isinstance(create_result, CreateTaskResult) - assert create_result.task.status == "working" - task_id = create_result.task.task_id - - # === Step 2: Wait for task to complete === - await app_context.task_done_events[task_id].wait() + CreateTaskResult, + ) - task_status = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), - GetTaskResult, - ) + assert isinstance(create_result, CreateTaskResult) + assert create_result.task.status == "working" + task_id = create_result.task.task_id - assert task_status.task_id == task_id - assert task_status.status == "completed" + # Step 2: Wait for task to complete + await task_done_events[task_id].wait() - # === Step 3: Retrieve the actual result === - task_result = await client_session.send_request( - GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id=task_id)), - CallToolResult, - ) + task_status = await client.session.experimental.get_task(task_id) + assert task_status.task_id == task_id + assert task_status.status == "completed" - assert len(task_result.content) == 1 - content = task_result.content[0] - assert isinstance(content, TextContent) - assert content.text == "Processed: HELLO WORLD" + # Step 3: Retrieve the actual result + task_result = await client.session.experimental.get_task_result(task_id, CallToolResult) - tg.cancel_scope.cancel() + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed: HELLO WORLD" -@pytest.mark.anyio async def test_task_auto_fails_on_exception() -> None: """Test that task_execution automatically fails the task on unhandled exception.""" - # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message - server: Server[AppContext, Any] = Server("test-tasks-failure") # type: ignore[assignment] store = InMemoryTaskStore() + task_done_events: dict[str, Event] = {} + + async def handle_list_tools( + ctx: ServerRequestContext[AppContext], params: PaginatedRequestParams | None + ) -> ListToolsResult: + raise NotImplementedError - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object", "properties": {}}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext[AppContext], params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: app = ctx.lifespan_context - if name == "failing_task" and ctx.experimental.is_task: + if params.name == "failing_task" and ctx.experimental.is_task: task_metadata = ctx.experimental.task_metadata assert task_metadata is not None task = await app.store.create_task(task_metadata) - # Create event to signal completion (for testing) done_event = Event() app.task_done_events[task.task_id] = done_event - async def do_failing_work(): + async def do_failing_work() -> None: async with task_execution(task.task_id, app.store) as task_ctx: await task_ctx.update_status("About to fail...") raise RuntimeError("Something went wrong!") - # Note: complete() is never called, but task_execution - # will automatically call fail() due to the exception # This line is reached because task_execution suppresses the exception done_event.set() @@ -272,11 +200,10 @@ async def do_failing_work(): raise NotImplementedError - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: - app = server.request_context.lifespan_context - task = await app.store.get_task(request.params.task_id) - assert task is not None, f"Test setup error: task {request.params.task_id} should exist" + async def handle_get_task(ctx: ServerRequestContext[AppContext], params: GetTaskRequestParams) -> GetTaskResult: + app = ctx.lifespan_context + task = await app.store.get_task(params.task_id) + assert task is not None, f"Test setup error: task {params.task_id} should exist" return GetTaskResult( task_id=task.task_id, status=task.status, @@ -287,64 +214,34 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=task.poll_interval, ) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no cover - - async def run_server(app_context: AppContext): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + server: Server[AppContext] = Server( + "test-tasks-failure", + lifespan=_make_lifespan(store, task_done_events), + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks(on_get_task=handle_get_task) + + async with Client(server) as client: + # Send task request + create_result = await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async for message in server_session.incoming_messages: - await server._handle_message(message, server_session, app_context, raise_exceptions=False) - - async with anyio.create_task_group() as tg: - app_context = AppContext(task_group=tg, store=store) - tg.start_soon(run_server, app_context) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Send task request - create_result = await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="failing_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CreateTaskResult, - ) - - task_id = create_result.task.task_id + CreateTaskResult, + ) - # Wait for task to complete (even though it fails) - await app_context.task_done_events[task_id].wait() + task_id = create_result.task.task_id - # Check that task was auto-failed - task_status = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task_id)), GetTaskResult - ) + # Wait for task to complete (even though it fails) + await task_done_events[task_id].wait() - assert task_status.status == "failed" - assert task_status.status_message == "Something went wrong!" + # Check that task was auto-failed + task_status = await client.session.experimental.get_task(task_id) - tg.cancel_scope.cancel() + assert task_status.status == "failed" + assert task_status.status_message == "Something went wrong!" diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py index 0d5d1df77..027382e69 100644 --- a/tests/experimental/tasks/server/test_run_task_flow.py +++ b/tests/experimental/tasks/server/test_run_task_flow.py @@ -8,159 +8,102 @@ These are integration tests that verify the complete flow works end-to-end. """ -from typing import Any from unittest.mock import Mock import anyio import pytest from anyio import Event -from mcp.client.session import ClientSession -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.experimental.task_support import TaskSupport from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue -from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, - CancelTaskRequest, - CancelTaskResult, CreateTaskResult, - GetTaskPayloadRequest, - GetTaskPayloadResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, - ListTasksResult, + ListToolsResult, + PaginatedRequestParams, TextContent, - Tool, - ToolExecution, ) +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_run_task_basic_flow() -> None: - """Test the basic run_task flow without elicitation. - 1. enable_tasks() sets up handlers - 2. Client calls tool with task field - 3. run_task() spawns work, returns CreateTaskResult - 4. Work completes in background - 5. Client polls and sees completed status - """ - server = Server("test-run-task") +async def _handle_list_tools_simple_task( + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + raise NotImplementedError - # One-line setup - server.experimental.enable_tasks() - # Track when work completes and capture received meta +async def test_run_task_basic_flow() -> None: + """Test the basic run_task flow without elicitation.""" work_completed = Event() received_meta: list[str | None] = [None] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="simple_task", - description="A simple task", - input_schema={"type": "object", "properties": {"input": {"type": "string"}}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) - # Capture the meta from the request (if present) if ctx.meta is not None: # pragma: no branch received_meta[0] = ctx.meta.get("custom_field") async def work(task: ServerTaskContext) -> CallToolResult: await task.update_status("Working...") - input_val = arguments.get("input", "default") + input_val = (params.arguments or {}).get("input", "default") result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) work_completed.set() return result return await ctx.experimental.run_task(work) - # Set up streams - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + server = Server( + "test-run-task", + on_list_tools=_handle_list_tools_simple_task, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() + + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task( + "simple_task", + {"input": "hello"}, + meta={"custom_field": "test_value"}, ) - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - # Initialize - await client_session.initialize() - - # Call tool as task (with meta to test that code path) - result = await client_session.experimental.call_tool_as_task( - "simple_task", - {"input": "hello"}, - meta={"custom_field": "test_value"}, - ) - - # Should get CreateTaskResult - task_id = result.task.task_id - assert result.task.status == "working" - - # Wait for work to complete - with anyio.fail_after(5): - await work_completed.wait() - - # Poll until task status is completed - with anyio.fail_after(5): - while True: - task_status = await client_session.experimental.get_task(task_id) - if task_status.status == "completed": # pragma: no branch - break - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) - - # Verify the meta was passed through correctly + task_id = result.task.task_id + assert result.task.status == "working" + + with anyio.fail_after(5): + await work_completed.wait() + + with anyio.fail_after(5): + while True: + task_status = await client.session.experimental.get_task(task_id) + if task_status.status == "completed": # pragma: no branch + break + assert received_meta[0] == "test_value" -@pytest.mark.anyio async def test_run_task_auto_fails_on_exception() -> None: """Test that run_task automatically fails the task when work raises.""" - server = Server("test-run-task-fail") - server.experimental.enable_tasks() - work_failed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="failing_task", - description="A task that fails", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -169,42 +112,29 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("failing_task", {}) - task_id = result.task.task_id + server = Server( + "test-run-task-fail", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - # Wait for work to fail - with anyio.fail_after(5): - await work_failed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("failing_task", {}) + task_id = result.task.task_id - # Poll until task status is failed - with anyio.fail_after(5): - while True: - task_status = await client_session.experimental.get_task(task_id) - if task_status.status == "failed": # pragma: no branch - break + with anyio.fail_after(5): + await work_failed.wait() - assert "Something went wrong" in (task_status.status_message or "") + with anyio.fail_after(5): + while True: + task_status = await client.session.experimental.get_task(task_id) + if task_status.status == "failed": # pragma: no branch + break - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + assert "Something went wrong" in (task_status.status_message or "") -@pytest.mark.anyio async def test_enable_tasks_auto_registers_handlers() -> None: """Test that enable_tasks() auto-registers get_task, list_tasks, cancel_task handlers.""" server = Server("test-enable-tasks") @@ -221,63 +151,41 @@ async def test_enable_tasks_auto_registers_handlers() -> None: assert caps_after.tasks is not None assert caps_after.tasks.list is not None assert caps_after.tasks.cancel is not None - # Verify nested call capability is present assert caps_after.tasks.requests is not None assert caps_after.tasks.requests.tools is not None assert caps_after.tasks.requests.tools.call is not None -@pytest.mark.anyio async def test_enable_tasks_with_custom_store_and_queue() -> None: """Test that enable_tasks() uses provided store and queue instead of defaults.""" server = Server("test-custom-store-queue") - # Create custom store and queue custom_store = InMemoryTaskStore() custom_queue = InMemoryTaskMessageQueue() - # Enable tasks with custom implementations task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue) - # Verify our custom implementations are used assert task_support.store is custom_store assert task_support.queue is custom_queue -@pytest.mark.anyio async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None: """Test that enable_tasks() doesn't override already-registered handlers.""" server = Server("test-custom-handlers") - # Register custom handlers BEFORE enable_tasks (never called, just for registration) - @server.experimental.get_task() - async def custom_get_task(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError - - @server.experimental.get_task_result() - async def custom_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: + # Register custom handlers via enable_tasks kwargs + async def custom_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: raise NotImplementedError - @server.experimental.list_tasks() - async def custom_list_tasks(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError - - @server.experimental.cancel_task() - async def custom_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=custom_get_task) - # Now enable tasks - should NOT override our custom handlers - server.experimental.enable_tasks() + # Verify handler is registered + assert server._has_handler("tasks/get") + assert server._has_handler("tasks/list") + assert server._has_handler("tasks/cancel") + assert server._has_handler("tasks/result") - # Verify our custom handlers are still registered (not replaced by defaults) - # The handlers dict should contain our custom handlers - assert GetTaskRequest in server.request_handlers - assert GetTaskPayloadRequest in server.request_handlers - assert ListTasksRequest in server.request_handlers - assert CancelTaskRequest in server.request_handlers - -@pytest.mark.anyio async def test_run_task_without_enable_tasks_raises() -> None: """Test that run_task raises when enable_tasks() wasn't called.""" experimental = Experimental( @@ -294,7 +202,6 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_task_support_task_group_before_run_raises() -> None: """Test that accessing task_group before run() raises RuntimeError.""" task_support = TaskSupport.in_memory() @@ -303,7 +210,6 @@ async def test_task_support_task_group_before_run_raises() -> None: _ = task_support.task_group -@pytest.mark.anyio async def test_run_task_without_session_raises() -> None: """Test that run_task raises when session is not available.""" task_support = TaskSupport.in_memory() @@ -322,7 +228,6 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_run_task_without_task_metadata_raises() -> None: """Test that run_task raises when request is not task-augmented.""" task_support = TaskSupport.in_memory() @@ -342,29 +247,17 @@ async def work(task: ServerTaskContext) -> CallToolResult: await experimental.run_task(work) -@pytest.mark.anyio async def test_run_task_with_model_immediate_response() -> None: """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" - server = Server("test-run-task-immediate") - server.experimental.enable_tasks() - work_completed = Event() immediate_response_text = "Processing your request..." - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="task_with_immediate", - description="A task with immediate response", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -373,164 +266,102 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("task_with_immediate", {}) + server = Server( + "test-run-task-immediate", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - # Verify the immediate response is in _meta - assert result.meta is not None - assert "io.modelcontextprotocol/model-immediate-response" in result.meta - assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("task_with_immediate", {}) - with anyio.fail_after(5): - await work_completed.wait() + assert result.meta is not None + assert "io.modelcontextprotocol/model-immediate-response" in result.meta + assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + with anyio.fail_after(5): + await work_completed.wait() -@pytest.mark.anyio async def test_run_task_doesnt_complete_if_already_terminal() -> None: """Test that run_task doesn't auto-complete if work manually completed the task.""" - server = Server("test-already-complete") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_complete_task", - description="A task that manually completes", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: - # Manually complete the task before returning manual_result = CallToolResult(content=[TextContent(type="text", text="Manually completed")]) await task.complete(manual_result, notify=False) work_completed.set() - # Return a different result - but it should be ignored since task is already terminal return CallToolResult(content=[TextContent(type="text", text="This should be ignored")]) return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("manual_complete_task", {}) - task_id = result.task.task_id + server = Server( + "test-already-complete", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - with anyio.fail_after(5): - await work_completed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("manual_complete_task", {}) + task_id = result.task.task_id - # Poll until task status is completed - with anyio.fail_after(5): - while True: - status = await client_session.experimental.get_task(task_id) - if status.status == "completed": # pragma: no branch - break + with anyio.fail_after(5): + await work_completed.wait() - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + with anyio.fail_after(5): + while True: + status = await client.session.experimental.get_task(task_id) + if status.status == "completed": # pragma: no branch + break -@pytest.mark.anyio async def test_run_task_doesnt_fail_if_already_terminal() -> None: """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" - server = Server("test-already-failed") - server.experimental.enable_tasks() - work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="manual_cancel_task", - description="A task that manually cancels then raises", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + async def handle_call_tool( + ctx: ServerRequestContext, params: CallToolRequestParams + ) -> CallToolResult | CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: - # Manually fail the task first await task.fail("Manually failed", notify=False) work_completed.set() - # Then raise - but the auto-fail should be skipped since task is already terminal raise RuntimeError("This error should not change status") return await ctx.experimental.run_task(work) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run( - client_to_server_receive, - server_to_client_send, - server.create_initialization_options(), - ) - - async def run_client() -> None: - async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: - await client_session.initialize() - - result = await client_session.experimental.call_tool_as_task("manual_cancel_task", {}) - task_id = result.task.task_id + server = Server( + "test-already-failed", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + server.experimental.enable_tasks() - with anyio.fail_after(5): - await work_completed.wait() + async with Client(server) as client: + result = await client.session.experimental.call_tool_as_task("manual_cancel_task", {}) + task_id = result.task.task_id - # Poll until task status is failed - with anyio.fail_after(5): - while True: - status = await client_session.experimental.get_task(task_id) - if status.status == "failed": # pragma: no branch - break + with anyio.fail_after(5): + await work_completed.wait() - # Task should still be failed (from manual fail, not auto-fail from exception) - assert status.status_message == "Manually failed" # Not "This error should not change status" + with anyio.fail_after(5): + while True: + status = await client.session.experimental.get_task(task_id) + if status.status == "failed": # pragma: no branch + break - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - tg.start_soon(run_client) + assert status.status_message == "Manually failed" diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 8005380d2..6a28b274e 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -6,8 +6,9 @@ import anyio import pytest +from mcp import Client from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -23,7 +24,6 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - CancelTaskRequest, CancelTaskRequestParams, CancelTaskResult, ClientResult, @@ -31,21 +31,18 @@ GetTaskPayloadRequest, GetTaskPayloadRequestParams, GetTaskPayloadResult, - GetTaskRequest, GetTaskRequestParams, GetTaskResult, JSONRPCError, JSONRPCNotification, JSONRPCResponse, - ListTasksRequest, ListTasksResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, SamplingMessage, ServerCapabilities, ServerNotification, ServerRequest, - ServerResult, Task, TaskMetadata, TextContent, @@ -53,57 +50,37 @@ ToolExecution, ) +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_list_tasks_handler() -> None: - """Test that experimental list_tasks handler works.""" - server = Server("test") +async def test_list_tasks_handler() -> None: + """Test that experimental list_tasks handler works via Client.""" now = datetime.now(timezone.utc) test_tasks = [ - Task( - task_id="task-1", - status="working", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ), - Task( - task_id="task-2", - status="completed", - created_at=now, - last_updated_at=now, - ttl=60000, - poll_interval=1000, - ), + Task(task_id="task-1", status="working", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), + Task(task_id="task-2", status="completed", created_at=now, last_updated_at=now, ttl=60000, poll_interval=1000), ] - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def handle_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: return ListTasksResult(tasks=test_tasks) - handler = server.request_handlers[ListTasksRequest] - request = ListTasksRequest(method="tasks/list") - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_list_tasks=handle_list_tasks) - assert isinstance(result, ServerResult) - assert isinstance(result, ListTasksResult) - assert len(result.tasks) == 2 - assert result.tasks[0].task_id == "task-1" - assert result.tasks[1].task_id == "task-2" + async with Client(server) as client: + result = await client.session.experimental.list_tasks() + assert len(result.tasks) == 2 + assert result.tasks[0].task_id == "task-1" + assert result.tasks[1].task_id == "task-2" -@pytest.mark.anyio async def test_get_task_handler() -> None: - """Test that experimental get_task handler works.""" - server = Server("test") + """Test that experimental get_task handler works via Client.""" - @server.experimental.get_task() - async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + async def handle_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: now = datetime.now(timezone.utc) return GetTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="working", created_at=now, last_updated_at=now, @@ -111,85 +88,69 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: poll_interval=1000, ) - handler = server.request_handlers[GetTaskRequest] - request = GetTaskRequest( - method="tasks/get", - params=GetTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_get_task=handle_get_task) - assert isinstance(result, ServerResult) - assert isinstance(result, GetTaskResult) - assert result.task_id == "test-task-123" - assert result.status == "working" + async with Client(server) as client: + result = await client.session.experimental.get_task("test-task-123") + assert result.task_id == "test-task-123" + assert result.status == "working" -@pytest.mark.anyio async def test_get_task_result_handler() -> None: - """Test that experimental get_task_result handler works.""" - server = Server("test") + """Test that experimental get_task_result handler works via Client.""" - @server.experimental.get_task_result() - async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + async def handle_get_task_result( + ctx: ServerRequestContext, params: GetTaskPayloadRequestParams + ) -> GetTaskPayloadResult: return GetTaskPayloadResult() - handler = server.request_handlers[GetTaskPayloadRequest] - request = GetTaskPayloadRequest( - method="tasks/result", - params=GetTaskPayloadRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_task_result=handle_get_task_result) - assert isinstance(result, ServerResult) - assert isinstance(result, GetTaskPayloadResult) + async with Client(server) as client: + result = await client.session.send_request( + GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(task_id="test-task-123")), + GetTaskPayloadResult, + ) + assert isinstance(result, GetTaskPayloadResult) -@pytest.mark.anyio async def test_cancel_task_handler() -> None: - """Test that experimental cancel_task handler works.""" - server = Server("test") + """Test that experimental cancel_task handler works via Client.""" - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def handle_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: now = datetime.now(timezone.utc) return CancelTaskResult( - task_id=request.params.task_id, + task_id=params.task_id, status="cancelled", created_at=now, last_updated_at=now, ttl=60000, ) - handler = server.request_handlers[CancelTaskRequest] - request = CancelTaskRequest( - method="tasks/cancel", - params=CancelTaskRequestParams(task_id="test-task-123"), - ) - result = await handler(request) + server = Server("test") + server.experimental.enable_tasks(on_cancel_task=handle_cancel_task) - assert isinstance(result, ServerResult) - assert isinstance(result, CancelTaskResult) - assert result.task_id == "test-task-123" - assert result.status == "cancelled" + async with Client(server) as client: + result = await client.session.experimental.cancel_task("test-task-123") + assert result.task_id == "test-task-123" + assert result.status == "cancelled" -@pytest.mark.anyio async def test_server_capabilities_include_tasks() -> None: """Test that server capabilities include tasks when handlers are registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError - @server.experimental.cancel_task() - async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + async def noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: raise NotImplementedError - capabilities = server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) + server.experimental.enable_tasks(on_list_tasks=noop_list_tasks, on_cancel_task=noop_cancel_task) + + capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) assert capabilities.tasks is not None assert capabilities.tasks.list is not None @@ -198,259 +159,164 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: assert capabilities.tasks.requests.tools is not None -@pytest.mark.anyio -async def test_server_capabilities_partial_tasks() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +async def test_server_capabilities_partial_tasks() -> None: # pragma: no cover """Test capabilities with only some task handlers registered.""" server = Server("test") - @server.experimental.list_tasks() - async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + async def noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: raise NotImplementedError # Only list_tasks registered, not cancel_task + server.experimental.enable_tasks(on_list_tasks=noop_list_tasks) - capabilities = server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ) + capabilities = server.get_capabilities(notification_options=NotificationOptions(), experimental_capabilities={}) assert capabilities.tasks is not None assert capabilities.tasks.list is not None assert capabilities.tasks.cancel is None # Not registered -@pytest.mark.anyio async def test_tool_with_task_execution_metadata() -> None: """Test that tools can declare task execution mode.""" - server = Server("test") - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="quick_tool", - description="Fast tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_FORBIDDEN), - ), - Tool( - name="long_tool", - description="Long running tool", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ), - Tool( - name="flexible_tool", - description="Can be either", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support=TASK_OPTIONAL), - ), - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="quick_tool", + description="Fast tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_FORBIDDEN), + ), + Tool( + name="long_tool", + description="Long running tool", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_REQUIRED), + ), + Tool( + name="flexible_tool", + description="Can be either", + input_schema={"type": "object", "properties": {}}, + execution=ToolExecution(task_support=TASK_OPTIONAL), + ), + ] + ) - tools_handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list") - result = await tools_handler(request) + server = Server("test", on_list_tools=handle_list_tools) - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - tools = result.tools + async with Client(server) as client: + result = await client.list_tools() + tools = result.tools - assert tools[0].execution is not None - assert tools[0].execution.task_support == TASK_FORBIDDEN - assert tools[1].execution is not None - assert tools[1].execution.task_support == TASK_REQUIRED - assert tools[2].execution is not None - assert tools[2].execution.task_support == TASK_OPTIONAL + assert tools[0].execution is not None + assert tools[0].execution.task_support == TASK_FORBIDDEN + assert tools[1].execution is not None + assert tools[1].execution.task_support == TASK_REQUIRED + assert tools[2].execution is not None + assert tools[2].execution.task_support == TASK_OPTIONAL -@pytest.mark.anyio async def test_task_metadata_in_call_tool_request() -> None: - """Test that task metadata is accessible via RequestContext when calling a tool.""" - server = Server("test") + """Test that task metadata is accessible via ctx when calling a tool.""" captured_task_metadata: TaskMetadata | None = None - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="long_task", - description="A long running task", - input_schema={"type": "object", "properties": {}}, - execution=ToolExecution(task_support="optional"), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal captured_task_metadata - ctx = server.request_context captured_task_metadata = ctx.experimental.task_metadata - return [TextContent(type="text", text="done")] - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch - - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, + return CallToolResult(content=[TextContent(type="text", text="done")]) + + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + + async with Client(server) as client: + # Call tool with task metadata + await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), ), ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Call tool with task metadata - await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams( - name="long_task", - arguments={}, - task=TaskMetadata(ttl=60000), - ), - ), - CallToolResult, - ) - - tg.cancel_scope.cancel() + CallToolResult, + ) assert captured_task_metadata is not None assert captured_task_metadata.ttl == 60000 -@pytest.mark.anyio async def test_task_metadata_is_task_property() -> None: - """Test that RequestContext.experimental.is_task works correctly.""" - server = Server("test") + """Test that ctx.experimental.is_task works correctly.""" is_task_values: list[bool] = [] - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="test_tool", - description="Test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: is_task_values.append(ctx.experimental.is_task) - return [TextContent(type="text", text="done")] + return CallToolResult(content=[TextContent(type="text", text="done")]) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: ... # pragma: no branch + async with Client(server) as client: + # Call without task metadata + await client.session.send_request( + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), + CallToolResult, + ) - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), + # Call with task metadata + await client.session.send_request( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - await client_session.initialize() - - # Call without task metadata - await client_session.send_request( - CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), - CallToolResult, - ) - - # Call with task metadata - await client_session.send_request( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}, task=TaskMetadata(ttl=60000)), - ), - CallToolResult, - ) - - tg.cancel_scope.cancel() + CallToolResult, + ) assert len(is_task_values) == 2 assert is_task_values[0] is False # First call without task assert is_task_values[1] is True # Second call with task -@pytest.mark.anyio async def test_update_capabilities_no_handlers() -> None: """Test that update_capabilities returns early when no task handlers are registered.""" server = Server("test-no-handlers") - # Access experimental to initialize it, but don't register any task handlers _ = server.experimental caps = server.get_capabilities(NotificationOptions(), {}) - - # Without any task handlers registered, tasks capability should be None assert caps.tasks is None -@pytest.mark.anyio +async def test_update_capabilities_partial_handlers() -> None: + """Test that update_capabilities skips list/cancel when only tasks/get is registered.""" + server = Server("test-partial") + # Access .experimental to create the ExperimentalHandlers instance + exp = server.experimental + # Second access returns the same cached instance + assert server.experimental is exp + + async def noop_get(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + raise NotImplementedError + + server._add_request_handler("tasks/get", noop_get) + + caps = server.get_capabilities(NotificationOptions(), {}) + assert caps.tasks is not None + assert caps.tasks.list is None + assert caps.tasks.cancel is None + + async def test_default_task_handlers_via_enable_tasks() -> None: - """Test that enable_tasks() auto-registers working default handlers. - - This exercises the default handlers in lowlevel/experimental.py: - - _default_get_task (task not found) - - _default_get_task_result - - _default_list_tasks - - _default_cancel_task - """ + """Test that enable_tasks() auto-registers working default handlers.""" server = Server("test-default-handlers") - # Enable tasks with default handlers (no custom handlers registered) task_support = server.experimental.enable_tasks() store = task_support.store @@ -493,24 +359,18 @@ async def run_server() -> None: task = await store.create_task(TaskMetadata(ttl=60000)) # Test list_tasks (default handler) - list_result = await client_session.send_request(ListTasksRequest(), ListTasksResult) + list_result = await client_session.experimental.list_tasks() assert len(list_result.tasks) == 1 assert list_result.tasks[0].task_id == task.task_id # Test get_task (default handler - found) - get_result = await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id=task.task_id)), - GetTaskResult, - ) + get_result = await client_session.experimental.get_task(task.task_id) assert get_result.task_id == task.task_id assert get_result.status == "working" # Test get_task (default handler - not found path) with pytest.raises(MCPError, match="not found"): - await client_session.send_request( - GetTaskRequest(params=GetTaskRequestParams(task_id="nonexistent-task")), - GetTaskResult, - ) + await client_session.experimental.get_task("nonexistent-task") # Create a completed task to test get_task_result completed_task = await store.create_task(TaskMetadata(ttl=60000)) @@ -529,9 +389,7 @@ async def run_server() -> None: assert "io.modelcontextprotocol/related-task" in payload_result.meta # Test cancel_task (default handler) - cancel_result = await client_session.send_request( - CancelTaskRequest(params=CancelTaskRequestParams(task_id=task.task_id)), CancelTaskResult - ) + cancel_result = await client_session.experimental.cancel_task(task.task_id) assert cancel_result.task_id == task.task_id assert cancel_result.status == "cancelled" diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index 57122da7b..2d0378a9c 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -17,7 +17,7 @@ from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import NotificationOptions from mcp.shared._context import RequestContext @@ -26,6 +26,7 @@ from mcp.shared.message import SessionMessage from mcp.types import ( TASK_REQUIRED, + CallToolRequestParams, CallToolResult, CreateMessageRequestParams, CreateMessageResult, @@ -35,11 +36,12 @@ ErrorData, GetTaskPayloadResult, GetTaskResult, + ListToolsResult, + PaginatedRequestParams, SamplingMessage, TaskMetadata, TextContent, Tool, - ToolExecution, ) @@ -181,24 +183,21 @@ async def test_scenario1_normal_tool_normal_elicitation() -> None: Server calls session.elicit() directly, client responds immediately. """ - server = Server("test-scenario1") elicit_received = Event() tool_result: list[str] = [] - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Normal elicitation - expects immediate response result = await ctx.session.elicit( message="Please confirm the action", @@ -209,6 +208,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario1", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -262,27 +263,24 @@ async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: Server calls session.experimental.elicit_as_task(), client creates a task for the elicitation and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2") elicit_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="confirm_action", + description="Confirm an action", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented elicitation - server polls client result = await ctx.session.experimental.elicit_as_task( message="Please confirm the action", @@ -294,6 +292,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append("confirmed" if confirmed else "cancelled") return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + server = Server("test-scenario2", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -342,26 +341,13 @@ async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: Client calls tool as task. Inside the task, server uses task.elicit() which queues the request and delivers via tasks/result. """ - server = Server("test-scenario3") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -377,6 +363,9 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario3", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() + # Elicitation callback for client async def elicitation_callback( context: RequestContext[ClientSession], @@ -452,29 +441,16 @@ async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> Non 5. Server gets the ElicitResult and completes the tool task 6. Client's tasks/result returns with the CallToolResult """ - server = Server("test-scenario4") - server.experimental.enable_tasks() - elicit_received = Event() work_completed = Event() # Client-side task store for handling task-augmented elicitation client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="confirm_action", - description="Confirm an action", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -491,6 +467,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_client_task_handlers(client_task_store, elicit_received) # Set up streams @@ -553,27 +531,24 @@ async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: Server calls session.experimental.create_message_as_task(), client creates a task for the sampling and returns CreateTaskResult. Server polls client. """ - server = Server("test-scenario2-sampling") sampling_received = Event() tool_result: list[str] = [] # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - ) - ] - - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: - ctx = server.request_context + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="generate_text", + description="Generate text using sampling", + input_schema={"type": "object"}, + ) + ] + ) + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: # Task-augmented sampling - server polls client result = await ctx.session.experimental.create_message_as_task( messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], @@ -587,6 +562,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResu tool_result.append(response_text) return CallToolResult(content=[TextContent(type="text", text=response_text)]) + server = Server("test-scenario2-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams @@ -636,29 +612,16 @@ async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() which sends task-augmented sampling. Client creates its own task for the sampling, and server polls the client. """ - server = Server("test-scenario4-sampling") - server.experimental.enable_tasks() - sampling_received = Event() work_completed = Event() # Client-side task store for handling task-augmented sampling client_task_store = InMemoryTaskStore() - @server.list_tools() - async def list_tools() -> list[Tool]: - return [ - Tool( - name="generate_text", - description="Generate text using sampling", - input_schema={"type": "object"}, - execution=ToolExecution(task_support=TASK_REQUIRED), - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: - ctx = server.request_context + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CreateTaskResult: ctx.experimental.validate_task_mode(TASK_REQUIRED) async def work(task: ServerTaskContext) -> CallToolResult: @@ -677,6 +640,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: return await ctx.experimental.run_task(work) + server = Server("test-scenario4-sampling", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + server.experimental.enable_tasks() task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) # Set up streams diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py index d00ce40a4..38d7d0a66 100644 --- a/tests/experimental/tasks/test_spec_compliance.py +++ b/tests/experimental/tasks/test_spec_compliance.py @@ -10,17 +10,17 @@ import pytest -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY from mcp.types import ( - CancelTaskRequest, + CancelTaskRequestParams, CancelTaskResult, CreateTaskResult, - GetTaskRequest, + GetTaskRequestParams, GetTaskResult, - ListTasksRequest, ListTasksResult, + PaginatedRequestParams, ServerCapabilities, Task, ) @@ -44,13 +44,22 @@ def test_server_without_task_handlers_has_no_tasks_capability() -> None: assert caps.tasks is None +async def _noop_get_task(ctx: ServerRequestContext, params: GetTaskRequestParams) -> GetTaskResult: + raise NotImplementedError + + +async def _noop_list_tasks(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListTasksResult: + raise NotImplementedError + + +async def _noop_cancel_task(ctx: ServerRequestContext, params: CancelTaskRequestParams) -> CancelTaskResult: + raise NotImplementedError + + def test_server_with_list_tasks_handler_declares_list_capability() -> None: """Server with list_tasks handler declares tasks.list capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError + server.experimental.enable_tasks(on_list_tasks=_noop_list_tasks) caps = _get_capabilities(server) assert caps.tasks is not None @@ -60,10 +69,7 @@ async def handle_list(req: ListTasksRequest) -> ListTasksResult: def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: """Server with cancel_task handler declares tasks.cancel capability.""" server: Server = Server("test") - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_cancel_task=_noop_cancel_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -75,10 +81,7 @@ def test_server_with_get_task_handler_declares_requests_tools_call_capability() (get_task is required for task-augmented tools/call support) """ server: Server = Server("test") - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -86,28 +89,30 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: assert caps.tasks.requests.tools is not None -def test_server_without_list_handler_has_no_list_capability() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +def test_server_without_list_handler_has_no_list_capability() -> None: # pragma: no cover """Server without list_tasks handler has no tasks.list capability.""" server: Server = Server("test") - - # Register only get_task (not list_tasks) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None assert caps.tasks.list is None -def test_server_without_cancel_handler_has_no_cancel_capability() -> None: +@pytest.mark.skip( + reason="TODO(maxisbey): enable_tasks registers default handlers for all task methods, " + "so partial capabilities aren't possible yet. Low-level API should support " + "selectively enabling/disabling task capabilities." +) +def test_server_without_cancel_handler_has_no_cancel_capability() -> None: # pragma: no cover """Server without cancel_task handler has no tasks.cancel capability.""" server: Server = Server("test") - - # Register only get_task (not cancel_task) - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks(on_get_task=_noop_get_task) caps = _get_capabilities(server) assert caps.tasks is not None @@ -117,18 +122,11 @@ async def handle_get(req: GetTaskRequest) -> GetTaskResult: def test_server_with_all_task_handlers_has_full_capability() -> None: """Server with all task handlers declares complete tasks capability.""" server: Server = Server("test") - - @server.experimental.list_tasks() - async def handle_list(req: ListTasksRequest) -> ListTasksResult: - raise NotImplementedError - - @server.experimental.cancel_task() - async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: - raise NotImplementedError - - @server.experimental.get_task() - async def handle_get(req: GetTaskRequest) -> GetTaskResult: - raise NotImplementedError + server.experimental.enable_tasks( + on_list_tasks=_noop_list_tasks, + on_cancel_task=_noop_cancel_task, + on_get_task=_noop_get_task, + ) caps = _get_capabilities(server) assert caps.tasks is not None diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 39e2c6f2a..bb4735121 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -1,15 +1,13 @@ import pytest -from mcp import types +from mcp import Client from mcp.server.mcpserver import MCPServer @pytest.mark.anyio async def test_resource_templates(): - # Create an MCP server mcp = MCPServer("Demo") - # Add a dynamic greeting resource @mcp.resource("greeting://{name}") def get_greeting(name: str) -> str: # pragma: no cover """Get a personalized greeting""" @@ -20,23 +18,16 @@ def get_user_profile(user_id: str) -> str: # pragma: no cover """Dynamic user data""" return f"Profile data for user {user_id}" - # Get the list of resource templates using the underlying server - # Note: list_resource_templates() returns a decorator that wraps the handler - # The handler returns a ServerResult with a ListResourceTemplatesResult inside - result = await mcp._lowlevel_server.request_handlers[types.ListResourceTemplatesRequest]( - types.ListResourceTemplatesRequest(params=None) - ) - assert isinstance(result, types.ListResourceTemplatesResult) - templates = result.resource_templates - - # Verify we get both templates back - assert len(templates) == 2 - - # Verify template details - greeting_template = next(t for t in templates if t.name == "get_greeting") - assert greeting_template.uri_template == "greeting://{name}" - assert greeting_template.description == "Get a personalized greeting" - - profile_template = next(t for t in templates if t.name == "get_user_profile") - assert profile_template.uri_template == "users://{user_id}/profile" - assert profile_template.description == "Dynamic user data" + async with Client(mcp) as client: + result = await client.list_resource_templates() + templates = result.resource_templates + + assert len(templates) == 2 + + greeting_template = next(t for t in templates if t.name == "get_greeting") + assert greeting_template.uri_template == "greeting://{name}" + assert greeting_template.description == "Get a personalized greeting" + + profile_template = next(t for t in templates if t.name == "get_user_profile") + assert profile_template.uri_template == "users://{user_id}/profile" + assert profile_template.description == "Dynamic user data" diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index e738017f8..851e89979 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -3,9 +3,16 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer +from mcp.types import ( + BlobResourceContents, + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -58,7 +65,6 @@ def get_image_as_bytes() -> bytes: async def test_lowlevel_resource_mime_type(): """Test that mime_type parameter is respected for resources.""" - server = Server("test") # Create a small test image as bytes image_bytes = b"fake_image_data" @@ -74,17 +80,24 @@ async def test_lowlevel_resource_mime_type(): ), ] - @server.list_resources() - async def handle_list_resources(): - return test_resources - - @server.read_resource() - async def handle_read_resource(uri: str): - if str(uri) == "test://image": - return [ReadResourceContents(content=base64_string, mime_type="image/png")] - elif str(uri) == "test://image_bytes": - return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] - raise Exception(f"Resource not found: {uri}") # pragma: no cover + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) + + resource_contents: dict[str, list[TextResourceContents | BlobResourceContents]] = { + "test://image": [TextResourceContents(uri="test://image", text=base64_string, mime_type="image/png")], + "test://image_bytes": [ + BlobResourceContents( + uri="test://image_bytes", blob=base64.b64encode(image_bytes).decode("utf-8"), mime_type="image/png" + ) + ], + } + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult(contents=resource_contents[str(params.uri)]) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) # Test that resources are listed with correct mime type async with Client(server) as client: diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index e6ff56877..c67708128 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -13,8 +13,14 @@ import pytest from mcp import Client, types -from mcp.server.lowlevel import Server -from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + ListResourcesResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) pytestmark = pytest.mark.anyio @@ -26,24 +32,24 @@ async def test_relative_uri_roundtrip(): the server would fail to serialize resources with relative URIs, or the URI would be transformed during the roundtrip. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="user", uri="users/me"), - types.Resource(name="config", uri="./config"), - types.Resource(name="parent", uri="../parent/resource"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ - ReadResourceContents( - content=f"data for {uri}", - mime_type="text/plain", - ) - ] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="user", uri="users/me"), + types.Resource(name="config", uri="./config"), + types.Resource(name="parent", uri="../parent/resource"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text=f"data for {params.uri}", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: # List should return the exact URIs we specified @@ -67,18 +73,23 @@ async def test_custom_scheme_uri_roundtrip(): Some MCP servers use custom schemes like "custom://resource". These should work end-to-end. """ - server = Server("test") - - @server.list_resources() - async def list_resources(): - return [ - types.Resource(name="custom", uri="custom://my-resource"), - types.Resource(name="file", uri="file:///path/to/file"), - ] - - @server.read_resource() - async def read_resource(uri: str): - return [ReadResourceContents(content="data", mime_type="text/plain")] + + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + types.Resource(name="custom", uri="custom://my-resource"), + types.Resource(name="file", uri="file:///path/to/file"), + ] + ) + + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="data", mime_type="text/plain")] + ) + + server = Server("test", on_list_resources=handle_list_resources, on_read_resource=handle_read_resource) async with Client(server) as client: resources = await client.list_resources() diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index fb4bb0101..5d5f8b8fc 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -35,6 +35,12 @@ async def test_progress_token_zero_first_call(): # Verify progress notifications assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent" - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None) - mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=0.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=5.0, total=10.0, message=None, related_request_id="test-request" + ) + mock_session.send_progress_notification.assert_any_call( + progress_token=0, progress=10.0, total=10.0, message=None, related_request_id="test-request" + ) diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index 44b17d337..2bccedf8d 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -1,83 +1,52 @@ """Test for base64 encoding issue in MCP server. -This test demonstrates the issue in server.py where the server uses -urlsafe_b64encode but the BlobResourceContents validator expects standard -base64 encoding. - -The test should FAIL before fixing server.py to use b64encode instead of -urlsafe_b64encode. -After the fix, the test should PASS. +This test verifies that binary resource data is encoded with standard base64 +(not urlsafe_b64encode), so BlobResourceContents validation succeeds. """ import base64 -from typing import cast import pytest -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import Server -from mcp.types import ( - BlobResourceContents, - ReadResourceRequest, - ReadResourceRequestParams, - ReadResourceResult, - ServerResult, -) +from mcp import Client +from mcp.server.mcpserver import MCPServer +from mcp.types import BlobResourceContents +pytestmark = pytest.mark.anyio -@pytest.mark.anyio -async def test_server_base64_encoding_issue(): - """Tests that server response can be validated by BlobResourceContents. - This test will: - 1. Set up a server that returns binary data - 2. Extract the base64-encoded blob from the server's response - 3. Verify the encoded data can be properly validated by BlobResourceContents +async def test_server_base64_encoding(): + """Tests that binary resource data round-trips correctly through base64 encoding. - BEFORE FIX: The test will fail because server uses urlsafe_b64encode - AFTER FIX: The test will pass because server uses standard b64encode + The test uses binary data that produces different results with urlsafe vs standard + base64, ensuring the server uses standard encoding. """ - server = Server("test") + mcp = MCPServer("test") # Create binary data that will definitely result in + and / characters # when encoded with standard base64 binary_data = bytes(list(range(255)) * 4) - # Register a resource handler that returns our test data - @server.read_resource() - async def read_resource(uri: str) -> list[ReadResourceContents]: - return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[ReadResourceRequest] - - # Create a request - request = ReadResourceRequest( - params=ReadResourceRequestParams(uri="test://resource"), - ) - - # Call the handler to get the response - result: ServerResult = await handler(request) - - # After (fixed code): - read_result: ReadResourceResult = cast(ReadResourceResult, result) - blob_content = read_result.contents[0] - - # First verify our test data actually produces different encodings + # Sanity check: our test data produces different encodings urlsafe_b64 = base64.urlsafe_b64encode(binary_data).decode() standard_b64 = base64.b64encode(binary_data).decode() - assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate" - " encoding difference" + assert urlsafe_b64 != standard_b64, "Test data doesn't demonstrate encoding difference" + + @mcp.resource("test://binary", mime_type="application/octet-stream") + def get_binary() -> bytes: + """Return binary test data.""" + return binary_data + + async with Client(mcp) as client: + result = await client.read_resource("test://binary") + assert len(result.contents) == 1 - # Now validate the server's output with BlobResourceContents.model_validate - # Before the fix: This should fail with "Invalid base64" because server - # uses urlsafe_b64encode - # After the fix: This should pass because server will use standard b64encode - model_dict = blob_content.model_dump() + blob_content = result.contents[0] + assert isinstance(blob_content, BlobResourceContents) - # Direct validation - this will fail before fix, pass after fix - blob_model = BlobResourceContents.model_validate(model_dict) + # Verify standard base64 was used (not urlsafe) + assert blob_content.blob == standard_b64 - # Verify we can decode the data back correctly - decoded = base64.b64decode(blob_model.blob) - assert decoded == binary_data + # Verify we can decode the data back correctly + decoded = base64.b64decode(blob_content.blob) + assert decoded == binary_data diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index cd27698e6..6b593d2a5 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -1,8 +1,6 @@ """Test to reproduce issue #88: Random error thrown on response.""" -from collections.abc import Sequence from pathlib import Path -from typing import Any import anyio import pytest @@ -11,10 +9,10 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -from mcp.types import ContentBlock, TextContent +from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent @pytest.mark.anyio @@ -32,36 +30,38 @@ async def test_notification_validation_error(tmp_path: Path): - Slow operations use minimal timeout (10ms) for quick test execution """ - server = Server(name="test") request_count = 0 slow_request_lock = anyio.Event() - @server.list_tools() - async def list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow", - description="A slow tool", - input_schema={"type": "object"}, - ), - types.Tool( - name="fast", - description="A fast tool", - input_schema={"type": "object"}, - ), - ] - - @server.call_tool() - async def slow_tool(name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]: + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + types.Tool( + name="slow", + description="A slow tool", + input_schema={"type": "object"}, + ), + types.Tool( + name="fast", + description="A fast tool", + input_schema={"type": "object"}, + ), + ] + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal request_count request_count += 1 + assert params.name in ("slow", "fast"), f"Unknown tool: {params.name}" - if name == "slow": + if params.name == "slow": await slow_request_lock.wait() # it should timeout here - return [TextContent(type="text", text=f"slow {request_count}")] - elif name == "fast": - return [TextContent(type="text", text=f"fast {request_count}")] - return [TextContent(type="text", text=f"unknown {request_count}")] # pragma: no cover + text = f"slow {request_count}" + else: + text = f"fast {request_count}" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + server = Server(name="test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], diff --git a/tests/server/auth/test_routes.py b/tests/server/auth/test_routes.py new file mode 100644 index 000000000..3d13b5ba5 --- /dev/null +++ b/tests/server/auth/test_routes.py @@ -0,0 +1,47 @@ +import pytest +from pydantic import AnyHttpUrl + +from mcp.server.auth.routes import validate_issuer_url + + +def test_validate_issuer_url_https_allowed(): + validate_issuer_url(AnyHttpUrl("https://example.com/path")) + + +def test_validate_issuer_url_http_localhost_allowed(): + validate_issuer_url(AnyHttpUrl("http://localhost:8080/path")) + + +def test_validate_issuer_url_http_127_0_0_1_allowed(): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1:8080/path")) + + +def test_validate_issuer_url_http_ipv6_loopback_allowed(): + validate_issuer_url(AnyHttpUrl("http://[::1]:8080/path")) + + +def test_validate_issuer_url_http_non_loopback_rejected(): + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://evil.com/path")) + + +def test_validate_issuer_url_http_127_prefix_domain_rejected(): + """A domain like 127.0.0.1.evil.com is not loopback.""" + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1.evil.com/path")) + + +def test_validate_issuer_url_http_127_prefix_subdomain_rejected(): + """A domain like 127.0.0.1something.example.com is not loopback.""" + with pytest.raises(ValueError, match="Issuer URL must be HTTPS"): + validate_issuer_url(AnyHttpUrl("http://127.0.0.1something.example.com/path")) + + +def test_validate_issuer_url_fragment_rejected(): + with pytest.raises(ValueError, match="fragment"): + validate_issuer_url(AnyHttpUrl("https://example.com/path#frag")) + + +def test_validate_issuer_url_query_rejected(): + with pytest.raises(ValueError, match="query"): + validate_issuer_url(AnyHttpUrl("https://example.com/path?q=1")) diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py deleted file mode 100644 index 9cb2b561a..000000000 --- a/tests/server/lowlevel/test_func_inspection.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Unit tests for func_inspection module. - -Tests the create_call_wrapper function which determines how to call handler functions -with different parameter signatures and type hints. -""" - -from typing import Any, Generic, TypeVar - -import pytest - -from mcp.server.lowlevel.func_inspection import create_call_wrapper -from mcp.types import ListPromptsRequest, ListResourcesRequest, ListToolsRequest, PaginatedRequestParams - -T = TypeVar("T") - - -@pytest.mark.anyio -async def test_no_params_returns_deprecated_wrapper() -> None: - """Test: def foo() - should call without request.""" - called_without_request = False - - async def handler() -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_param_with_default_returns_deprecated_wrapper() -> None: - """Test: def foo(thing: int = 1) - should call without request.""" - called_without_request = False - - async def handler(thing: int = 1) -> list[str]: - nonlocal called_without_request - called_without_request = True - return [f"test-{thing}"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request (uses default value) - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test-1"] - - -@pytest.mark.anyio -async def test_typed_request_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor="test-cursor")) - await wrapper(request) - - assert received_request is not None - assert received_request is request - params = getattr(received_request, "params", None) - assert params is not None - assert params.cursor == "test-cursor" - - -@pytest.mark.anyio -async def test_typed_request_with_default_param_passes_request() -> None: - """Test: def foo(req: ListPromptsRequest, thing: int = 1) - should pass request through.""" - received_request = None - received_thing = None - - async def handler(req: ListPromptsRequest, thing: int = 1) -> list[str]: - nonlocal received_request, received_thing - received_request = req - received_thing = thing - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - assert received_thing == 1 # default value - - -@pytest.mark.anyio -async def test_optional_typed_request_with_default_none_is_deprecated() -> None: - """Test: def foo(thing: int = 1, req: ListPromptsRequest | None = None) - old style.""" - called_without_request = False - - async def handler(thing: int = 1, req: ListPromptsRequest | None = None) -> list[str]: - nonlocal called_without_request - called_without_request = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request - request = ListPromptsRequest(method="prompts/list", params=None) - result = await wrapper(request) - assert called_without_request is True - assert result == ["test"] - - -@pytest.mark.anyio -async def test_untyped_request_param_is_deprecated() -> None: - """Test: def foo(req) - should call without request.""" - called = False - - async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] # pragma: no cover - nonlocal called - called = True - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) # pyright: ignore[reportUnknownArgumentType] - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_any_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Any) - should call without request.""" - - async def handler(req: Any) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_generic_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: Generic[T]) - should call without request.""" - - async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_wrong_typed_request_param_is_deprecated() -> None: - """Test: def foo(req: str) - should call without request.""" - - async def handler(req: str) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should call handler without passing request, which will fail because req is required - request = ListPromptsRequest(method="prompts/list", params=None) - # This will raise TypeError because handler expects 'req' but wrapper doesn't provide it - with pytest.raises(TypeError, match="missing 1 required positional argument"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_required_param_before_typed_request_attempts_to_pass() -> None: - """Test: def foo(thing: int, req: ListPromptsRequest) - attempts to pass request (will fail at runtime).""" - received_request = None - - async def handler(thing: int, req: ListPromptsRequest) -> list[str]: # pragma: no cover - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper will attempt to pass request, but it will fail at runtime - # because 'thing' is required and has no default - request = ListPromptsRequest(method="prompts/list", params=None) - - # This will raise TypeError because 'thing' is missing - with pytest.raises(TypeError, match="missing 1 required positional argument: 'thing'"): - await wrapper(request) - - -@pytest.mark.anyio -async def test_positional_only_param_with_correct_type() -> None: - """Test: def foo(req: ListPromptsRequest, /) - should pass request through.""" - received_request = None - - async def handler(req: ListPromptsRequest, /) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_keyword_only_param_with_correct_type() -> None: - """Test: def foo(*, req: ListPromptsRequest) - should pass request through.""" - received_request = None - - async def handler(*, req: ListPromptsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Wrapper should pass request to handler with keyword argument - request = ListPromptsRequest(method="prompts/list", params=None) - await wrapper(request) - - assert received_request is request - - -@pytest.mark.anyio -async def test_different_request_types() -> None: - """Test that wrapper works with different request types.""" - # Test with ListResourcesRequest - received_request = None - - async def handler(req: ListResourcesRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper = create_call_wrapper(handler, ListResourcesRequest) - - request = ListResourcesRequest(method="resources/list", params=None) - await wrapper(request) - - assert received_request is request - - # Test with ListToolsRequest - received_request = None - - async def handler2(req: ListToolsRequest) -> list[str]: - nonlocal received_request - received_request = req - return ["test"] - - wrapper2 = create_call_wrapper(handler2, ListToolsRequest) - - request2 = ListToolsRequest(method="tools/list", params=None) - await wrapper2(request2) - - assert received_request is request2 - - -@pytest.mark.anyio -async def test_mixed_params_with_typed_request() -> None: - """Test: def foo(a: str, req: ListPromptsRequest, b: int = 5) - attempts to pass request.""" - - async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: # pragma: no cover - return ["test"] - - wrapper = create_call_wrapper(handler, ListPromptsRequest) - - # Will fail at runtime due to missing 'a' - request = ListPromptsRequest(method="prompts/list", params=None) - - with pytest.raises(TypeError, match="missing 1 required positional argument: 'a'"): - await wrapper(request) diff --git a/tests/server/lowlevel/test_server_listing.py b/tests/server/lowlevel/test_server_listing.py index 6bf4cddb3..2c3d303a9 100644 --- a/tests/server/lowlevel/test_server_listing.py +++ b/tests/server/lowlevel/test_server_listing.py @@ -1,20 +1,16 @@ -"""Basic tests for list_prompts, list_resources, and list_tools decorators without pagination.""" - -import warnings +"""Basic tests for list_prompts, list_resources, and list_tools handlers without pagination.""" import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, + PaginatedRequestParams, Prompt, Resource, - ServerResult, Tool, ) @@ -22,60 +18,44 @@ @pytest.mark.anyio async def test_list_prompts_basic() -> None: """Test basic prompt listing without pagination.""" - server = Server("test") - test_prompts = [ Prompt(name="prompt1", description="First prompt"), Prompt(name="prompt2", description="Second prompt"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return test_prompts - - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=test_prompts) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == test_prompts + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == test_prompts @pytest.mark.anyio async def test_list_resources_basic() -> None: """Test basic resource listing without pagination.""" - server = Server("test") - test_resources = [ Resource(uri="file:///test1.txt", name="Test 1"), Resource(uri="file:///test2.txt", name="Test 2"), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return test_resources + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=test_resources) - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == test_resources + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == test_resources @pytest.mark.anyio async def test_list_tools_basic() -> None: """Test basic tool listing without pagination.""" - server = Server("test") - test_tools = [ Tool( name="tool1", @@ -102,80 +82,53 @@ async def test_list_tools_basic() -> None: ), ] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=test_tools) - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return test_tools - - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == test_tools + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == test_tools @pytest.mark.anyio async def test_list_prompts_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_prompts() - async def handle_list_prompts() -> list[Prompt]: - return [] - handler = server.request_handlers[ListPromptsRequest] - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + return ListPromptsResult(prompts=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListPromptsResult) - assert result.prompts == [] + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + result = await client.list_prompts() + assert result.prompts == [] @pytest.mark.anyio async def test_list_resources_empty() -> None: """Test listing with empty results.""" - server = Server("test") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult(resources=[]) - @server.list_resources() - async def handle_list_resources() -> list[Resource]: - return [] - - handler = server.request_handlers[ListResourcesRequest] - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - - assert isinstance(result, ServerResult) - assert isinstance(result, ListResourcesResult) - assert result.resources == [] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + result = await client.list_resources() + assert result.resources == [] @pytest.mark.anyio async def test_list_tools_empty() -> None: """Test listing with empty results.""" - server = Server("test") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [] - handler = server.request_handlers[ListToolsRequest] - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) - assert isinstance(result, ServerResult) - assert isinstance(result, ListToolsResult) - assert result.tools == [] + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.tools == [] diff --git a/tests/server/lowlevel/test_server_pagination.py b/tests/server/lowlevel/test_server_pagination.py index 081fb262a..a4627b316 100644 --- a/tests/server/lowlevel/test_server_pagination.py +++ b/tests/server/lowlevel/test_server_pagination.py @@ -1,111 +1,83 @@ import pytest -from mcp.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.types import ( - ListPromptsRequest, ListPromptsResult, - ListResourcesRequest, ListResourcesResult, - ListToolsRequest, ListToolsResult, PaginatedRequestParams, - ServerResult, ) @pytest.mark.anyio async def test_list_prompts_pagination() -> None: - server = Server("test") test_cursor = "test-cursor-123" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListPromptsRequest | None = None - - @server.list_prompts() - async def handle_list_prompts(request: ListPromptsRequest) -> ListPromptsResult: - nonlocal received_request - received_request = request + async def handle_list_prompts( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListPromptsResult: + nonlocal received_params + received_params = params return ListPromptsResult(prompts=[], next_cursor="next") - handler = server.request_handlers[ListPromptsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListPromptsRequest(method="prompts/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) + server = Server("test", on_list_prompts=handle_list_prompts) + async with Client(server) as client: + # No cursor provided + await client.list_prompts() + assert received_params is not None + assert received_params.cursor is None - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListPromptsRequest(method="prompts/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_prompts(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_resources_pagination() -> None: - server = Server("test") test_cursor = "resource-cursor-456" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListResourcesRequest | None = None - - @server.list_resources() - async def handle_list_resources(request: ListResourcesRequest) -> ListResourcesResult: - nonlocal received_request - received_request = request + async def handle_list_resources( + ctx: ServerRequestContext, params: PaginatedRequestParams | None + ) -> ListResourcesResult: + nonlocal received_params + received_params = params return ListResourcesResult(resources=[], next_cursor="next") - handler = server.request_handlers[ListResourcesRequest] + server = Server("test", on_list_resources=handle_list_resources) + async with Client(server) as client: + # No cursor provided + await client.list_resources() + assert received_params is not None + assert received_params.cursor is None - # Test: No cursor provided -> handler receives request with None params - request = ListResourcesRequest(method="resources/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListResourcesRequest( - method="resources/list", params=PaginatedRequestParams(cursor=test_cursor) - ) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + # Cursor provided + await client.list_resources(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor @pytest.mark.anyio async def test_list_tools_pagination() -> None: - server = Server("test") test_cursor = "tools-cursor-789" + received_params: PaginatedRequestParams | None = None - # Track what request was received - received_request: ListToolsRequest | None = None - - @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: - nonlocal received_request - received_request = request + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + nonlocal received_params + received_params = params return ListToolsResult(tools=[], next_cursor="next") - handler = server.request_handlers[ListToolsRequest] - - # Test: No cursor provided -> handler receives request with None params - request = ListToolsRequest(method="tools/list", params=None) - result = await handler(request) - assert received_request is not None - assert received_request.params is None - assert isinstance(result, ServerResult) - - # Test: Cursor provided -> handler receives request with cursor in params - request_with_cursor = ListToolsRequest(method="tools/list", params=PaginatedRequestParams(cursor=test_cursor)) - result2 = await handler(request_with_cursor) - assert received_request is not None - assert received_request.params is not None - assert received_request.params.cursor == test_cursor - assert isinstance(result2, ServerResult) + server = Server("test", on_list_tools=handle_list_tools) + async with Client(server) as client: + # No cursor provided + await client.list_tools() + assert received_params is not None + assert received_params.cursor is None + + # Cursor provided + await client.list_tools(cursor=test_cursor) + assert received_params is not None + assert received_params.cursor == test_cursor diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index a78a86cf0..602f5cc75 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -21,7 +21,8 @@ RefreshToken, construct_redirect_uri, ) -from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes +from mcp.server.auth.routes import create_auth_routes +from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthToken diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index 035e1cc81..553e47363 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -2,8 +2,8 @@ import pytest -from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, TextContent, UserMessage -from mcp.types import EmbeddedResource, TextResourceContents +from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, UserMessage +from mcp.types import EmbeddedResource, TextContent, TextResourceContents class TestRenderPrompt: diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 0e30b2e69..02f91c680 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -1,7 +1,8 @@ import pytest -from mcp.server.mcpserver.prompts.base import Prompt, TextContent, UserMessage +from mcp.server.mcpserver.prompts.base import Prompt, UserMessage from mcp.server.mcpserver.prompts.manager import PromptManager +from mcp.types import TextContent class TestPromptManager: diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 979dc580f..cfbe6587b 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1,7 +1,7 @@ import base64 from pathlib import Path from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from inline_snapshot import snapshot @@ -10,6 +10,8 @@ from starlette.routing import Mount, Route from mcp.client import Client +from mcp.server.context import ServerRequestContext +from mcp.server.experimental.request_context import Experimental from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.prompts.base import Message, UserMessage @@ -21,6 +23,9 @@ from mcp.types import ( AudioContent, BlobResourceContents, + Completion, + CompletionArgument, + CompletionContext, ContentBlock, EmbeddedResource, GetPromptResult, @@ -30,6 +35,7 @@ Prompt, PromptArgument, PromptMessage, + PromptReference, ReadResourceResult, Resource, ResourceTemplate, @@ -1401,6 +1407,23 @@ def prompt_fn(name: str) -> str: ... # pragma: no branch await client.get_prompt("prompt_fn") +async def test_completion_decorator() -> None: + """Test that the completion decorator registers a working handler.""" + mcp = MCPServer() + + @mcp.completion() + async def handle_completion( + ref: PromptReference, argument: CompletionArgument, context: CompletionContext | None + ) -> Completion: + assert argument.name == "style" + return Completion(values=["bold", "italic", "underline"]) + + async with Client(mcp) as client: + ref = PromptReference(type="ref/prompt", name="test") + result = await client.complete(ref=ref, argument={"name": "style", "value": "b"}) + assert result.completion.values == ["bold", "italic", "underline"] + + def test_streamable_http_no_redirect() -> None: """Test that streamable HTTP routes are correctly configured.""" mcp = MCPServer() @@ -1415,3 +1438,34 @@ def test_streamable_http_no_redirect() -> None: # Verify path values assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp" + + +async def test_report_progress_passes_related_request_id(): + """Test that report_progress passes the request_id as related_request_id. + + Without related_request_id, the streamable HTTP transport cannot route + progress notifications to the correct SSE stream, causing them to be + silently dropped. See #953 and #2001. + """ + mock_session = AsyncMock() + mock_session.send_progress_notification = AsyncMock() + + request_context = ServerRequestContext( + request_id="req-abc-123", + session=mock_session, + meta={"progress_token": "tok-1"}, + lifespan_context=None, + experimental=Experimental(), + ) + + ctx = Context(request_context=request_context, mcp_server=MagicMock()) + + await ctx.report_progress(50, 100, message="halfway") + + mock_session.send_progress_notification.assert_awaited_once_with( + progress_token="tok-1", + progress=50, + total=100, + message="halfway", + related_request_id="req-abc-123", + ) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 6d1634f2e..297f3d6a5 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -1,12 +1,10 @@ """Test that cancelled requests don't cause double responses.""" -from typing import Any - import anyio import pytest -from mcp import Client, types -from mcp.server.lowlevel.server import Server +from mcp import Client +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.types import ( CallToolRequest, @@ -14,6 +12,9 @@ CallToolResult, CancelledNotification, CancelledNotificationParams, + ListToolsResult, + PaginatedRequestParams, + TextContent, Tool, ) @@ -22,34 +23,34 @@ async def test_server_remains_functional_after_cancel(): """Verify server can handle new requests after a cancellation.""" - server = Server("test-server") - # Track tool calls call_count = 0 ev_first_call = anyio.Event() first_request_id = None - @server.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="Tool for testing", - input_schema={}, - ) - ] + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="Tool for testing", + input_schema={}, + ) + ] + ) - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: nonlocal call_count, first_request_id - if name == "test_tool": + if params.name == "test_tool": call_count += 1 if call_count == 1: - first_request_id = server.request_context.request_id + first_request_id = ctx.request_id ev_first_call.set() await anyio.sleep(5) # First call is slow - return [types.TextContent(type="text", text=f"Call number: {call_count}")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async with Client(server) as client: # First request (will be cancelled) @@ -86,6 +87,6 @@ async def first_request(): # Type narrowing for pyright content = result.content[0] assert content.type == "text" - assert isinstance(content, types.TextContent) + assert isinstance(content, TextContent) assert content.text == "Call number: 2" assert call_count == 2 diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index 5a8d67f09..a01d0d4d7 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -1,15 +1,13 @@ """Tests for completion handler with context functionality.""" -from typing import Any - import pytest from mcp import Client -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext from mcp.types import ( + CompleteRequestParams, + CompleteResult, Completion, - CompletionArgument, - CompletionContext, PromptReference, ResourceTemplateReference, ) @@ -18,23 +16,15 @@ @pytest.mark.anyio async def test_completion_handler_receives_context(): """Test that the completion handler receives context correctly.""" - server = Server("test-server") - # Track what the handler receives - received_args: dict[str, Any] = {} + received_params: CompleteRequestParams | None = None - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - received_args["ref"] = ref - received_args["argument"] = argument - received_args["context"] = context + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + nonlocal received_params + received_params = params + return CompleteResult(completion=Completion(values=["test-completion"], total=1, has_more=False)) - # Return test completion - return Completion(values=["test-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test with context @@ -45,28 +35,23 @@ async def handle_completion( ) # Verify handler received the context - assert received_args["context"] is not None - assert received_args["context"].arguments == {"previous": "value"} + assert received_params is not None + assert received_params.context is not None + assert received_params.context.arguments == {"previous": "value"} assert result.completion.values == ["test-completion"] @pytest.mark.anyio async def test_completion_backward_compatibility(): """Test that completion works without context (backward compatibility).""" - server = Server("test-server") - context_was_none = False - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: nonlocal context_was_none - context_was_none = context is None + context_was_none = params.context is None + return CompleteResult(completion=Completion(values=["no-context-completion"], total=1, has_more=False)) - return Completion(values=["no-context-completion"], total=1, has_more=False) + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Test without context @@ -82,30 +67,31 @@ async def handle_completion( @pytest.mark.anyio async def test_dependent_completion_scenario(): """Test a real-world scenario with dependent completions.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: # Simulate database/table completion scenario - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "database": - # Complete database names - return Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) - elif argument.name == "table": - # Complete table names based on selected database - if context and context.arguments: - db = context.arguments.get("database") - if db == "users_db": - return Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) - elif db == "products_db": - return Completion(values=["products", "categories", "inventory"], total=3, has_more=False) - - return Completion(values=[], total=0, has_more=False) # pragma: no cover + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "db://{database}/{table}" + + if params.argument.name == "database": + return CompleteResult( + completion=Completion(values=["users_db", "products_db", "analytics_db"], total=3, has_more=False) + ) + + assert params.argument.name == "table" + assert params.context and params.context.arguments + db = params.context.arguments.get("database") + if db == "users_db": + return CompleteResult( + completion=Completion(values=["users", "sessions", "permissions"], total=3, has_more=False) + ) + else: + assert db == "products_db" + return CompleteResult( + completion=Completion(values=["products", "categories", "inventory"], total=3, has_more=False) + ) + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # First, complete database @@ -136,27 +122,20 @@ async def handle_completion( @pytest.mark.anyio async def test_completion_error_on_missing_context(): """Test that server can raise error when required context is missing.""" - server = Server("test-server") - - @server.completion() - async def handle_completion( - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - if isinstance(ref, ResourceTemplateReference): - if ref.uri == "db://{database}/{table}": - if argument.name == "table": - # Check if database context is provided - if not context or not context.arguments or "database" not in context.arguments: - # Raise an error instead of returning error as completion - raise ValueError("Please select a database first to see available tables") - # Normal completion if context is provided - db = context.arguments.get("database") - if db == "test_db": - return Completion(values=["users", "orders", "products"], total=3, has_more=False) - - return Completion(values=[], total=0, has_more=False) # pragma: no cover + + async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "db://{database}/{table}" + assert params.argument.name == "table" + + if not params.context or not params.context.arguments or "database" not in params.context.arguments: + raise ValueError("Please select a database first to see available tables") + + db = params.context.arguments.get("database") + assert db == "test_db" + return CompleteResult(completion=Completion(values=["users", "orders", "products"], total=3, has_more=False)) + + server = Server("test-server", on_completion=handle_completion) async with Client(server) as client: # Try to complete table without database context - should raise error diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a303664a5..0f8840d29 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -2,18 +2,20 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any import anyio import pytest from pydantic import TypeAdapter +from mcp.server import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( + CallToolRequestParams, + CallToolResult, ClientCapabilities, Implementation, InitializeRequestParams, @@ -39,20 +41,20 @@ async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]: finally: context["shutdown"] = True - server = Server[dict[str, bool]]("test", lifespan=test_lifespan) - - # Create memory streams for testing - send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) - # Create a tool that accesses lifespan context - @server.call_tool() - async def check_lifespan(name: str, arguments: dict[str, Any]) -> list[TextContent]: - ctx = server.request_context + async def check_lifespan( + ctx: ServerRequestContext[dict[str, bool]], params: CallToolRequestParams + ) -> CallToolResult: assert isinstance(ctx.lifespan_context, dict) assert ctx.lifespan_context["started"] assert not ctx.lifespan_context["shutdown"] - return [TextContent(type="text", text="true")] + return CallToolResult(content=[TextContent(type="text", text="true")]) + + server = Server[dict[str, bool]]("test", lifespan=test_lifespan, on_call_tool=check_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py deleted file mode 100644 index 3f977bcc1..000000000 --- a/tests/server/test_lowlevel_input_validation.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Test input schema validation for lowlevel server.""" - -import logging -from collections.abc import Awaitable, Callable -from typing import Any - -import anyio -import pytest - -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool - - -async def run_tool_test( - tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[list[TextContent]]], - test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], -) -> CallToolResult | None: - """Helper to run a tool test with minimal boilerplate. - - Args: - tools: List of tools to register - call_tool_handler: Handler function for tool calls - test_callback: Async function that performs the test using the client session - - Returns: - The result of the tool call - """ - server = Server("test") - result = None - - @server.list_tools() - async def list_tools(): - return tools - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: - return await call_tool_handler(name, arguments) - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # Run the test callback - result = await test_callback(client_session) - - # Cancel the server task - tg.cancel_scope.cancel() - - return result - - -def create_add_tool() -> Tool: - """Create a standard 'add' tool for testing.""" - return Tool( - name="add", - description="Add two numbers", - input_schema={ - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"}, - }, - "required": ["a", "b"], - "additionalProperties": False, - }, - ) - - -@pytest.mark.anyio -async def test_valid_tool_call(): - """Test that valid arguments pass validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "add": - result = arguments["a"] + arguments["b"] - return [TextContent(type="text", text=f"Result: {result}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": 5, "b": 3}) - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Result: 8" - - -@pytest.mark.anyio -async def test_invalid_tool_call_missing_required(): - """Test that missing required arguments fail validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": 5}) # missing 'b' - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'b' is a required property" in result.content[0].text - - -@pytest.mark.anyio -async def test_invalid_tool_call_wrong_type(): - """Test that wrong argument types fail validation.""" - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("add", {"a": "five", "b": 3}) # 'a' should be number - - result = await run_tool_test([create_add_tool()], call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'five' is not of type 'number'" in result.content[0].text - - -@pytest.mark.anyio -async def test_cache_refresh_on_missing_tool(): - """Test that tool cache is refreshed when tool is not found.""" - tools = [ - Tool( - name="multiply", - description="Multiply two numbers", - input_schema={ - "type": "object", - "properties": { - "x": {"type": "number"}, - "y": {"type": "number"}, - }, - "required": ["x", "y"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "multiply": - result = arguments["x"] * arguments["y"] - return [TextContent(type="text", text=f"Result: {result}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - # Call tool without first listing tools (cache should be empty) - # The cache should be refreshed automatically - return await client_session.call_tool("multiply", {"x": 10, "y": 20}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - should work because cache will be refreshed - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Result: 200" - - -@pytest.mark.anyio -async def test_enum_constraint_validation(): - """Test that enum constraints are validated.""" - tools = [ - Tool( - name="greet", - description="Greet someone", - input_schema={ - "type": "object", - "properties": { - "name": {"type": "string"}, - "title": {"type": "string", "enum": ["Mr", "Ms", "Dr"]}, - }, - "required": ["name"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover - # This should not be reached due to validation failure - raise RuntimeError("Should not reach here") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("greet", {"name": "Smith", "title": "Prof"}) # Invalid title - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Input validation error" in result.content[0].text - assert "'Prof' is not one of" in result.content[0].text - - -@pytest.mark.anyio -async def test_tool_not_in_list_logs_warning(caplog: pytest.LogCaptureFixture): - """Test that calling a tool not in list_tools logs a warning and skips validation.""" - tools = [ - Tool( - name="add", - description="Add two numbers", - input_schema={ - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"}, - }, - "required": ["a", "b"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This should be reached since validation is skipped for unknown tools - if name == "unknown_tool": - # Even with invalid arguments, this should execute since validation is skipped - return [TextContent(type="text", text="Unknown tool executed without validation")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - # Call a tool that's not in the list with invalid arguments - # This should trigger the warning about validation not being performed - return await client_session.call_tool("unknown_tool", {"invalid": "args"}) - - with caplog.at_level(logging.WARNING): - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - should succeed because validation is skipped for unknown tools - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Unknown tool executed without validation" - - # Verify warning was logged - assert any( - "Tool 'unknown_tool' not listed, no validation will be performed" in record.message for record in caplog.records - ) diff --git a/tests/server/test_lowlevel_output_validation.py b/tests/server/test_lowlevel_output_validation.py deleted file mode 100644 index 92d9c047c..000000000 --- a/tests/server/test_lowlevel_output_validation.py +++ /dev/null @@ -1,476 +0,0 @@ -"""Test output schema validation for lowlevel server.""" - -import json -from collections.abc import Awaitable, Callable -from typing import Any - -import anyio -import pytest - -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool - - -async def run_tool_test( - tools: list[Tool], - call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[Any]], - test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], -) -> CallToolResult | None: - """Helper to run a tool test with minimal boilerplate. - - Args: - tools: List of tools to register - call_tool_handler: Handler function for tool calls - test_callback: Async function that performs the test using the client session - - Returns: - The result of the tool call - """ - server = Server("test") - - result = None - - @server.list_tools() - async def list_tools(): - return tools - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - return await call_tool_handler(name, arguments) - - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( # pragma: no cover - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # Run the test callback - result = await test_callback(client_session) - - # Cancel the server task - tg.cancel_scope.cancel() - - return result - - -@pytest.mark.anyio -async def test_content_only_without_output_schema(): - """Test returning content only when no outputSchema is defined.""" - tools = [ - Tool( - name="echo", - description="Echo a message", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, - }, - "required": ["message"], - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - if name == "echo": - return [TextContent(type="text", text=f"Echo: {arguments['message']}")] - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("echo", {"message": "Hello"}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Echo: Hello" - assert result.structured_content is None - - -@pytest.mark.anyio -async def test_dict_only_without_output_schema(): - """Test returning dict only when no outputSchema is defined.""" - tools = [ - Tool( - name="get_info", - description="Get structured information", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "get_info": - return {"status": "ok", "data": {"value": 42}} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("get_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - # Check that the content is the JSON serialization - assert json.loads(result.content[0].text) == {"status": "ok", "data": {"value": 42}} - assert result.structured_content == {"status": "ok", "data": {"value": 42}} - - -@pytest.mark.anyio -async def test_both_content_and_dict_without_output_schema(): - """Test returning both content and dict when no outputSchema is defined.""" - tools = [ - Tool( - name="process", - description="Process data", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema defined - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "process": - content = [TextContent(type="text", text="Processing complete")] - data = {"result": "success", "count": 10} - return (content, data) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("process", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Processing complete" - assert result.structured_content == {"result": "success", "count": 10} - - -@pytest.mark.anyio -async def test_content_only_with_output_schema_error(): - """Test error when outputSchema is defined but only content is returned.""" - tools = [ - Tool( - name="structured_tool", - description="Tool expecting structured output", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "result": {"type": "string"}, - }, - "required": ["result"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This returns only content, but outputSchema expects structured data - return [TextContent(type="text", text="This is not structured")] - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("structured_tool", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Output validation error: outputSchema defined but no structured output returned" in result.content[0].text - - -@pytest.mark.anyio -async def test_valid_dict_with_output_schema(): - """Test valid dict output matching outputSchema.""" - tools = [ - Tool( - name="calc", - description="Calculate result", - input_schema={ - "type": "object", - "properties": { - "x": {"type": "number"}, - "y": {"type": "number"}, - }, - "required": ["x", "y"], - }, - output_schema={ - "type": "object", - "properties": { - "sum": {"type": "number"}, - "product": {"type": "number"}, - }, - "required": ["sum", "product"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "calc": - x = arguments["x"] - y = arguments["y"] - return {"sum": x + y, "product": x * y} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("calc", {"x": 3, "y": 4}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - # Check JSON serialization - assert json.loads(result.content[0].text) == {"sum": 7, "product": 12} - assert result.structured_content == {"sum": 7, "product": 12} - - -@pytest.mark.anyio -async def test_invalid_dict_with_output_schema(): - """Test dict output that doesn't match outputSchema.""" - tools = [ - Tool( - name="user_info", - description="Get user information", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name", "age"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "user_info": - # Missing required 'age' field - return {"name": "Alice"} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("user_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert "Output validation error:" in result.content[0].text - assert "'age' is a required property" in result.content[0].text - - -@pytest.mark.anyio -async def test_both_content_and_valid_dict_with_output_schema(): - """Test returning both content and valid dict with outputSchema.""" - tools = [ - Tool( - name="analyze", - description="Analyze data", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string"}, - }, - "required": ["text"], - }, - output_schema={ - "type": "object", - "properties": { - "sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]}, - "confidence": {"type": "number", "minimum": 0, "maximum": 1}, - }, - "required": ["sentiment", "confidence"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[TextContent], dict[str, Any]]: - if name == "analyze": - content = [TextContent(type="text", text=f"Analysis of: {arguments['text']}")] - data = {"sentiment": "positive", "confidence": 0.95} - return (content, data) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("analyze", {"text": "Great job!"}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Analysis of: Great job!" - assert result.structured_content == {"sentiment": "positive", "confidence": 0.95} - - -@pytest.mark.anyio -async def test_tool_call_result(): - """Test returning ToolCallResult when no outputSchema is defined.""" - tools = [ - Tool( - name="get_info", - description="Get structured information", - input_schema={ - "type": "object", - "properties": {}, - }, - # No outputSchema for direct return of tool call result - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> CallToolResult: - if name == "get_info": - return CallToolResult( - content=[TextContent(type="text", text="Results calculated")], - structured_content={"status": "ok", "data": {"value": 42}}, - _meta={"some": "metadata"}, - ) - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("get_info", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify results - assert result is not None - assert not result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Results calculated" - assert isinstance(result.content[0], TextContent) - assert result.structured_content == {"status": "ok", "data": {"value": 42}} - assert result.meta == {"some": "metadata"} - - -@pytest.mark.anyio -async def test_output_schema_type_validation(): - """Test outputSchema validates types correctly.""" - tools = [ - Tool( - name="stats", - description="Get statistics", - input_schema={ - "type": "object", - "properties": {}, - }, - output_schema={ - "type": "object", - "properties": { - "count": {"type": "integer"}, - "average": {"type": "number"}, - "items": {"type": "array", "items": {"type": "string"}}, - }, - "required": ["count", "average", "items"], - }, - ) - ] - - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - if name == "stats": - # Wrong type for 'count' - should be integer - return {"count": "five", "average": 2.5, "items": ["a", "b"]} - else: # pragma: no cover - raise ValueError(f"Unknown tool: {name}") - - async def test_callback(client_session: ClientSession) -> CallToolResult: - return await client_session.call_tool("stats", {}) - - result = await run_tool_test(tools, call_tool_handler, test_callback) - - # Verify error - assert result is not None - assert result.is_error - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert "Output validation error:" in result.content[0].text - assert "'five' is not of type 'integer'" in result.content[0].text diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 68543136e..705abdfe8 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -1,100 +1,44 @@ """Tests for tool annotations in low-level server.""" -import anyio import pytest -from mcp.client.session import ClientSession -from mcp.server import Server -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations +from mcp import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ListToolsResult, PaginatedRequestParams, Tool, ToolAnnotations @pytest.mark.anyio async def test_lowlevel_server_tool_annotations(): """Test that tool annotations work in low-level server.""" - server = Server("test") - # Create a tool with annotations - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="echo", - description="Echo a message back", - input_schema={ - "type": "object", - "properties": { - "message": {"type": "string"}, + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + description="Echo a message back", + input_schema={ + "type": "object", + "properties": { + "message": {"type": "string"}, + }, + "required": ["message"], }, - "required": ["message"], - }, - annotations=ToolAnnotations( - title="Echo Tool", - read_only_hint=True, - ), - ) - ] - - tools_result = None - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) - - # Message handler for client - async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # Server task - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) as server_session: - async with anyio.create_task_group() as tg: - - async def handle_messages(): - async for message in server_session.incoming_messages: # pragma: no branch - await server._handle_message(message, server_session, {}, False) - - tg.start_soon(handle_messages) - await anyio.sleep_forever() - - # Run the test - async with anyio.create_task_group() as tg: - tg.start_soon(run_server) - - async with ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session: - # Initialize the session - await client_session.initialize() - - # List tools - tools_result = await client_session.list_tools() - - # Cancel the server task - tg.cancel_scope.cancel() - - # Verify results - assert tools_result is not None - assert len(tools_result.tools) == 1 - assert tools_result.tools[0].name == "echo" - assert tools_result.tools[0].annotations is not None - assert tools_result.tools[0].annotations.title == "Echo Tool" - assert tools_result.tools[0].annotations.read_only_hint is True + annotations=ToolAnnotations( + title="Echo Tool", + read_only_hint=True, + ), + ) + ] + ) + + server = Server("test", on_list_tools=handle_list_tools) + + async with Client(server) as client: + tools_result = await client.list_tools() + + assert len(tools_result.tools) == 1 + assert tools_result.tools[0].name == "echo" + assert tools_result.tools[0].annotations is not None + assert tools_result.tools[0].annotations.title == "Echo Tool" + assert tools_result.tools[0].annotations.read_only_hint is True diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 88fd1e38f..102a58d03 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -1,106 +1,58 @@ -from collections.abc import Iterable -from pathlib import Path -from tempfile import NamedTemporaryFile +import base64 import pytest -from mcp import types -from mcp.server.lowlevel.server import ReadResourceContents, Server +from mcp import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + BlobResourceContents, + ReadResourceRequestParams, + ReadResourceResult, + TextResourceContents, +) +pytestmark = pytest.mark.anyio -@pytest.fixture -def temp_file(): - """Create a temporary file for testing.""" - with NamedTemporaryFile(mode="w", delete=False) as f: - f.write("test content") - path = Path(f.name).resolve() - yield path - try: - path.unlink() - except FileNotFoundError: # pragma: no cover - pass +async def test_read_resource_text(): + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text="Hello World", mime_type="text/plain")] + ) -@pytest.mark.anyio -async def test_read_resource_text(temp_file: Path): - server = Server("test") + server = Server("test", on_read_resource=handle_read_resource) - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content="Hello World", mime_type="text/plain")] + async with Client(server) as client: + result = await client.read_resource("test://resource") + assert len(result.contents) == 1 - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Hello World" + assert content.mime_type == "text/plain" - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 +async def test_read_resource_binary(): + binary_data = b"Hello World" - content = result.contents[0] - assert isinstance(content, types.TextResourceContents) - assert content.text == "Hello World" - assert content.mime_type == "text/plain" + async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=str(params.uri), + blob=base64.b64encode(binary_data).decode("utf-8"), + mime_type="application/octet-stream", + ) + ] + ) + server = Server("test", on_read_resource=handle_read_resource) -@pytest.mark.anyio -async def test_read_resource_binary(temp_file: Path): - server = Server("test") + async with Client(server) as client: + result = await client.read_resource("test://resource") + assert len(result.contents) == 1 - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")] - - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] - - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 - - content = result.contents[0] - assert isinstance(content, types.BlobResourceContents) - assert content.mime_type == "application/octet-stream" - - -@pytest.mark.anyio -async def test_read_resource_default_mime(temp_file: Path): - server = Server("test") - - @server.read_resource() - async def read_resource(uri: str) -> Iterable[ReadResourceContents]: - return [ - ReadResourceContents( - content="Hello World", - # No mime_type specified, should default to text/plain - ) - ] - - # Get the handler directly from the server - handler = server.request_handlers[types.ReadResourceRequest] - - # Create a request - request = types.ReadResourceRequest( - params=types.ReadResourceRequestParams(uri=temp_file.as_uri()), - ) - - # Call the handler - result = await handler(request) - assert isinstance(result, types.ReadResourceResult) - assert len(result.contents) == 1 - - content = result.contents[0] - assert isinstance(content, types.TextResourceContents) - assert content.text == "Hello World" - assert content.mime_type == "text/plain" + content = result.contents[0] + assert isinstance(content, BlobResourceContents) + assert content.mime_type == "application/octet-stream" + assert base64.b64decode(content.blob) == binary_data diff --git a/tests/server/test_session.py b/tests/server/test_session.py index d353e46e4..a2786d865 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -5,7 +5,7 @@ from mcp import types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -14,17 +14,10 @@ from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, - Completion, - CompletionArgument, - CompletionContext, CompletionsCapability, InitializedNotification, - Prompt, - PromptReference, PromptsCapability, - Resource, ResourcesCapability, - ResourceTemplateReference, ServerCapabilities, ) @@ -85,47 +78,50 @@ async def run_server(): @pytest.mark.anyio async def test_server_capabilities(): - server = Server("test") notification_options = NotificationOptions() experimental_capabilities: dict[str, Any] = {} - # Initially no capabilities + async def noop_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + raise NotImplementedError + + async def noop_list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + raise NotImplementedError + + async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + raise NotImplementedError + + # No capabilities + server = Server("test") caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts is None assert caps.resources is None assert caps.completions is None - # Add a prompts handler - @server.list_prompts() - async def list_prompts() -> list[Prompt]: # pragma: no cover - return [] - + # With prompts handler + server = Server("test", on_list_prompts=noop_list_prompts) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources is None assert caps.completions is None - # Add a resources handler - @server.list_resources() - async def list_resources() -> list[Resource]: # pragma: no cover - return [] - + # With prompts + resources handlers + server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) assert caps.completions is None - # Add a complete handler - @server.completion() - async def complete( # pragma: no cover - ref: PromptReference | ResourceTemplateReference, - argument: CompletionArgument, - context: CompletionContext | None, - ) -> Completion | None: - return Completion( - values=["completion1", "completion2"], - ) - + # With prompts + resources + completion handlers + server = Server( + "test", + on_list_prompts=noop_list_prompts, + on_list_resources=noop_list_resources, + on_completion=noop_completion, + ) caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(list_changed=False) assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 51572baa9..54a898cc5 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -9,13 +9,12 @@ import pytest from starlette.types import Message -from mcp import Client, types +from mcp import Client from mcp.client.streamable_http import streamable_http_client -from mcp.server import streamable_http_manager -from mcp.server.lowlevel import Server +from mcp.server import Server, ServerRequestContext, streamable_http_manager from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.types import INVALID_REQUEST +from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams @pytest.mark.anyio @@ -218,7 +217,7 @@ async def test_stateless_requests_memory_cleanup(): # Patch StreamableHTTPServerTransport constructor to track instances - original_constructor = streamable_http_manager.StreamableHTTPServerTransport + original_constructor = StreamableHTTPServerTransport def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: transport = original_constructor(*args, **kwargs) @@ -313,7 +312,7 @@ async def mock_receive(): # Verify JSON-RPC error format error_data = json.loads(response_body) assert error_data["jsonrpc"] == "2.0" - assert error_data["id"] == "server-error" + assert error_data["id"] is None assert error_data["error"]["code"] == INVALID_REQUEST assert error_data["error"]["message"] == "Session not found" @@ -321,12 +320,11 @@ async def mock_receive(): @pytest.mark.anyio async def test_e2e_streamable_http_server_cleanup(): host = "testserver" - app = Server("test-server") - @app.list_tools() - async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: - return types.ListToolsResult(tools=[]) + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[]) + app = Server("test-server", on_list_tools=handle_list_tools) mcp_app = app.streamable_http_app(host=host) async with ( mcp_app.router.lifespan_context(mcp_app), @@ -335,3 +333,80 @@ async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client, ): await client.list_tools() + + +@pytest.mark.anyio +async def test_idle_session_is_reaped(): + """After idle timeout fires, the session returns 404.""" + app = Server("test-idle-reap") + manager = StreamableHTTPSessionManager(app=app, session_idle_timeout=0.05) + + async with manager.run(): + sent_messages: list[Message] = [] + + async def mock_send(message: Message): + sent_messages.append(message) + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [(b"content-type", b"application/json")], + } + + async def mock_receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(scope, mock_receive, mock_send) + + session_id = None + for msg in sent_messages: # pragma: no branch + if msg["type"] == "http.response.start": # pragma: no branch + for header_name, header_value in msg.get("headers", []): # pragma: no branch + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # pragma: no branch + break + + assert session_id is not None, "Session ID not found in response headers" + + # Wait for the 50ms idle timeout to fire and cleanup to complete + await anyio.sleep(0.1) + + # Verify via public API: old session ID now returns 404 + response_messages: list[Message] = [] + + async def capture_send(message: Message): + response_messages.append(message) + + scope_with_session = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"content-type", b"application/json"), + (b"mcp-session-id", session_id.encode()), + ], + } + + await manager.handle_request(scope_with_session, mock_receive, capture_send) + + response_start = next( + (msg for msg in response_messages if msg["type"] == "http.response.start"), + None, + ) + assert response_start is not None + assert response_start["status"] == 404 + + +def test_session_idle_timeout_rejects_non_positive(): + with pytest.raises(ValueError, match="positive number"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=-1) + with pytest.raises(ValueError, match="positive number"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=0) + + +def test_session_idle_timeout_rejects_stateless(): + with pytest.raises(RuntimeError, match="not supported in stateless"): + StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 2c1c16dc3..5ae0e22b0 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -104,7 +104,7 @@ def test_check_resource_allowed_trailing_slash_handling(): """Trailing slashes should be handled correctly.""" # With and without trailing slashes assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is False + assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is True assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py deleted file mode 100644 index 31238b9ff..000000000 --- a/tests/shared/test_memory.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from mcp import Client -from mcp.server import Server -from mcp.types import EmptyResult, Resource - - -@pytest.fixture -def mcp_server() -> Server: - server = Server(name="test_server") - - @server.list_resources() - async def handle_list_resources(): # pragma: no cover - return [ - Resource( - uri="memory://test", - name="Test Resource", - description="A test resource", - ) - ] - - return server - - -@pytest.mark.anyio -async def test_memory_server_and_client_connection(mcp_server: Server): - """Shows how a client and server can communicate over memory streams.""" - async with Client(mcp_server) as client: - response = await client.send_ping() - assert isinstance(response, EmptyResult) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index ab117f1f0..aad9e5d43 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -6,13 +6,11 @@ from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.progress import progress from mcp.shared.session import RequestResponder @@ -35,9 +33,6 @@ async def run_server(): capabilities=server.get_capabilities(NotificationOptions(), {}), ), ) as server_session: - global serv_sesh - - serv_sesh = server_session async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) @@ -52,79 +47,73 @@ async def run_server(): server_progress_token = "server_token_123" client_progress_token = "client_token_456" - # Create a server with progress capability - server = Server(name="ProgressTestServer") - # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): + async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: server_progress_updates.append( { - "token": progress_token, - "progress": progress, - "total": total, - "message": message, + "token": params.progress_token, + "progress": params.progress, + "total": params.total, + "message": params.message, } ) # Register list tool handler - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="test_tool", - description="A tool that sends progress notifications types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list[types.TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: # Make sure we received a progress token - if name == "test_tool": - if arguments and "_meta" in arguments: - progressToken = arguments["_meta"]["progressToken"] - - if not progressToken: # pragma: no cover - raise ValueError("Empty progress token received") - - if progressToken != client_progress_token: # pragma: no cover - raise ValueError("Server sending back incorrect progressToken") - - # Send progress notifications - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.25, - total=1.0, - message="Server progress 25%", - ) + if params.name == "test_tool": + assert params.meta is not None + progress_token = params.meta.get("progress_token") + assert progress_token is not None + assert progress_token == client_progress_token + + # Send progress notifications using ctx.session + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.25, + total=1.0, + message="Server progress 25%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=0.5, - total=1.0, - message="Server progress 50%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=0.5, + total=1.0, + message="Server progress 50%", + ) - await serv_sesh.send_progress_notification( - progress_token=progressToken, - progress=1.0, - total=1.0, - message="Server progress 100%", - ) + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=1.0, + total=1.0, + message="Server progress 100%", + ) - else: # pragma: no cover - raise ValueError("Progress token not sent.") + return types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed successfully")]) - return [types.TextContent(type="text", text="Tool executed successfully")] + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover + # Create a server with progress capability + server = Server( + name="ProgressTestServer", + on_progress=handle_progress, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Client message handler to store progress notifications async def handle_client_message( @@ -164,7 +153,7 @@ async def handle_client_message( await client_session.list_tools() # Call test_tool with progress token - await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}}) + await client_session.call_tool("test_tool", meta={"progress_token": client_progress_token}) # Send progress notifications from client to server await client_session.send_progress_notification( @@ -207,118 +196,6 @@ async def handle_client_message( assert server_progress_updates[2]["progress"] == 1.0 -@pytest.mark.anyio -async def test_progress_context_manager(): - """Test client using progress context manager for sending progress notifications.""" - # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) - - # Track progress updates - server_progress_updates: list[dict[str, Any]] = [] - - server = Server(name="ProgressContextTestServer") - - progress_token = None - - # Register progress handler - @server.progress_notification() - async def handle_progress( - progress_token: str | int, - progress: float, - total: float | None, - message: str | None, - ): - server_progress_updates.append( - {"token": progress_token, "progress": progress, "total": total, "message": message} - ) - - # Run server session to receive progress updates - async def run_server(): - # Create a server session - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="ProgressContextTestServer", - server_version="0.1.0", - capabilities=server.get_capabilities(NotificationOptions(), {}), - ), - ) as server_session: - async for message in server_session.incoming_messages: - try: - await server._handle_message(message, server_session, {}) - except Exception as e: # pragma: no cover - raise e - - # Client message handler - async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - # run client session - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=handle_client_message, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - - await client_session.initialize() - - progress_token = "client_token_456" - - # Create request context - request_context = RequestContext( - request_id="test-request", - session=client_session, - meta={"progress_token": progress_token}, - ) - - # Utilize progress context manager - with progress(request_context, total=100) as p: - await p.progress(10, message="Loading configuration...") - await p.progress(30, message="Connecting to database...") - await p.progress(40, message="Fetching data...") - await p.progress(20, message="Processing results...") - - # Wait for all messages to be processed - await anyio.sleep(0.5) - tg.cancel_scope.cancel() - - # Verify progress updates were received by server - assert len(server_progress_updates) == 4 - - # first update - assert server_progress_updates[0]["token"] == progress_token - assert server_progress_updates[0]["progress"] == 10 - assert server_progress_updates[0]["total"] == 100 - assert server_progress_updates[0]["message"] == "Loading configuration..." - - # second update - assert server_progress_updates[1]["token"] == progress_token - assert server_progress_updates[1]["progress"] == 40 - assert server_progress_updates[1]["total"] == 100 - assert server_progress_updates[1]["message"] == "Connecting to database..." - - # third update - assert server_progress_updates[2]["token"] == progress_token - assert server_progress_updates[2]["progress"] == 80 - assert server_progress_updates[2]["total"] == 100 - assert server_progress_updates[2]["message"] == "Fetching data..." - - # final update - assert server_progress_updates[3]["token"] == progress_token - assert server_progress_updates[3]["progress"] == 100 - assert server_progress_updates[3]["total"] == 100 - assert server_progress_updates[3]["message"] == "Processing results..." - - @pytest.mark.anyio async def test_progress_callback_exception_logging(): """Test that exceptions in progress callbacks are logged and \ @@ -334,30 +211,37 @@ async def failing_progress_callback(progress: float, total: float | None, messag raise ValueError("Progress callback failed!") # Create a server with a tool that sends progress notifications - server = Server(name="TestProgressServer") - - @server.call_tool() - async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: - if name == "progress_tool": + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "progress_tool": + assert ctx.request_id is not None # Send a progress notification - await server.request_context.session.send_progress_notification( - progress_token=server.request_context.request_id, + await ctx.session.send_progress_notification( + progress_token=ctx.request_id, progress=50.0, total=100.0, message="Halfway done", ) - return [types.TextContent(type="text", text="progress_result")] - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="progress_tool", - description="A tool that sends progress notifications", - input_schema={}, - ) - ] + return types.CallToolResult(content=[types.TextContent(type="text", text="progress_result")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="progress_tool", + description="A tool that sends progress notifications", + input_schema={}, + ) + ] + ) + + server = Server( + name="TestProgressServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) # Test with mocked logging with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 182b4671d..d7c6cc3b5 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,23 +1,25 @@ -from typing import Any - import anyio import pytest from mcp import Client, types from mcp.client.session import ClientSession -from mcp.server.lowlevel.server import Server +from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder from mcp.types import ( + PARSE_ERROR, CancelledNotification, CancelledNotificationParams, + ClientResult, EmptyResult, ErrorData, JSONRPCError, JSONRPCRequest, JSONRPCResponse, - TextContent, + ServerNotification, + ServerRequest, ) @@ -42,29 +44,25 @@ async def test_request_cancellation(): request_id = None # Create a server with a slow tool - server = Server(name="TestSessionServer") - - # Register the tool handler - @server.call_tool() - async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: nonlocal request_id, ev_tool_called - if name == "slow_tool": - request_id = server.request_context.request_id + if params.name == "slow_tool": + request_id = ctx.request_id ev_tool_called.set() await anyio.sleep(10) # Long enough to ensure we can cancel - return [] # pragma: no cover - raise ValueError(f"Unknown tool: {name}") # pragma: no cover - - # Register the tool so it shows up in list_tools - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - return [ - types.Tool( - name="slow_tool", - description="A slow tool that takes 10 seconds to complete", - input_schema={}, - ) - ] + return types.CallToolResult(content=[]) # pragma: no cover + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + raise NotImplementedError + + server = Server( + name="TestSessionServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) async def make_request(client: Client): nonlocal ev_cancelled @@ -304,3 +302,117 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): # pragma: no branch await ev_response.wait() + + +@pytest.mark.anyio +async def test_null_id_error_surfaced_via_message_handler(): + """Test that a JSONRPCError with id=None is surfaced to the message handler. + + Per JSON-RPC 2.0, error responses use id=null when the request id could not + be determined (e.g., parse errors). These cannot be correlated to any pending + request, so they are forwarded to the message handler as MCPError. + """ + ev_error_received = anyio.Event() + error_holder: list[MCPError] = [] + + async def capture_errors( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + assert isinstance(message, MCPError) + error_holder.append(message) + ev_error_received.set() + + sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + _server_read, server_write = server_streams + + async def mock_server(): + """Send a null-id error (simulating a parse error).""" + error_response = JSONRPCError(jsonrpc="2.0", id=None, error=sent_error) + await server_write.send(SessionMessage(message=error_response)) + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + message_handler=capture_errors, + ) as _client_session, + ): + tg.start_soon(mock_server) + + with anyio.fail_after(2): # pragma: no branch + await ev_error_received.wait() + + assert len(error_holder) == 1 + assert error_holder[0].error == sent_error + + +@pytest.mark.anyio +async def test_null_id_error_does_not_affect_pending_request(): + """Test that a null-id error doesn't interfere with an in-flight request. + + When a null-id error arrives while a request is pending, the error should + go to the message handler and the pending request should still complete + normally with its own response. + """ + ev_error_received = anyio.Event() + ev_response_received = anyio.Event() + error_holder: list[MCPError] = [] + result_holder: list[EmptyResult] = [] + + async def capture_errors( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + assert isinstance(message, MCPError) + error_holder.append(message) + ev_error_received.set() + + sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Read a request, inject a null-id error, then respond normally.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + assert isinstance(message.message, JSONRPCRequest) + request_id = message.message.id + + # First, send a null-id error (should go to message handler) + await server_write.send(SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=sent_error))) + + # Then, respond normally to the pending request + await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) + + async def make_request(client_session: ClientSession): + result = await client_session.send_ping() + result_holder.append(result) + ev_response_received.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + message_handler=capture_errors, + ) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): # pragma: no branch + await ev_error_received.wait() + await ev_response_received.wait() + + # Null-id error reached the message handler + assert len(error_holder) == 1 + assert error_holder[0].error == sent_error + + # Pending request completed successfully + assert len(result_holder) == 1 + assert isinstance(result_holder[0], EmptyResult) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index df2321ba1..7b2bc0a13 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,6 @@ import json import multiprocessing import socket -import time from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -22,15 +21,20 @@ from mcp import types from mcp.client.session import ClientSession from mcp.client.sse import _extract_session_id_from_endpoint, sse_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError from mcp.types import ( + CallToolRequestParams, + CallToolResult, EmptyResult, Implementation, InitializeResult, JSONRPCResponse, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, ReadResourceResult, ServerCapabilities, TextContent, @@ -54,36 +58,48 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] +async def _handle_read_resource( # pragma: no cover + ctx: ServerRequestContext, params: ReadResourceRequestParams +) -> ReadResourceResult: + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + text = f"Read {parsed.netloc}" + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + text = f"Slow response from {parsed.netloc}" + else: + raise MCPError(code=404, message="OOPS! no resource with that URI was found") + return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + + +async def _handle_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] + +async def _handle_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +def _create_server() -> Server: # pragma: no cover + return Server( + SERVER_NAME, + on_read_resource=_handle_read_resource, + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + ) # Test fixtures @@ -94,7 +110,7 @@ def make_server_app() -> Starlette: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - server = ServerTest() + server = _create_server() async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: @@ -117,11 +133,6 @@ def run_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: @@ -296,11 +307,6 @@ def run_mounted_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def mounted_server(server_port: int) -> Generator[None, None, None]: @@ -336,47 +342,46 @@ async def test_sse_client_basic_connection_mounted_app(mounted_server: None, ser assert isinstance(ping_result, EmptyResult) -# Test server with request context that returns headers in the response -class RequestContextServer(Server[object, Request]): # pragma: no cover - def __init__(self): - super().__init__("request_context_server") - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - headers_info = {} - context = self.request_context - if context.request: - headers_info = dict(context.request.headers) - - if name == "echo_headers": - return [TextContent(type="text", text=json.dumps(headers_info))] - elif name == "echo_context": - context_data = { - "request_id": args.get("request_id"), - "headers": headers_info, - } - return [TextContent(type="text", text=json.dumps(context_data))] - - return [TextContent(type="text", text=f"Called {name}")] - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echoes request headers", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echoes request context", - input_schema={ - "type": "object", - "properties": {"request_id": {"type": "string"}}, - "required": ["request_id"], - }, - ), - ] +async def _handle_context_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + headers_info: dict[str, Any] = {} + if ctx.request: + headers_info = dict(ctx.request.headers) + + if params.name == "echo_headers": + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + elif params.name == "echo_context": + context_data = { + "request_id": (params.arguments or {}).get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +async def _handle_context_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echoes request headers", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + input_schema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), + ] + ) def run_context_server(server_port: int) -> None: # pragma: no cover @@ -386,7 +391,11 @@ def run_context_server(server_port: int) -> None: # pragma: no cover allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = RequestContextServer() + context_server = Server( + "request_context_server", + on_call_tool=_handle_context_call_tool, + on_list_tools=_handle_context_list_tools, + ) async def handle_sse(request: Request) -> Response: async with sse.connect_sse(request.scope, request.receive, request._send) as streams: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b04b92026..42b1a3698 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,7 +10,9 @@ import socket import time import traceback -from collections.abc import Generator +from collections.abc import AsyncIterator, Generator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field from typing import Any from unittest.mock import MagicMock from urllib.parse import urlparse @@ -28,7 +30,7 @@ from mcp import MCPError, types from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, @@ -50,7 +52,19 @@ ) from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import InitializeResult, JSONRPCRequest, TextContent, TextResourceContents, Tool +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + InitializeResult, + JSONRPCRequest, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + Tool, +) from tests.test_helpers import wait_for_server # Test constants @@ -124,263 +138,258 @@ async def replay_events_after( # pragma: no cover return target_stream_id -# Test server implementation that follows MCP protocol -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - self._lock = None # Will be initialized in async context - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - if parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise ValueError(f"Unknown resource: {uri}") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_tool_with_standalone_notification", - description="A test tool that sends a notification", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="long_running_with_checkpoints", - description="A long-running tool that sends periodic notifications", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="test_sampling_tool", - description="A tool that triggers server-side sampling", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="wait_for_lock_with_notification", - description="A tool that sends a notification and waits for lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="release_lock", - description="A tool that releases the lock", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_stream_close", - description="A tool that closes SSE stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_notifications_and_close", - description="Tool that sends notification1, closes stream, sends notification2, notification3", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="tool_with_multiple_stream_closes", - description="Tool that closes SSE stream multiple times during execution", - input_schema={ - "type": "object", - "properties": { - "checkpoints": {"type": "integer", "default": 3}, - "sleep_time": {"type": "number", "default": 0.2}, - }, +@dataclass +class ServerState: + lock: anyio.Event = field(default_factory=anyio.Event) + + +@asynccontextmanager +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover + yield ServerState() + + +async def _handle_read_resource( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams +) -> ReadResourceResult: + uri = str(params.uri) + parsed = urlparse(uri) + if parsed.scheme == "foobar": + text = f"Read {parsed.netloc}" + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + text = f"Slow response from {parsed.netloc}" + else: + raise ValueError(f"Unknown resource: {uri}") + return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) + + +async def _handle_list_tools( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="long_running_with_checkpoints", + description="A long-running tool that sends periodic notifications", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="test_sampling_tool", + description="A tool that triggers server-side sampling", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="wait_for_lock_with_notification", + description="A tool that sends a notification and waits for lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="release_lock", + description="A tool that releases the lock", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_stream_closes", + description="Tool that closes SSE stream multiple times during execution", + input_schema={ + "type": "object", + "properties": { + "checkpoints": {"type": "integer", "default": 3}, + "sleep_time": {"type": "number", "default": 0.2}, }, - ), - Tool( - name="tool_with_standalone_stream_close", - description="Tool that closes standalone GET stream mid-operation", - input_schema={"type": "object", "properties": {}}, - ), - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context + }, + ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + input_schema={"type": "object", "properties": {}}, + ), + ] + ) - # When the tool is called, send a notification to test GET stream - if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated(uri="http://test_resource") - return [TextContent(type="text", text=f"Called {name}")] - elif name == "long_running_with_checkpoints": - # Send notifications that are part of the response stream - # This simulates a long-running tool that sends logs +async def _handle_call_tool( # pragma: no cover + ctx: ServerRequestContext[ServerState], params: CallToolRequestParams +) -> CallToolResult: + name = params.name + args = params.arguments or {} - await ctx.session.send_log_message( - level="info", - data="Tool started", - logger="tool", - related_request_id=ctx.request_id, # need for stream association - ) + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + await ctx.session.send_resource_updated(uri="http://test_resource") + return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - await anyio.sleep(0.1) + elif name == "long_running_with_checkpoints": + await ctx.session.send_log_message( + level="info", + data="Tool started", + logger="tool", + related_request_id=ctx.request_id, + ) - await ctx.session.send_log_message( - level="info", - data="Tool is almost done", - logger="tool", - related_request_id=ctx.request_id, - ) + await anyio.sleep(0.1) - return [TextContent(type="text", text="Completed!")] + await ctx.session.send_log_message( + level="info", + data="Tool is almost done", + logger="tool", + related_request_id=ctx.request_id, + ) - elif name == "test_sampling_tool": - # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="Server needs client sampling"), - ) - ], - max_tokens=100, - related_request_id=ctx.request_id, - ) + return CallToolResult(content=[TextContent(type="text", text="Completed!")]) - # Return the sampling result in the tool response - # Since we're not passing tools param, result.content is single content - if sampling_result.content.type == "text": - response = sampling_result.content.text - else: - response = str(sampling_result.content) - return [ - TextContent( - type="text", - text=f"Response from sampling: {response}", - ) - ] - - elif name == "wait_for_lock_with_notification": - # Initialize lock if not already done - if self._lock is None: - self._lock = anyio.Event() - - # First send a notification - await ctx.session.send_log_message( - level="info", - data="First notification before lock", - logger="lock_tool", - related_request_id=ctx.request_id, + elif name == "test_sampling_tool": + sampling_result = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="Server needs client sampling"), ) + ], + max_tokens=100, + related_request_id=ctx.request_id, + ) - # Now wait for the lock to be released - await self._lock.wait() - - # Send second notification after lock is released - await ctx.session.send_log_message( - level="info", - data="Second notification after lock", - logger="lock_tool", - related_request_id=ctx.request_id, + if sampling_result.content.type == "text": + response = sampling_result.content.text + else: + response = str(sampling_result.content) + return CallToolResult( + content=[ + TextContent( + type="text", + text=f"Response from sampling: {response}", ) + ] + ) - return [TextContent(type="text", text="Completed")] + elif name == "wait_for_lock_with_notification": + await ctx.session.send_log_message( + level="info", + data="First notification before lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "release_lock": - assert self._lock is not None, "Lock must be initialized before releasing" + await ctx.lifespan_context.lock.wait() - # Release the lock - self._lock.set() - return [TextContent(type="text", text="Lock released")] + await ctx.session.send_log_message( + level="info", + data="Second notification after lock", + logger="lock_tool", + related_request_id=ctx.request_id, + ) - elif name == "tool_with_stream_close": - # Send notification before closing - await ctx.session.send_log_message( - level="info", - data="Before close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream (triggers client reconnect) - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Continue processing (events stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="After close", - logger="stream_close_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="Done")] - - elif name == "tool_with_multiple_notifications_and_close": - # Send notification1 - await ctx.session.send_log_message( - level="info", - data="notification1", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - # Close SSE stream - assert ctx.close_sse_stream is not None - await ctx.close_sse_stream() - # Send notification2, notification3 (stored in event_store) - await anyio.sleep(0.1) - await ctx.session.send_log_message( - level="info", - data="notification2", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - await ctx.session.send_log_message( - level="info", - data="notification3", - logger="multi_notif_tool", - related_request_id=ctx.request_id, - ) - return [TextContent(type="text", text="All notifications sent")] + return CallToolResult(content=[TextContent(type="text", text="Completed")]) - elif name == "tool_with_multiple_stream_closes": - num_checkpoints = args.get("checkpoints", 3) - sleep_time = args.get("sleep_time", 0.2) + elif name == "release_lock": + ctx.lifespan_context.lock.set() + return CallToolResult(content=[TextContent(type="text", text="Lock released")]) - for i in range(num_checkpoints): - await ctx.session.send_log_message( - level="info", - data=f"checkpoint_{i}", - logger="multi_close_tool", - related_request_id=ctx.request_id, - ) + elif name == "tool_with_stream_close": + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + elif name == "tool_with_multiple_notifications_and_close": + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return CallToolResult(content=[TextContent(type="text", text="All notifications sent")]) + + elif name == "tool_with_multiple_stream_closes": + num_checkpoints = args.get("checkpoints", 3) + sleep_time = args.get("sleep_time", 0.2) + + for i in range(num_checkpoints): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, + ) - if ctx.close_sse_stream: - await ctx.close_sse_stream() + if ctx.close_sse_stream: + await ctx.close_sse_stream() - await anyio.sleep(sleep_time) + await anyio.sleep(sleep_time) - return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + return CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]) - elif name == "tool_with_standalone_stream_close": - # Test for GET stream reconnection - # 1. Send unsolicited notification via GET stream (no related_request_id) - await ctx.session.send_resource_updated(uri="http://notification_1") + elif name == "tool_with_standalone_stream_close": + await ctx.session.send_resource_updated(uri="http://notification_1") + await anyio.sleep(0.1) - # Small delay to ensure notification is flushed before closing - await anyio.sleep(0.1) + if ctx.close_standalone_sse_stream: + await ctx.close_standalone_sse_stream() - # 2. Close the standalone GET stream - if ctx.close_standalone_sse_stream: - await ctx.close_standalone_sse_stream() + await anyio.sleep(1.5) + await ctx.session.send_resource_updated(uri="http://notification_2") - # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) - await anyio.sleep(1.5) + return CallToolResult(content=[TextContent(type="text", text="Standalone stream close test done")]) - # 4. Send another notification on the new GET stream connection - await ctx.session.send_resource_updated(uri="http://notification_2") + return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - return [TextContent(type="text", text="Standalone stream close test done")] - return [TextContent(type="text", text=f"Called {name}")] +def _create_server() -> Server[ServerState]: # pragma: no cover + return Server( + SERVER_NAME, + lifespan=_server_lifespan, + on_read_resource=_handle_read_resource, + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + ) def create_app( @@ -396,7 +405,7 @@ def create_app( retry_interval: Retry interval in milliseconds for SSE polling. """ # Create server instance - server = ServerTest() + server = _create_server() # Create the session manager security_settings = TransportSecuritySettings( @@ -1385,69 +1394,68 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__("ContextAwareServer") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="echo_headers", - description="Echo request headers from context", - input_schema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_context", - description="Echo request context with custom data", - input_schema={ - "type": "object", - "properties": { - "request_id": {"type": "string"}, - }, - "required": ["request_id"], +async def _handle_context_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo_headers", + description="Echo request headers from context", + input_schema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echo request context with custom data", + input_schema={ + "type": "object", + "properties": { + "request_id": {"type": "string"}, }, - ), - ] + "required": ["request_id"], + }, + ), + ] + ) - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - ctx = self.request_context - - if name == "echo_headers": - # Access the request object from context - headers_info = {} - if ctx.request and isinstance(ctx.request, Request): - headers_info = dict(ctx.request.headers) - return [TextContent(type="text", text=json.dumps(headers_info))] - - elif name == "echo_context": - # Return full context information - context_data: dict[str, Any] = { - "request_id": args.get("request_id"), - "headers": {}, - "method": None, - "path": None, - } - if ctx.request and isinstance(ctx.request, Request): - request = ctx.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return [ - TextContent( - type="text", - text=json.dumps(context_data), - ) - ] - - return [TextContent(type="text", text=f"Unknown tool: {name}")] + +async def _handle_context_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + name = params.name + args = params.arguments or {} + + if name == "echo_headers": + headers_info: dict[str, Any] = {} + if ctx.request and isinstance(ctx.request, Request): + headers_info = dict(ctx.request.headers) + return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) + + elif name == "echo_context": + context_data: dict[str, Any] = { + "request_id": args.get("request_id"), + "headers": {}, + "method": None, + "path": None, + } + if ctx.request and isinstance(ctx.request, Request): + request = ctx.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) + + return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # Server runner for context-aware testing def run_context_aware_server(port: int): # pragma: no cover """Run the context-aware test server.""" - server = ContextAwareServerTest() + server = Server( + "ContextAwareServer", + on_list_tools=_handle_context_list_tools, + on_call_tool=_handle_context_call_tool, + ) session_manager = StreamableHTTPSessionManager( app=server, diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 07e19195d..9addb661d 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,8 +1,6 @@ import multiprocessing import socket -import time from collections.abc import AsyncGenerator, Generator -from typing import Any from urllib.parse import urlparse import anyio @@ -15,9 +13,21 @@ from mcp import MCPError from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client -from mcp.server import Server +from mcp.server import Server, ServerRequestContext from mcp.server.websocket import websocket_server -from mcp.types import EmptyResult, InitializeResult, ReadResourceResult, TextContent, TextResourceContents, Tool +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + EmptyResult, + InitializeResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceRequestParams, + ReadResourceResult, + TextContent, + TextResourceContents, + Tool, +) from tests.test_helpers import wait_for_server SERVER_NAME = "test_server_for_WS" @@ -35,42 +45,59 @@ def server_url(server_port: int) -> str: return f"ws://127.0.0.1:{server_port}" -# Test server implementation -class ServerTest(Server): # pragma: no cover - def __init__(self): - super().__init__(SERVER_NAME) - - @self.read_resource() - async def handle_read_resource(uri: str) -> str | bytes: - parsed = urlparse(uri) - if parsed.scheme == "foobar": - return f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - # Simulate a slow resource - await anyio.sleep(2.0) - return f"Slow response from {parsed.netloc}" - - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, +async def handle_read_resource( # pragma: no cover + ctx: ServerRequestContext, params: ReadResourceRequestParams +) -> ReadResourceResult: + parsed = urlparse(str(params.uri)) + if parsed.scheme == "foobar": + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + elif parsed.scheme == "slow": + await anyio.sleep(2.0) + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" ) ] + ) + raise MCPError(code=404, message="OOPS! no resource with that URI was found") - @self.call_tool() - async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: - return [TextContent(type="text", text=f"Called {name}")] + +async def handle_list_tools( # pragma: no cover + ctx: ServerRequestContext, params: PaginatedRequestParams | None +) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + +async def handle_call_tool( # pragma: no cover + ctx: ServerRequestContext, params: CallToolRequestParams +) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + + +def _create_server() -> Server: # pragma: no cover + return Server( + SERVER_NAME, + on_read_resource=handle_read_resource, + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) # Test fixtures def make_server_app() -> Starlette: # pragma: no cover """Create test Starlette app with WebSocket transport""" - server = ServerTest() + server = _create_server() async def handle_ws(websocket: WebSocket): async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: @@ -86,11 +113,6 @@ def run_server(server_port: int) -> None: # pragma: no cover print(f"starting server on {server_port}") server.run() - # Give server time to start - while not server.started: - print("waiting for server to start") - time.sleep(0.5) - @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: diff --git a/uv.lock b/uv.lock index 5d3a83f37..d01d510f1 100644 --- a/uv.lock +++ b/uv.lock @@ -128,6 +128,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, ] +[[package]] +name = "cairocffi" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/70/c5/1a4dc131459e68a173cbdab5fad6b524f53f9c1ef7861b7698e998b837cc/cairocffi-1.7.1.tar.gz", hash = "sha256:2e48ee864884ec4a3a34bfa8c9ab9999f688286eb714a15a43ec9d068c36557b", size = 88096, upload-time = "2024-06-18T10:56:06.741Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/d8/ba13451aa6b745c49536e87b6bf8f629b950e84bd0e8308f7dc6883b67e2/cairocffi-1.7.1-py3-none-any.whl", hash = "sha256:9803a0e11f6c962f3b0ae2ec8ba6ae45e957a146a004697a1ac1bbf16b073b3f", size = 75611, upload-time = "2024-06-18T10:55:59.489Z" }, +] + +[[package]] +name = "cairosvg" +version = "2.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cairocffi" }, + { name = "cssselect2" }, + { name = "defusedxml" }, + { name = "pillow" }, + { name = "tinycss2" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/b9/5106168bd43d7cd8b7cc2a2ee465b385f14b63f4c092bb89eee2d48c8e67/cairosvg-2.8.2.tar.gz", hash = "sha256:07cbf4e86317b27a92318a4cac2a4bb37a5e9c1b8a27355d06874b22f85bef9f", size = 8398590, upload-time = "2025-05-15T06:56:32.653Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/48/816bd4aaae93dbf9e408c58598bc32f4a8c65f4b86ab560864cb3ee60adb/cairosvg-2.8.2-py3-none-any.whl", hash = "sha256:eab46dad4674f33267a671dce39b64be245911c901c70d65d2b7b0821e852bf5", size = 45773, upload-time = "2025-05-15T06:56:28.552Z" }, +] + [[package]] name = "certifi" version = "2025.8.3" @@ -468,6 +496,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, ] +[[package]] +name = "cssselect2" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tinycss2" }, + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/20/92eaa6b0aec7189fa4b75c890640e076e9e793095721db69c5c81142c2e1/cssselect2-0.9.0.tar.gz", hash = "sha256:759aa22c216326356f65e62e791d66160a0f9c91d1424e8d8adc5e74dddfc6fb", size = 35595, upload-time = "2026-02-12T17:16:39.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/0e/8459ca4413e1a21a06c97d134bfaf18adfd27cea068813dc0faae06cbf00/cssselect2-0.9.0-py3-none-any.whl", hash = "sha256:6a99e5f91f9a016a304dd929b0966ca464bcfda15177b6fb4a118fc0fb5d9563", size = 15453, upload-time = "2026-02-12T17:16:38.317Z" }, +] + +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, +] + [[package]] name = "dirty-equals" version = "0.9.0" @@ -781,7 +831,7 @@ dev = [ docs = [ { name = "mkdocs" }, { name = "mkdocs-glightbox" }, - { name = "mkdocs-material" }, + { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] @@ -829,7 +879,7 @@ dev = [ docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, { name = "mkdocs-glightbox", specifier = ">=0.4.0" }, - { name = "mkdocs-material", specifier = ">=9.5.45" }, + { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, { name = "mkdocstrings-python", specifier = ">=2.0.1" }, ] @@ -1469,7 +1519,7 @@ wheels = [ [[package]] name = "mkdocs-material" -version = "9.7.1" +version = "9.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "babel" }, @@ -1484,9 +1534,15 @@ dependencies = [ { name = "pymdown-extensions" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/27/e2/2ffc356cd72f1473d07c7719d82a8f2cbd261666828614ecb95b12169f41/mkdocs_material-9.7.1.tar.gz", hash = "sha256:89601b8f2c3e6c6ee0a918cc3566cb201d40bf37c3cd3c2067e26fadb8cce2b8", size = 4094392, upload-time = "2025-12-18T09:49:00.308Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/57/5d3c8c9e2ff9d66dc8f63aa052eb0bac5041fecff7761d8689fe65c39c13/mkdocs_material-9.7.2.tar.gz", hash = "sha256:6776256552290b9b7a7aa002780e25b1e04bc9c3a8516b6b153e82e16b8384bd", size = 4097818, upload-time = "2026-02-18T15:53:07.763Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/32/ed071cb721aca8c227718cffcf7bd539620e9799bbf2619e90c757bfd030/mkdocs_material-9.7.1-py3-none-any.whl", hash = "sha256:3f6100937d7d731f87f1e3e3b021c97f7239666b9ba1151ab476cabb96c60d5c", size = 9297166, upload-time = "2025-12-18T09:48:56.664Z" }, + { url = "https://files.pythonhosted.org/packages/cd/19/d194e75e82282b1d688f0720e21b5ac250ed64ddea333a228aaf83105f2e/mkdocs_material-9.7.2-py3-none-any.whl", hash = "sha256:9bf6f53452d4a4d527eac3cef3f92b7b6fc4931c55d57766a7d87890d47e1b92", size = 9305052, upload-time = "2026-02-18T15:53:05.221Z" }, +] + +[package.optional-dependencies] +imaging = [ + { name = "cairosvg" }, + { name = "pillow" }, ] [[package]] @@ -2406,6 +2462,18 @@ dependencies = [ { name = "pydantic" }, ] +[[package]] +name = "tinycss2" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/ae/2ca4913e5c0f09781d75482874c3a95db9105462a92ddd303c7d285d3df2/tinycss2-1.5.1.tar.gz", hash = "sha256:d339d2b616ba90ccce58da8495a78f46e55d4d25f9fd71dfd526f07e7d53f957", size = 88195, upload-time = "2025-11-23T10:29:10.082Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/45/c7b5c3168458db837e8ceab06dc77824e18202679d0463f0e8f002143a97/tinycss2-1.5.1-py3-none-any.whl", hash = "sha256:3415ba0f5839c062696996998176c4a3751d18b7edaaeeb658c9ce21ec150661", size = 28404, upload-time = "2025-11-23T10:29:08.676Z" }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -2554,6 +2622,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721, upload-time = "2017-04-05T20:21:34.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, +] + [[package]] name = "websockets" version = "15.0.1"