diff --git a/src/mcp/server/auth/client_credentials.py b/src/mcp/server/auth/client_credentials.py new file mode 100644 index 000000000..6ea066794 --- /dev/null +++ b/src/mcp/server/auth/client_credentials.py @@ -0,0 +1,85 @@ +"""Utilities for extracting client credentials from requests.""" + +import base64 +from typing import Tuple + +from starlette.requests import Request + +from mcp.shared.auth import OAuthClientInformationFull + + +class ClientCredentialError(Exception): + """Error extracting client credentials.""" + def __init__(self, error: str, error_description: str): + self.error = error + self.error_description = error_description + super().__init__(error_description) + + +def extract_client_credentials( + request: Request, + client: OAuthClientInformationFull, + client_id_from_body: str, + client_secret_from_body: str | None = None, +) -> Tuple[str, str | None]: + """ + Extract client credentials based on the client's registered authentication method. + + Args: + request: The HTTP request + client: The client information with registered auth method + client_id_from_body: The client_id from the request body + client_secret_from_body: The client_secret from the request body (if present) + + Returns: + Tuple of (client_id, client_secret) + + Raises: + ClientCredentialError: If credentials are missing or invalid + """ + client_id = client_id_from_body + client_secret = None + + if client.token_endpoint_auth_method == "client_secret_basic": + # Must use Basic auth header + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Basic "): + raise ClientCredentialError( + "invalid_client", + "Client must use Basic authentication" + ) + try: + # Decode Basic auth header + encoded_credentials = auth_header[6:] # Remove "Basic " prefix + decoded = base64.b64decode(encoded_credentials).decode("utf-8") + if ":" not in decoded: + raise ValueError("Invalid Basic auth format") + basic_client_id, client_secret = decoded.split(":", 1) + # Verify client_id matches + if basic_client_id != client_id: + raise ClientCredentialError( + "invalid_client", + "Client ID mismatch" + ) + except ClientCredentialError: + raise + except Exception: + raise ClientCredentialError( + "invalid_client", + "Invalid Basic authentication header" + ) + + elif client.token_endpoint_auth_method == "client_secret_post": + # Must use POST body + client_secret = client_secret_from_body + if client.client_secret and not client_secret: + raise ClientCredentialError( + "invalid_client", + "Client secret required in request body" + ) + + elif client.token_endpoint_auth_method == "none": + # Public client, no secret expected + client_secret = None + + return client_id, client_secret \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index e6d99e66d..815da5d3a 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -76,6 +76,17 @@ async def handle(self, request: Request) -> Response: ), status_code=400, ) + + # The MCP spec requires servers to use the authorization `code` flow + # with PKCE + if "code" not in client_metadata.response_types: + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description="response_types must include 'code' for authorization_code grant", + ), + status_code=400, + ) client_id_issued_at = int(time.time()) client_secret_expires_at = ( diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 478ad7a01..4a60dd301 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -6,6 +6,7 @@ from starlette.requests import Request from starlette.responses import Response +from mcp.server.auth.client_credentials import ClientCredentialError, extract_client_credentials from mcp.server.auth.errors import ( stringify_pydantic_error, ) @@ -51,10 +52,38 @@ async def handle(self, request: Request) -> Response: ), ) + # First, look up the client to determine expected auth method + client = await self.provider.get_client(revocation_request.client_id) + if not client: + return PydanticJSONResponse( + status_code=401, + content=RevocationErrorResponse( + error="unauthorized_client", + error_description="Invalid client_id", + ), + ) + + # Extract client credentials based on the registered auth method + try: + client_id, client_secret = extract_client_credentials( + request, + client, + revocation_request.client_id, + revocation_request.client_secret, + ) + except ClientCredentialError as e: + return PydanticJSONResponse( + status_code=401, + content=RevocationErrorResponse( + error=e.error, # type: ignore + error_description=e.error_description, + ), + ) + # Authenticate client try: client = await self.client_authenticator.authenticate( - revocation_request.client_id, revocation_request.client_secret + client_id, client_secret ) except AuthenticationError as e: return PydanticJSONResponse( diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 4e15e6265..1320389ee 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -1,12 +1,13 @@ -import base64 import hashlib import time from dataclasses import dataclass from typing import Annotated, Any, Literal +import base64 from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.requests import Request +from mcp.server.auth.client_credentials import ClientCredentialError, extract_client_credentials from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator @@ -102,10 +103,36 @@ async def handle(self, request: Request): ) ) + # First, look up the client to determine expected auth method + client = await self.provider.get_client(token_request.client_id) + if not client: + return self.response( + TokenErrorResponse( + error="invalid_client", + error_description="Invalid client_id", + ) + ) + + # Extract client credentials based on the registered auth method + try: + client_id, client_secret = extract_client_credentials( + request, + client, + token_request.client_id, + token_request.client_secret, + ) + except ClientCredentialError as e: + return self.response( + TokenErrorResponse( + error=e.error, # type: ignore + error_description=e.error_description, + ) + ) + try: client_info = await self.client_authenticator.authenticate( - client_id=token_request.client_id, - client_secret=token_request.client_secret, + client_id=client_id, + client_secret=client_secret, ) except AuthenticationError as e: return self.response( diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index bce32df52..e20db5a54 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -164,7 +164,7 @@ def build_metadata( response_types_supported=["code"], response_modes_supported=None, grant_types_supported=["authorization_code", "refresh_token"], - token_endpoint_auth_methods_supported=["client_secret_post"], + token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"], token_endpoint_auth_signing_alg_values_supported=None, service_documentation=service_documentation_url, ui_locales_supported=None, @@ -181,7 +181,7 @@ def build_metadata( # Add revocation endpoint if supported if revocation_options.enabled: metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH) - metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] + metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"] return metadata diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 6bf15b531..607d0eb25 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -42,17 +42,15 @@ class OAuthClientMetadata(BaseModel): """ redirect_uris: list[AnyUrl] = Field(..., min_length=1) - # token_endpoint_auth_method: this implementation only supports none & - # client_secret_post; - # ie: we do not support client_secret_basic - token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" + token_endpoint_auth_method: Literal["none", "client_secret_post", "client_secret_basic"] = "client_secret_post" # grant_types: this implementation only supports authorization_code & refresh_token grant_types: list[Literal["authorization_code", "refresh_token"]] = [ "authorization_code", "refresh_token", ] - # this implementation only supports code; ie: it does not support implicit grants - response_types: list[Literal["code"]] = ["code"] + # The MCP spec requires the "code" response type, but OAuth + # servers may also return additional types they support + response_types: list[str] = ["code"] scope: str | None = None # these fields are currently unused, but we support & store them for potential