-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Dynamic Authorization in Streamable HTTP Client #700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7d9a84a
289c03a
c99d4f7
785964e
d3f0dea
c4fb621
28ae4f7
8ab1d66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
from contextlib import asynccontextmanager | ||
from dataclasses import dataclass | ||
from datetime import timedelta | ||
from typing import Any | ||
from typing import Any, Protocol | ||
|
||
import anyio | ||
import httpx | ||
|
@@ -74,6 +74,20 @@ class RequestContext: | |
sse_read_timeout: timedelta | ||
|
||
|
||
class AuthClientProvider(Protocol): | ||
"""Base class that can be extended to implement custom client-to-server | ||
authentication""" | ||
|
||
async def get_headers(self) -> dict[str, str]: | ||
"""Gets auth headers for authenticating to an MCP server. | ||
Clients may call this API multiple times per request to an MCP server. | ||
|
||
Returns: | ||
dict[str, str]: The authentication headers. | ||
""" | ||
... | ||
|
||
|
||
class StreamableHTTPTransport: | ||
"""StreamableHTTP client transport implementation.""" | ||
|
||
|
@@ -83,6 +97,7 @@ def __init__( | |
headers: dict[str, Any] | None = None, | ||
timeout: timedelta = timedelta(seconds=30), | ||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | ||
auth_client_provider: AuthClientProvider | None = None, | ||
) -> None: | ||
"""Initialize the StreamableHTTP transport. | ||
|
||
|
@@ -102,6 +117,7 @@ def __init__( | |
CONTENT_TYPE: JSON, | ||
**self.headers, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I would expect any auth headers passed in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call, I added the behaviour to not override passed in headers, and add a test case as well |
||
} | ||
self.auth_client_provider = auth_client_provider | ||
|
||
def _update_headers_with_session( | ||
self, base_headers: dict[str, str] | ||
|
@@ -112,6 +128,24 @@ def _update_headers_with_session( | |
headers[MCP_SESSION_ID] = self.session_id | ||
return headers | ||
|
||
async def _update_headers_with_auth_headers( | ||
self, base_headers: dict[str, str] | ||
) -> dict[str, str]: | ||
"""Update headers with auth_headers if auth client provider is specified. | ||
The headers are merged, giving precedence to any headers already | ||
specified in base_headers""" | ||
if self.auth_client_provider is None: | ||
return base_headers | ||
|
||
auth_headers = await self.auth_client_provider.get_headers() | ||
return {**auth_headers, **base_headers} | ||
|
||
async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]: | ||
"""Update headers with session ID and token if available.""" | ||
headers = self._update_headers_with_session(base_headers) | ||
headers = await self._update_headers_with_auth_headers(headers) | ||
return headers | ||
|
||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool: | ||
"""Check if the message is an initialization request.""" | ||
return ( | ||
|
@@ -184,7 +218,7 @@ async def handle_get_stream( | |
if not self.session_id: | ||
return | ||
|
||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = await self._update_headers(self.request_headers) | ||
|
||
async with aconnect_sse( | ||
client, | ||
|
@@ -206,7 +240,7 @@ async def handle_get_stream( | |
|
||
async def _handle_resumption_request(self, ctx: RequestContext) -> None: | ||
"""Handle a resumption request using GET with SSE.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = await self._update_headers(ctx.headers) | ||
if ctx.metadata and ctx.metadata.resumption_token: | ||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token | ||
else: | ||
|
@@ -216,7 +250,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: | |
original_request_id = None | ||
if isinstance(ctx.session_message.message.root, JSONRPCRequest): | ||
original_request_id = ctx.session_message.message.root.id | ||
|
||
async with aconnect_sse( | ||
ctx.client, | ||
"GET", | ||
|
@@ -241,7 +274,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: | |
|
||
async def _handle_post_request(self, ctx: RequestContext) -> None: | ||
"""Handle a POST request with response processing.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = await self._update_headers(ctx.headers) | ||
message = ctx.session_message.message | ||
is_initialization = self._is_initialization_request(message) | ||
|
||
|
@@ -268,7 +301,6 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: | |
self._maybe_extract_session_id_from_response(response) | ||
|
||
content_type = response.headers.get(CONTENT_TYPE, "").lower() | ||
|
||
if content_type.startswith(JSON): | ||
await self._handle_json_response(response, ctx.read_stream_writer) | ||
elif content_type.startswith(SSE): | ||
|
@@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: | |
return | ||
|
||
try: | ||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = await self._update_headers(self.request_headers) | ||
response = await client.delete(self.url, headers=headers) | ||
|
||
if response.status_code == 405: | ||
|
@@ -427,6 +459,7 @@ async def streamablehttp_client( | |
timeout: timedelta = timedelta(seconds=30), | ||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | ||
terminate_on_close: bool = True, | ||
auth_client_provider: AuthClientProvider | None = None, | ||
) -> AsyncGenerator[ | ||
tuple[ | ||
MemoryObjectReceiveStream[SessionMessage | Exception], | ||
|
@@ -441,13 +474,21 @@ async def streamablehttp_client( | |
`sse_read_timeout` determines how long (in seconds) the client will wait for a new | ||
event before disconnecting. All other HTTP operations are controlled by `timeout`. | ||
|
||
`auth_client_provider` instance of `AuthClientProvider` that can be passed to | ||
support client-to-server authentication. Before each request to the MCP Server, | ||
the auth_client_provider.get_headers() method is invoked to retrieve headers | ||
for authentication. Note that any headers already specified in `headers` | ||
will take precedence over headers returned by auth_client_provider.get_headers() | ||
|
||
Yields: | ||
Tuple containing: | ||
- read_stream: Stream for reading messages from the server | ||
- write_stream: Stream for sending messages to the server | ||
- get_session_id_callback: Function to retrieve the current session ID | ||
""" | ||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) | ||
transport = StreamableHTTPTransport( | ||
url, headers, timeout, sse_read_timeout, auth_client_provider | ||
) | ||
|
||
read_stream_writer, read_stream = anyio.create_memory_object_stream[ | ||
SessionMessage | Exception | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import time | ||
from collections.abc import Generator | ||
from typing import Any | ||
from unittest.mock import AsyncMock | ||
|
||
import anyio | ||
import httpx | ||
|
@@ -1223,3 +1224,56 @@ async def sampling_callback( | |
captured_message_params.messages[0].content.text | ||
== "Server needs client sampling" | ||
) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_auth_client_provider_headers(basic_server, basic_server_url): | ||
"""Test that auth token provider correctly sets Authorization header.""" | ||
# Create a mock token provider | ||
client_provider = AsyncMock() | ||
client_provider.get_headers.return_value = { | ||
"Authorization": "Bearer test-token-123" | ||
} | ||
|
||
# Create client with token provider | ||
async with streamablehttp_client( | ||
f"{basic_server_url}/mcp", auth_client_provider=client_provider | ||
) as (read_stream, write_stream, _): | ||
async with ClientSession(read_stream, write_stream) as session: | ||
# Initialize the session | ||
result = await session.initialize() | ||
assert isinstance(result, InitializeResult) | ||
|
||
# Make a request to verify headers | ||
tools = await session.list_tools() | ||
assert len(tools.tools) == 4 | ||
|
||
client_provider.get_headers.assert_called() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_auth_client_provider_called_per_request(basic_server, basic_server_url): | ||
"""Test that auth token provider can return different tokens.""" | ||
# Create a dynamic token provider | ||
client_provider = AsyncMock() | ||
client_provider.get_headers.return_value = { | ||
"Authorization": "Bearer test-token-123" | ||
} | ||
|
||
# Create client with dynamic token provider | ||
async with streamablehttp_client( | ||
f"{basic_server_url}/mcp", auth_client_provider=client_provider | ||
) as (read_stream, write_stream, _): | ||
async with ClientSession(read_stream, write_stream) as session: | ||
# Initialize the session | ||
result = await session.initialize() | ||
assert isinstance(result, InitializeResult) | ||
|
||
# Make multiple requests to verify token updates | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dumb question, where do we verify the token is actually updated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is really hard in this testing environment to get the headers and verify the implementation. There is a mock server which hosts a list and set tools. We create a session, and add our messages to the write stream. This is then read by our transport layer and a request is sent to the server. I could not find a way to intercept or inspect this request object to verify the headers. So I just ensured the method was being called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could try to build a custom app, that looks at the header, then calls the server, and returns the auth headers in the response headers. I will wait for the maintainer to chime in if they have better ideas on how to test this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. off the cuff I suspect you'd want to patch:
In the appropriate places to catch all calls with headers, and assert your headers from the provider are there. I think the mocks could just pass through to the original function |
||
for i in range(3): | ||
tools = await session.list_tools() | ||
assert len(tools.tools) == 4 | ||
|
||
# list_tools is called 3 times, but get_auth_headers is also used during | ||
# session initialization and setup. Verify it's called at least 3 times. | ||
assert client_provider.get_headers.call_count > 3 |
Uh oh!
There was an error while loading. Please reload this page.