From 7d9a84afc5ff186c7c40585d6ef7f538f0579f79 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 12 May 2025 14:54:17 -0700 Subject: [PATCH 1/8] Add support for get tokens method --- src/mcp/client/streamable_http.py | 47 +++++++++++++++++++--- tests/shared/test_streamable_http.py | 59 ++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3324dab5a..a055735a5 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Any +from typing import Any, Protocol import anyio import httpx @@ -74,6 +74,18 @@ class RequestContext: sse_read_timeout: timedelta +class AuthTokenProvider(Protocol): + """Protocol for providers that supply authentication tokens.""" + + async def get_token(self) -> str: + """Get an authentication token. + + Returns: + str: The authentication token. + """ + ... + + class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" @@ -83,6 +95,7 @@ def __init__( headers: dict[str, Any] | None = None, timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + auth_token_provider: AuthTokenProvider | None = None, ) -> None: """Initialize the StreamableHTTP transport. @@ -102,6 +115,7 @@ def __init__( CONTENT_TYPE: JSON, **self.headers, } + self.auth_token_provider = auth_token_provider def _update_headers_with_session( self, base_headers: dict[str, str] @@ -112,6 +126,24 @@ def _update_headers_with_session( headers[MCP_SESSION_ID] = self.session_id return headers + async def _update_headers_with_token( + self, base_headers: dict[str, str] + ) -> dict[str, str]: + """Update headers with token if token provider is specified.""" + if self.auth_token_provider is None: + return base_headers + + token = await self.auth_token_provider.get_token() + headers = base_headers.copy() + headers["Authorization"] = f"Bearer {token}" + return headers + + async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]: + """Update headers with session ID and token if available.""" + headers = self._update_headers_with_session(base_headers) + headers = await self._update_headers_with_token(headers) + return headers + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" return ( @@ -184,7 +216,7 @@ async def handle_get_stream( if not self.session_id: return - headers = self._update_headers_with_session(self.request_headers) + headers = await self._update_headers(self.request_headers) async with aconnect_sse( client, @@ -206,7 +238,7 @@ async def handle_get_stream( async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" - headers = self._update_headers_with_session(ctx.headers) + headers = await self._update_headers(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: @@ -241,7 +273,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._update_headers_with_session(ctx.headers) + headers = await self._update_headers(ctx.headers) message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: return try: - headers = self._update_headers_with_session(self.request_headers) + headers = await self._update_headers(self.request_headers) response = await client.delete(self.url, headers=headers) if response.status_code == 405: @@ -427,6 +459,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, + auth_token_provider: AuthTokenProvider | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -447,7 +480,9 @@ async def streamablehttp_client( - write_stream: Stream for sending messages to the server - get_session_id_callback: Function to retrieve the current session ID """ - transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + transport = StreamableHTTPTransport( + url, headers, timeout, sse_read_timeout, auth_token_provider + ) read_stream_writer, read_stream = anyio.create_memory_object_stream[ SessionMessage | Exception diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f1c7ef809..a021e42a2 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -9,6 +9,7 @@ import time from collections.abc import Generator from typing import Any +from unittest.mock import AsyncMock import anyio import httpx @@ -1223,3 +1224,61 @@ async def sampling_callback( captured_message_params.messages[0].content.text == "Server needs client sampling" ) + + +class MockAuthTokenProvider: + """Mock implementation of AuthTokenProvider for testing.""" + + def __init__(self, token: str): + self.token = token + + async def get_token(self) -> str: + return self.token + + +@pytest.mark.anyio +async def test_auth_token_provider_headers(basic_server, basic_server_url): + """Test that auth token provider correctly sets Authorization header.""" + # Create a mock token provider + token_provider = MockAuthTokenProvider("test-token-123") + token_provider.get_token = AsyncMock(return_value="test-token-123") + + # Create client with token provider + async with streamablehttp_client( + f"{basic_server_url}/mcp", auth_token_provider=token_provider + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make a request to verify headers + tools = await session.list_tools() + assert len(tools.tools) == 4 + + token_provider.get_token.assert_called() + + +@pytest.mark.anyio +async def test_auth_token_provider_token_update(basic_server, basic_server_url): + """Test that auth token provider can return different tokens.""" + # Create a dynamic token provider + token_provider = MockAuthTokenProvider("test-token-123") + token_provider.get_token = AsyncMock(return_value="test-token-123") + + # Create client with dynamic token provider + async with streamablehttp_client( + f"{basic_server_url}/mcp", auth_token_provider=token_provider + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify token updates + for i in range(3): + tools = await session.list_tools() + assert len(tools.tools) == 4 + await anyio.sleep(0.1) # Small delay to ensure token updates + + token_provider.get_token.call_count > 1 From 289c03a440582b600ccb5284eb350065a61fc0f6 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 12 May 2025 16:29:18 -0700 Subject: [PATCH 2/8] Address comments --- src/mcp/client/streamable_http.py | 15 ++++++++++++--- tests/shared/test_streamable_http.py | 28 +++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index a055735a5..09934cb55 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -75,7 +75,9 @@ class RequestContext: class AuthTokenProvider(Protocol): - """Protocol for providers that supply authentication tokens.""" + """Protocol that can be extended to implement custom client-to-server authentication + The get_token method is invoked before each request to the MCP Server to retrieve a + fresh authentication token and update the request headers.""" async def get_token(self) -> str: """Get an authentication token. @@ -129,8 +131,9 @@ def _update_headers_with_session( async def _update_headers_with_token( self, base_headers: dict[str, str] ) -> dict[str, str]: - """Update headers with token if token provider is specified.""" - if self.auth_token_provider is None: + """Update headers with token if token provider is specified and authorization + header is not present.""" + if self.auth_token_provider is None or "Authorization" in base_headers: return base_headers token = await self.auth_token_provider.get_token() @@ -474,6 +477,12 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + `auth_token_provider` is an optional protocol that can be extended to implement + custom client-to-server authentication. Before each request to the MCP Server, + the get_token method is invoked to retrieve a fresh authentication token and + update the request headers. Note that if the passed in headers already + contain an authorization header, this provider will not be called. + Yields: Tuple containing: - read_stream: Stream for reading messages from the server diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index a021e42a2..4ac93edbc 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1279,6 +1279,32 @@ async def test_auth_token_provider_token_update(basic_server, basic_server_url): for i in range(3): tools = await session.list_tools() assert len(tools.tools) == 4 - await anyio.sleep(0.1) # Small delay to ensure token updates token_provider.get_token.call_count > 1 + + +@pytest.mark.anyio +async def test_auth_token_provider_headers_not_overridden( + basic_server, basic_server_url +): + """Test that auth token provider correctly sets Authorization header.""" + # Create a mock token provider + token_provider = MockAuthTokenProvider("test-token-123") + token_provider.get_token = AsyncMock(return_value="test-token-123") + + # Create client with token provider + async with streamablehttp_client( + f"{basic_server_url}/mcp", + auth_token_provider=token_provider, + headers={"Authorization": "test-token-123"}, + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make a request to verify headers + tools = await session.list_tools() + assert len(tools.tools) == 4 + + token_provider.get_token.assert_not_called() From c99d4f7a90752b71e9b15ef40512f9673503c825 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 12 May 2025 16:54:45 -0700 Subject: [PATCH 3/8] Update Auth Provider to AuthClientProvider --- src/mcp/client/streamable_http.py | 25 ++++++++++---------- tests/shared/test_streamable_http.py | 34 ++++++++++++++-------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 09934cb55..0ed88b0bb 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -74,13 +74,14 @@ class RequestContext: sse_read_timeout: timedelta -class AuthTokenProvider(Protocol): - """Protocol that can be extended to implement custom client-to-server authentication - The get_token method is invoked before each request to the MCP Server to retrieve a - fresh authentication token and update the request headers.""" +class AuthClientProvider(Protocol): + """Base class that can be extended to implement custom client-to-server + authentication""" async def get_token(self) -> str: - """Get an authentication token. + """Get a token for authenticating to an MCP server. The token is assumed to + be short-lived; clients may call this API multiple times per + request to an MCP server. Returns: str: The authentication token. @@ -97,7 +98,7 @@ def __init__( headers: dict[str, Any] | None = None, timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), - auth_token_provider: AuthTokenProvider | None = None, + auth_client_provider: AuthClientProvider | None = None, ) -> None: """Initialize the StreamableHTTP transport. @@ -117,7 +118,7 @@ def __init__( CONTENT_TYPE: JSON, **self.headers, } - self.auth_token_provider = auth_token_provider + self.auth_client_provider = auth_client_provider def _update_headers_with_session( self, base_headers: dict[str, str] @@ -133,10 +134,10 @@ async def _update_headers_with_token( ) -> dict[str, str]: """Update headers with token if token provider is specified and authorization header is not present.""" - if self.auth_token_provider is None or "Authorization" in base_headers: + if self.auth_client_provider is None or "Authorization" in base_headers: return base_headers - token = await self.auth_token_provider.get_token() + token = await self.auth_client_provider.get_token() headers = base_headers.copy() headers["Authorization"] = f"Bearer {token}" return headers @@ -462,7 +463,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, - auth_token_provider: AuthTokenProvider | None = None, + auth_client_provider: AuthClientProvider | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -477,7 +478,7 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. - `auth_token_provider` is an optional protocol that can be extended to implement + `auth_client_provider` is an optional protocol that can be extended to implement custom client-to-server authentication. Before each request to the MCP Server, the get_token method is invoked to retrieve a fresh authentication token and update the request headers. Note that if the passed in headers already @@ -490,7 +491,7 @@ async def streamablehttp_client( - get_session_id_callback: Function to retrieve the current session ID """ transport = StreamableHTTPTransport( - url, headers, timeout, sse_read_timeout, auth_token_provider + url, headers, timeout, sse_read_timeout, auth_client_provider ) read_stream_writer, read_stream = anyio.create_memory_object_stream[ diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 4ac93edbc..f79b82b8d 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1226,8 +1226,8 @@ async def sampling_callback( ) -class MockAuthTokenProvider: - """Mock implementation of AuthTokenProvider for testing.""" +class MockAuthClientProvider: + """Mock implementation of AuthClientProvider for testing.""" def __init__(self, token: str): self.token = token @@ -1237,15 +1237,15 @@ async def get_token(self) -> str: @pytest.mark.anyio -async def test_auth_token_provider_headers(basic_server, basic_server_url): +async def test_auth_client_provider_headers(basic_server, basic_server_url): """Test that auth token provider correctly sets Authorization header.""" # Create a mock token provider - token_provider = MockAuthTokenProvider("test-token-123") - token_provider.get_token = AsyncMock(return_value="test-token-123") + client_provider = MockAuthClientProvider("test-token-123") + client_provider.get_token = AsyncMock(return_value="test-token-123") # Create client with token provider async with streamablehttp_client( - f"{basic_server_url}/mcp", auth_token_provider=token_provider + f"{basic_server_url}/mcp", auth_client_provider=client_provider ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the session @@ -1256,19 +1256,19 @@ async def test_auth_token_provider_headers(basic_server, basic_server_url): tools = await session.list_tools() assert len(tools.tools) == 4 - token_provider.get_token.assert_called() + client_provider.get_token.assert_called() @pytest.mark.anyio -async def test_auth_token_provider_token_update(basic_server, basic_server_url): +async def test_auth_client_provider_token_update(basic_server, basic_server_url): """Test that auth token provider can return different tokens.""" # Create a dynamic token provider - token_provider = MockAuthTokenProvider("test-token-123") - token_provider.get_token = AsyncMock(return_value="test-token-123") + client_provider = MockAuthClientProvider("test-token-123") + client_provider.get_token = AsyncMock(return_value="test-token-123") # Create client with dynamic token provider async with streamablehttp_client( - f"{basic_server_url}/mcp", auth_token_provider=token_provider + f"{basic_server_url}/mcp", auth_client_provider=client_provider ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the session @@ -1280,22 +1280,22 @@ async def test_auth_token_provider_token_update(basic_server, basic_server_url): tools = await session.list_tools() assert len(tools.tools) == 4 - token_provider.get_token.call_count > 1 + client_provider.get_token.call_count > 1 @pytest.mark.anyio -async def test_auth_token_provider_headers_not_overridden( +async def test_auth_client_provider_headers_not_overridden( basic_server, basic_server_url ): """Test that auth token provider correctly sets Authorization header.""" # Create a mock token provider - token_provider = MockAuthTokenProvider("test-token-123") - token_provider.get_token = AsyncMock(return_value="test-token-123") + client_provider = MockAuthClientProvider("test-token-123") + client_provider.get_token = AsyncMock(return_value="test-token-123") # Create client with token provider async with streamablehttp_client( f"{basic_server_url}/mcp", - auth_token_provider=token_provider, + auth_client_provider=client_provider, headers={"Authorization": "test-token-123"}, ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: @@ -1307,4 +1307,4 @@ async def test_auth_token_provider_headers_not_overridden( tools = await session.list_tools() assert len(tools.tools) == 4 - token_provider.get_token.assert_not_called() + client_provider.get_token.assert_not_called() From 785964e8787bc03a4c35937c0af3eebaf5251b53 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Mon, 12 May 2025 17:08:56 -0700 Subject: [PATCH 4/8] Address nits --- src/mcp/client/streamable_http.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 0ed88b0bb..18d0997c5 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -478,11 +478,12 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. - `auth_client_provider` is an optional protocol that can be extended to implement - custom client-to-server authentication. Before each request to the MCP Server, - the get_token method is invoked to retrieve a fresh authentication token and - update the request headers. Note that if the passed in headers already - contain an authorization header, this provider will not be called. + `auth_client_provider` instance of `AuthClientProvider` that can be passed to + support client-to-server authentication. Before each request to the MCP Server, + the auth_client_provider.get_token method is invoked to retrieve a fresh + authentication token and update the request headers. Note that if the passed in + `headers` already contain an Authorization header, that header will take precedence + over any tokens generated by this provider. Yields: Tuple containing: From d3f0dea6f6be9d2610267242b06a66069168cc46 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Tue, 13 May 2025 12:07:40 -0700 Subject: [PATCH 5/8] Change implementation to take in auth headers --- src/mcp/client/streamable_http.py | 51 +++++--- tests/shared/test_streamable_http.py | 171 ++++++++++++++++++--------- 2 files changed, 153 insertions(+), 69 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 18d0997c5..5f074e938 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -78,13 +78,12 @@ class AuthClientProvider(Protocol): """Base class that can be extended to implement custom client-to-server authentication""" - async def get_token(self) -> str: - """Get a token for authenticating to an MCP server. The token is assumed to - be short-lived; clients may call this API multiple times per - request to an MCP server. + async def get_auth_headers(self) -> dict[str, str]: + """Gets auth headers for authenticating to an MCP server. + Clients may call this API multiple times per request to an MCP server. Returns: - str: The authentication token. + dict[str, str]: The authentication headers. """ ... @@ -129,23 +128,22 @@ def _update_headers_with_session( headers[MCP_SESSION_ID] = self.session_id return headers - async def _update_headers_with_token( + async def _update_headers_with_auth_headers( self, base_headers: dict[str, str] ) -> dict[str, str]: - """Update headers with token if token provider is specified and authorization - header is not present.""" - if self.auth_client_provider is None or "Authorization" in base_headers: + """Update headers with auth_headers if auth client provider is specified. + The headers are merged giving precedence to the base_headers to + avoid overwriting existing Authorization headers""" + if self.auth_client_provider is None: return base_headers - token = await self.auth_client_provider.get_token() - headers = base_headers.copy() - headers["Authorization"] = f"Bearer {token}" - return headers + auth_headers = await self.auth_client_provider.get_auth_headers() + return {**auth_headers, **base_headers} async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]: """Update headers with session ID and token if available.""" headers = self._update_headers_with_session(base_headers) - headers = await self._update_headers_with_token(headers) + headers = await self._update_headers_with_auth_headers(headers) return headers def _is_initialization_request(self, message: JSONRPCMessage) -> bool: @@ -252,7 +250,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: original_request_id = None if isinstance(ctx.session_message.message.root, JSONRPCRequest): original_request_id = ctx.session_message.message.root.id - async with aconnect_sse( ctx.client, "GET", @@ -275,6 +272,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: if is_complete: break + async def _is_testing_header_capture(self, response: httpx.Response) -> str | None: + try: + content = await response.aread() + if content.decode().startswith("[TESTING_HEADER_CAPTURE]"): + return content.decode() + except Exception as _: + return None + + return None + async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" headers = await self._update_headers(ctx.headers) @@ -299,12 +306,24 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: ) return + if response.status_code == 418: + test_error_message = await self._is_testing_header_capture(response) + # If this is coming from the test case return the response content + if test_error_message and isinstance(message.root, JSONRPCRequest): + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=message.root.id, + error=ErrorData(code=32600, message=test_error_message), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + await ctx.read_stream_writer.send(session_message) + return + response.raise_for_status() if is_initialization: self._maybe_extract_session_id_from_response(response) content_type = response.headers.get(CONTENT_TYPE, "").lower() - if content_type.startswith(JSON): await self._handle_json_response(response, ctx.read_stream_writer) elif content_type.startswith(SSE): diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f79b82b8d..7db75f7f9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,6 +4,7 @@ Contains tests for both server and client sides of the StreamableHTTP transport. """ +import json import multiprocessing import socket import time @@ -18,6 +19,8 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response from starlette.routing import Mount import mcp.types as types @@ -244,8 +247,46 @@ def create_app( return app +def create_header_capture_app() -> Starlette: + """Implement a minimal Starlette app that intercepts every request, + extracts its headers, and responds with status 418 (Test Status code), + embedding the captured headers as the JSON response body. + We use this server solely to verify that the MCP Server is forwarding + headers correctly.""" + + # Create a wrapper that captures headers and returns them in error response + async def header_capture_wrapper(scope, receive, send): + # Capture headers + request = Request(scope, receive=receive) + headers = dict(request.headers) + + # Return error response with headers in body + response = Response( + "[TESTING_HEADER_CAPTURE]:" + json.dumps({"headers": headers}), + status_code=418, + ) + await response(scope, receive, send) + + # Create an ASGI application that uses our wrapper + app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=header_capture_wrapper), + ], + ) + + return app + + +def _get_captured_headrs(str) -> dict[str, str]: + return json.loads(str.split("[TESTING_HEADER_CAPTURE]:")[1])["headers"] + + def run_server( - port: int, is_json_response_enabled=False, event_store: EventStore | None = None + port: int, + is_json_response_enabled=False, + event_store: EventStore | None = None, + testing_header_capture: bool = False, ) -> None: """Run the test server. @@ -255,7 +296,11 @@ def run_server( event_store: Optional event store for testing resumability. """ - app = create_app(is_json_response_enabled, event_store) + if testing_header_capture: + app = create_header_capture_app() + else: + app = create_app(is_json_response_enabled, event_store) + # Configure server config = uvicorn.Config( app=app, @@ -296,33 +341,48 @@ def json_server_port() -> int: return s.getsockname()[1] -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" +def _start_basic_server( + basic_server_port: int, testing_header_capture: bool +) -> Generator[None, None, None]: proc = multiprocessing.Process( - target=run_server, kwargs={"port": basic_server_port}, daemon=True + target=run_server, + kwargs={ + "port": basic_server_port, + "testing_header_capture": testing_header_capture, + }, + daemon=True, ) proc.start() # Wait for server to be running max_attempts = 20 - attempt = 0 - while attempt < max_attempts: + for attempt in range(max_attempts): try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect(("127.0.0.1", basic_server_port)) break except ConnectionRefusedError: time.sleep(0.1) - attempt += 1 else: raise RuntimeError(f"Server failed to start after {max_attempts} attempts") - yield + try: + yield + finally: + proc.kill() + proc.join(timeout=2) - # Clean up - proc.kill() - proc.join(timeout=2) + +@pytest.fixture +def basic_server(basic_server_port: int) -> Generator[None, None, None]: + yield from _start_basic_server(basic_server_port, testing_header_capture=False) + + +@pytest.fixture +def basic_server_with_header_capture( + basic_server_port: int, +) -> Generator[None, None, None]: + yield from _start_basic_server(basic_server_port, testing_header_capture=True) @pytest.fixture @@ -1232,16 +1292,17 @@ class MockAuthClientProvider: def __init__(self, token: str): self.token = token - async def get_token(self) -> str: - return self.token + async def get_auth_headers(self) -> dict[str, str]: + return {"Authorization": f"Bearer {self.token}"} @pytest.mark.anyio -async def test_auth_client_provider_headers(basic_server, basic_server_url): +async def test_auth_client_provider_headers( + basic_server_with_header_capture, basic_server_url +): """Test that auth token provider correctly sets Authorization header.""" # Create a mock token provider - client_provider = MockAuthClientProvider("test-token-123") - client_provider.get_token = AsyncMock(return_value="test-token-123") + client_provider = MockAuthClientProvider("short-lived-token-123") # Create client with token provider async with streamablehttp_client( @@ -1249,62 +1310,66 @@ async def test_auth_client_provider_headers(basic_server, basic_server_url): ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Make a request to verify headers - tools = await session.list_tools() - assert len(tools.tools) == 4 - - client_provider.get_token.assert_called() + with pytest.raises(McpError) as mcpError: + _ = await session.initialize() + assert ( + _get_captured_headrs(mcpError.value.error.message)["Authorization"] + == "Bearer short-lived-token-123" + ) @pytest.mark.anyio -async def test_auth_client_provider_token_update(basic_server, basic_server_url): +async def test_auth_client_provider_token_called_on_every_request( + basic_server_with_header_capture, basic_server_url +): """Test that auth token provider can return different tokens.""" # Create a dynamic token provider - client_provider = MockAuthClientProvider("test-token-123") - client_provider.get_token = AsyncMock(return_value="test-token-123") + client_provider = MockAuthClientProvider("short-lived-token-123") - # Create client with dynamic token provider async with streamablehttp_client( f"{basic_server_url}/mcp", auth_client_provider=client_provider ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Make multiple requests to verify token updates - for i in range(3): - tools = await session.list_tools() - assert len(tools.tools) == 4 + with pytest.raises(McpError) as mcpError: + _ = await session.initialize() + assert ( + _get_captured_headrs(mcpError.value.error.message)["Authorization"] + == "Bearer short-lived-token-123" + ) - client_provider.get_token.call_count > 1 + # Mock a new token and ensure the new token is returned + client_provider.get_auth_headers = AsyncMock( + return_value={"Authorization": "Bearer short-lived-token-456"} + ) + with pytest.raises(McpError) as mcpError: + _ = await session.initialize() + assert ( + _get_captured_headrs(mcpError.value.error.message)["Authorization"] + == "Bearer short-lived-token-456" + ) @pytest.mark.anyio async def test_auth_client_provider_headers_not_overridden( - basic_server, basic_server_url + basic_server_with_header_capture, basic_server_url ): - """Test that auth token provider correctly sets Authorization header.""" + """Test that provided headers override auth client provider headers.""" # Create a mock token provider - client_provider = MockAuthClientProvider("test-token-123") - client_provider.get_token = AsyncMock(return_value="test-token-123") + client_provider = MockAuthClientProvider("short-lived-token") - # Create client with token provider + # Create client with token provider and custom headers + custom_headers = {"Authorization": "Bearer original-long-lived-token"} async with streamablehttp_client( f"{basic_server_url}/mcp", auth_client_provider=client_provider, - headers={"Authorization": "test-token-123"}, + headers=custom_headers, ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - - # Make a request to verify headers - tools = await session.list_tools() - assert len(tools.tools) == 4 - - client_provider.get_token.assert_not_called() + # Original token is used and not short-lived-token from the provider + with pytest.raises(McpError) as mcpError: + _ = await session.initialize() + assert ( + _get_captured_headrs(mcpError.value.error.message)["Authorization"] + == "Bearer original-long-lived-token" + ) From c4fb621b68f78587af0bf0db2051ac0eaaad9d1e Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Tue, 13 May 2025 12:22:27 -0700 Subject: [PATCH 6/8] Clean up code --- src/mcp/client/streamable_http.py | 7 ++++++- tests/shared/test_streamable_http.py | 20 +++++++++++--------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 5f074e938..069e5ba9c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -42,6 +42,7 @@ MCP_SESSION_ID = "mcp-session-id" LAST_EVENT_ID = "last-event-id" CONTENT_TYPE = "content-type" +HEADER_CAPTURE = "[TESTING_HEADER_CAPTURE]" ACCEPT = "Accept" @@ -275,7 +276,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _is_testing_header_capture(self, response: httpx.Response) -> str | None: try: content = await response.aread() - if content.decode().startswith("[TESTING_HEADER_CAPTURE]"): + if content.decode().startswith(HEADER_CAPTURE): return content.decode() except Exception as _: return None @@ -306,6 +307,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: ) return + # To test if headers are being forwarded correctly, in unit tests + # we have a mock server that returns a 418 status code with the + # HEADER_CAPTURE prefix. If the response has this status code + # with the prefix, return the response content as part of the error message. if response.status_code == 418: test_error_message = await self._is_testing_header_capture(response) # If this is coming from the test case return the response content diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7db75f7f9..54de534b9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -25,7 +25,7 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import HEADER_CAPTURE, streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, @@ -262,7 +262,7 @@ async def header_capture_wrapper(scope, receive, send): # Return error response with headers in body response = Response( - "[TESTING_HEADER_CAPTURE]:" + json.dumps({"headers": headers}), + HEADER_CAPTURE + json.dumps({"headers": headers}), status_code=418, ) await response(scope, receive, send) @@ -279,7 +279,7 @@ async def header_capture_wrapper(scope, receive, send): def _get_captured_headrs(str) -> dict[str, str]: - return json.loads(str.split("[TESTING_HEADER_CAPTURE]:")[1])["headers"] + return json.loads(str.split(HEADER_CAPTURE)[1])["headers"] def run_server( @@ -356,21 +356,23 @@ def _start_basic_server( # Wait for server to be running max_attempts = 20 - for attempt in range(max_attempts): + attempt = 0 + while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect(("127.0.0.1", basic_server_port)) break except ConnectionRefusedError: time.sleep(0.1) + attempt += 1 else: raise RuntimeError(f"Server failed to start after {max_attempts} attempts") - try: - yield - finally: - proc.kill() - proc.join(timeout=2) + yield + + # Clean up + proc.kill() + proc.join(timeout=2) @pytest.fixture From 28ae4f7b92f29b1ca9a3803448cd37efb0207317 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Tue, 13 May 2025 13:50:19 -0700 Subject: [PATCH 7/8] Making unit tests simpler --- src/mcp/client/streamable_http.py | 28 ----- tests/shared/test_streamable_http.py | 158 ++++++--------------------- 2 files changed, 34 insertions(+), 152 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 069e5ba9c..acac5f1cb 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -42,7 +42,6 @@ MCP_SESSION_ID = "mcp-session-id" LAST_EVENT_ID = "last-event-id" CONTENT_TYPE = "content-type" -HEADER_CAPTURE = "[TESTING_HEADER_CAPTURE]" ACCEPT = "Accept" @@ -273,16 +272,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: if is_complete: break - async def _is_testing_header_capture(self, response: httpx.Response) -> str | None: - try: - content = await response.aread() - if content.decode().startswith(HEADER_CAPTURE): - return content.decode() - except Exception as _: - return None - - return None - async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" headers = await self._update_headers(ctx.headers) @@ -307,23 +296,6 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: ) return - # To test if headers are being forwarded correctly, in unit tests - # we have a mock server that returns a 418 status code with the - # HEADER_CAPTURE prefix. If the response has this status code - # with the prefix, return the response content as part of the error message. - if response.status_code == 418: - test_error_message = await self._is_testing_header_capture(response) - # If this is coming from the test case return the response content - if test_error_message and isinstance(message.root, JSONRPCRequest): - jsonrpc_error = JSONRPCError( - jsonrpc="2.0", - id=message.root.id, - error=ErrorData(code=32600, message=test_error_message), - ) - session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) - await ctx.read_stream_writer.send(session_message) - return - response.raise_for_status() if is_initialization: self._maybe_extract_session_id_from_response(response) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 54de534b9..bf839b83e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,7 +4,6 @@ Contains tests for both server and client sides of the StreamableHTTP transport. """ -import json import multiprocessing import socket import time @@ -19,13 +18,11 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import HEADER_CAPTURE, streamablehttp_client +from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, @@ -247,46 +244,8 @@ def create_app( return app -def create_header_capture_app() -> Starlette: - """Implement a minimal Starlette app that intercepts every request, - extracts its headers, and responds with status 418 (Test Status code), - embedding the captured headers as the JSON response body. - We use this server solely to verify that the MCP Server is forwarding - headers correctly.""" - - # Create a wrapper that captures headers and returns them in error response - async def header_capture_wrapper(scope, receive, send): - # Capture headers - request = Request(scope, receive=receive) - headers = dict(request.headers) - - # Return error response with headers in body - response = Response( - HEADER_CAPTURE + json.dumps({"headers": headers}), - status_code=418, - ) - await response(scope, receive, send) - - # Create an ASGI application that uses our wrapper - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=header_capture_wrapper), - ], - ) - - return app - - -def _get_captured_headrs(str) -> dict[str, str]: - return json.loads(str.split(HEADER_CAPTURE)[1])["headers"] - - def run_server( - port: int, - is_json_response_enabled=False, - event_store: EventStore | None = None, - testing_header_capture: bool = False, + port: int, is_json_response_enabled=False, event_store: EventStore | None = None ) -> None: """Run the test server. @@ -296,11 +255,7 @@ def run_server( event_store: Optional event store for testing resumability. """ - if testing_header_capture: - app = create_header_capture_app() - else: - app = create_app(is_json_response_enabled, event_store) - + app = create_app(is_json_response_enabled, event_store) # Configure server config = uvicorn.Config( app=app, @@ -341,16 +296,11 @@ def json_server_port() -> int: return s.getsockname()[1] -def _start_basic_server( - basic_server_port: int, testing_header_capture: bool -) -> Generator[None, None, None]: +@pytest.fixture +def basic_server(basic_server_port: int) -> Generator[None, None, None]: + """Start a basic server.""" proc = multiprocessing.Process( - target=run_server, - kwargs={ - "port": basic_server_port, - "testing_header_capture": testing_header_capture, - }, - daemon=True, + target=run_server, kwargs={"port": basic_server_port}, daemon=True ) proc.start() @@ -375,18 +325,6 @@ def _start_basic_server( proc.join(timeout=2) -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - yield from _start_basic_server(basic_server_port, testing_header_capture=False) - - -@pytest.fixture -def basic_server_with_header_capture( - basic_server_port: int, -) -> Generator[None, None, None]: - yield from _start_basic_server(basic_server_port, testing_header_capture=True) - - @pytest.fixture def event_store() -> SimpleEventStore: """Create a test event store.""" @@ -1295,16 +1233,17 @@ def __init__(self, token: str): self.token = token async def get_auth_headers(self) -> dict[str, str]: - return {"Authorization": f"Bearer {self.token}"} + return {"Authorization": "Bearer " + self.token} @pytest.mark.anyio -async def test_auth_client_provider_headers( - basic_server_with_header_capture, basic_server_url -): +async def test_auth_client_provider_headers(basic_server, basic_server_url): """Test that auth token provider correctly sets Authorization header.""" # Create a mock token provider - client_provider = MockAuthClientProvider("short-lived-token-123") + client_provider = MockAuthClientProvider("test-token-123") + client_provider.get_auth_headers = AsyncMock( + return_value={"Authorization": "Bearer test-token-123"} + ) # Create client with token provider async with streamablehttp_client( @@ -1312,66 +1251,37 @@ async def test_auth_client_provider_headers( ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the session - with pytest.raises(McpError) as mcpError: - _ = await session.initialize() - assert ( - _get_captured_headrs(mcpError.value.error.message)["Authorization"] - == "Bearer short-lived-token-123" - ) + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make a request to verify headers + tools = await session.list_tools() + assert len(tools.tools) == 4 + + client_provider.get_auth_headers.assert_called() @pytest.mark.anyio -async def test_auth_client_provider_token_called_on_every_request( - basic_server_with_header_capture, basic_server_url -): +async def test_auth_client_provider_called_per_request(basic_server, basic_server_url): """Test that auth token provider can return different tokens.""" # Create a dynamic token provider - client_provider = MockAuthClientProvider("short-lived-token-123") + client_provider = MockAuthClientProvider("test-token-123") + client_provider.get_auth_headers = AsyncMock( + return_value={"Authorization": "Bearer test-token-123"} + ) + # Create client with dynamic token provider async with streamablehttp_client( f"{basic_server_url}/mcp", auth_client_provider=client_provider ) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the session - with pytest.raises(McpError) as mcpError: - _ = await session.initialize() - assert ( - _get_captured_headrs(mcpError.value.error.message)["Authorization"] - == "Bearer short-lived-token-123" - ) - - # Mock a new token and ensure the new token is returned - client_provider.get_auth_headers = AsyncMock( - return_value={"Authorization": "Bearer short-lived-token-456"} - ) - with pytest.raises(McpError) as mcpError: - _ = await session.initialize() - assert ( - _get_captured_headrs(mcpError.value.error.message)["Authorization"] - == "Bearer short-lived-token-456" - ) + result = await session.initialize() + assert isinstance(result, InitializeResult) + # Make multiple requests to verify token updates + for i in range(3): + tools = await session.list_tools() + assert len(tools.tools) == 4 -@pytest.mark.anyio -async def test_auth_client_provider_headers_not_overridden( - basic_server_with_header_capture, basic_server_url -): - """Test that provided headers override auth client provider headers.""" - # Create a mock token provider - client_provider = MockAuthClientProvider("short-lived-token") - - # Create client with token provider and custom headers - custom_headers = {"Authorization": "Bearer original-long-lived-token"} - async with streamablehttp_client( - f"{basic_server_url}/mcp", - auth_client_provider=client_provider, - headers=custom_headers, - ) as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - # Original token is used and not short-lived-token from the provider - with pytest.raises(McpError) as mcpError: - _ = await session.initialize() - assert ( - _get_captured_headrs(mcpError.value.error.message)["Authorization"] - == "Bearer original-long-lived-token" - ) + client_provider.get_auth_headers.call_count > 1 From 8ab1d667599b16552ca01277a820c79f0237a385 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Tue, 13 May 2025 14:20:34 -0700 Subject: [PATCH 8/8] Address comments --- src/mcp/client/streamable_http.py | 15 ++++++------- tests/shared/test_streamable_http.py | 32 +++++++++++----------------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index acac5f1cb..13468c3b4 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -78,7 +78,7 @@ class AuthClientProvider(Protocol): """Base class that can be extended to implement custom client-to-server authentication""" - async def get_auth_headers(self) -> dict[str, str]: + async def get_headers(self) -> dict[str, str]: """Gets auth headers for authenticating to an MCP server. Clients may call this API multiple times per request to an MCP server. @@ -132,12 +132,12 @@ async def _update_headers_with_auth_headers( self, base_headers: dict[str, str] ) -> dict[str, str]: """Update headers with auth_headers if auth client provider is specified. - The headers are merged giving precedence to the base_headers to - avoid overwriting existing Authorization headers""" + The headers are merged, giving precedence to any headers already + specified in base_headers""" if self.auth_client_provider is None: return base_headers - auth_headers = await self.auth_client_provider.get_auth_headers() + auth_headers = await self.auth_client_provider.get_headers() return {**auth_headers, **base_headers} async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]: @@ -476,10 +476,9 @@ async def streamablehttp_client( `auth_client_provider` instance of `AuthClientProvider` that can be passed to support client-to-server authentication. Before each request to the MCP Server, - the auth_client_provider.get_token method is invoked to retrieve a fresh - authentication token and update the request headers. Note that if the passed in - `headers` already contain an Authorization header, that header will take precedence - over any tokens generated by this provider. + the auth_client_provider.get_headers() method is invoked to retrieve headers + for authentication. Note that any headers already specified in `headers` + will take precedence over headers returned by auth_client_provider.get_headers() Yields: Tuple containing: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index bf839b83e..ab6d79686 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1226,24 +1226,14 @@ async def sampling_callback( ) -class MockAuthClientProvider: - """Mock implementation of AuthClientProvider for testing.""" - - def __init__(self, token: str): - self.token = token - - async def get_auth_headers(self) -> dict[str, str]: - return {"Authorization": "Bearer " + self.token} - - @pytest.mark.anyio async def test_auth_client_provider_headers(basic_server, basic_server_url): """Test that auth token provider correctly sets Authorization header.""" # Create a mock token provider - client_provider = MockAuthClientProvider("test-token-123") - client_provider.get_auth_headers = AsyncMock( - return_value={"Authorization": "Bearer test-token-123"} - ) + client_provider = AsyncMock() + client_provider.get_headers.return_value = { + "Authorization": "Bearer test-token-123" + } # Create client with token provider async with streamablehttp_client( @@ -1258,17 +1248,17 @@ async def test_auth_client_provider_headers(basic_server, basic_server_url): tools = await session.list_tools() assert len(tools.tools) == 4 - client_provider.get_auth_headers.assert_called() + client_provider.get_headers.assert_called() @pytest.mark.anyio async def test_auth_client_provider_called_per_request(basic_server, basic_server_url): """Test that auth token provider can return different tokens.""" # Create a dynamic token provider - client_provider = MockAuthClientProvider("test-token-123") - client_provider.get_auth_headers = AsyncMock( - return_value={"Authorization": "Bearer test-token-123"} - ) + client_provider = AsyncMock() + client_provider.get_headers.return_value = { + "Authorization": "Bearer test-token-123" + } # Create client with dynamic token provider async with streamablehttp_client( @@ -1284,4 +1274,6 @@ async def test_auth_client_provider_called_per_request(basic_server, basic_serve tools = await session.list_tools() assert len(tools.tools) == 4 - client_provider.get_auth_headers.call_count > 1 + # list_tools is called 3 times, but get_auth_headers is also used during + # session initialization and setup. Verify it's called at least 3 times. + assert client_provider.get_headers.call_count > 3