Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/mcp/server/auth/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AuthorizationCodeRequest(BaseModel):
grant_type: Literal["authorization_code"]
code: str = Field(..., description="The authorization code")
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
client_id: str
client_id: str | None = Field(None, description="If none, client_id must be provided via basic auth header")
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
Expand All @@ -33,7 +33,7 @@ class RefreshTokenRequest(BaseModel):
grant_type: Literal["refresh_token"]
refresh_token: str = Field(..., description="The refresh token")
scope: str | None = Field(None, description="Optional scope parameter")
client_id: str
client_id: str | None = Field(None, description="If none, client_id must be provided via basic auth header")
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
# RFC 8707 resource indicator
Expand Down Expand Up @@ -131,7 +131,7 @@ async def handle(self, request: Request):
match token_request:
case AuthorizationCodeRequest():
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
if auth_code is None or auth_code.client_id != token_request.client_id:
if auth_code is None or auth_code.client_id != client_info.client_id:
# if code belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
Expand Down Expand Up @@ -197,7 +197,7 @@ async def handle(self, request: Request):

case RefreshTokenRequest(): # pragma: no cover
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
if refresh_token is None or refresh_token.client_id != token_request.client_id:
if refresh_token is None or refresh_token.client_id != client_info.client_id:
# if token belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
Expand Down
110 changes: 70 additions & 40 deletions src/mcp/server/auth/middleware/client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import binascii
import hmac
import time
from typing import Any
from dataclasses import dataclass
from typing import Any, Literal
from urllib.parse import unquote

from starlette.requests import Request
Expand All @@ -16,6 +17,13 @@ def __init__(self, message: str):
self.message = message # pragma: no cover


@dataclass
class ClientCredentials:
auth_method: Literal["client_secret_basic", "client_secret_post"]
client_id: str
client_secret: str | None = None


class ClientAuthenticator:
"""
ClientAuthenticator is a callable which validates requests from a client
Expand Down Expand Up @@ -52,64 +60,86 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
Raises:
AuthenticationError: If authentication fails
"""
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
raise AuthenticationError("Missing client_id")
client_credentials = await self._get_credentials(request)
client = await self.provider.get_client(client_credentials.client_id)

client = await self.provider.get_client(str(client_id))
if not client:
raise AuthenticationError("Invalid client_id") # pragma: no cover

request_client_secret: str | None = None
auth_header = request.headers.get("Authorization", "")

if client.token_endpoint_auth_method == "client_secret_basic":
if not auth_header.startswith("Basic "):
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")

try:
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
if ":" not in decoded:
raise ValueError("Invalid Basic auth format")
basic_client_id, request_client_secret = decoded.split(":", 1)

# URL-decode both parts per RFC 6749 Section 2.3.1
basic_client_id = unquote(basic_client_id)
request_client_secret = unquote(request_client_secret)

if basic_client_id != client_id:
raise AuthenticationError("Client ID mismatch in Basic auth")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid Basic authentication header")

if client_credentials.auth_method != "client_secret_basic":
raise AuthenticationError(
f"Expected client_secret_basic authentication method, but got {client_credentials.auth_method}"
)
elif client.token_endpoint_auth_method == "client_secret_post":
raw_form_data = form_data.get("client_secret")
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
if isinstance(raw_form_data, str):
request_client_secret = str(raw_form_data)

if client_credentials.auth_method != "client_secret_post":
raise AuthenticationError(
f"Expected client_secret_post authentication method, but got {client_credentials.auth_method}"
)
elif client.token_endpoint_auth_method == "none":
request_client_secret = None
else:
raise AuthenticationError( # pragma: no cover
f"Unsupported auth method: {client.token_endpoint_auth_method}"
)
pass
else: # pragma: no cover
raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}")

# If client from the store expects a secret, validate that the request provides
# that secret
if client.client_secret: # pragma: no branch
if not request_client_secret:
if not client_credentials.client_secret:
raise AuthenticationError("Client secret is required") # pragma: no cover

# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
# arguments to bytes.
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
if not hmac.compare_digest(client.client_secret.encode(), client_credentials.client_secret.encode()):
raise AuthenticationError("Invalid client_secret") # pragma: no cover

if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
raise AuthenticationError("Client secret has expired") # pragma: no cover

return client

async def _get_credentials(self, request: Request) -> ClientCredentials:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the core extraction logic of client_id and client_secret is as same as before.

"""
Extract client credentials from request, either from form data or Basic auth header.

Basic auth header takes precedence over form data.

Args:
request: The HTTP request containing client credentials
Returns:
The extracted client credentials
"""
# First, check for Basic auth header
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
if ":" not in decoded:
raise ValueError("Invalid Basic auth format")
client_id, client_secret = decoded.split(":", 1)

# URL-decode the client_id per RFC 6749 Section 2.3.1
client_id = unquote(client_id)
client_secret = unquote(client_secret)
return ClientCredentials(
auth_method="client_secret_basic",
client_id=client_id,
client_secret=client_secret,
)
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid Basic authentication header")

# If not, check for client_id and client_secret in form data
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
raise AuthenticationError("Missing client_id")

raw_client_secret = form_data.get("client_secret")
client_secret = str(raw_client_secret) if isinstance(raw_client_secret, str) else None
return ClientCredentials(
auth_method="client_secret_post",
client_id=str(client_id),
client_secret=client_secret,
)
88 changes: 74 additions & 14 deletions tests/server/fastmcp/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ async def test_client_secret_basic_authentication(
assert "access_token" in token_response

@pytest.mark.anyio
async def test_wrong_auth_method_without_valid_credentials_fails(
async def test_wrong_auth_method_fails(
Copy link
Author

@challenger71498 challenger71498 Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: test fails as same as before, only exception message has been changed, due to the auth method mismatch.

self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test that using the wrong authentication method fails when credentials are missing."""
Expand Down Expand Up @@ -1117,7 +1117,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails(
)

# Try to use Basic auth when client_secret_post is registered (without secret in body)
# This should fail because the secret is missing from the expected location
# This should fail despite that credentials are provided via Basic auth, because the method is wrong

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
Expand All @@ -1138,7 +1138,7 @@ async def test_wrong_auth_method_without_valid_credentials_fails(
error_response = response.json()
# RFC 6749: authentication failures return "invalid_client"
assert error_response["error"] == "invalid_client"
assert "Client secret is required" in error_response["error_description"]
assert "Expected client_secret_post authentication method" in error_response["error_description"]

@pytest.mark.anyio
async def test_basic_auth_without_header_fails(
Expand Down Expand Up @@ -1183,7 +1183,7 @@ async def test_basic_auth_without_header_fails(
error_response = response.json()
# RFC 6749: authentication failures return "invalid_client"
assert error_response["error"] == "invalid_client"
assert "Missing or invalid Basic authentication" in error_response["error_description"]
assert "Expected client_secret_basic authentication method" in error_response["error_description"]
Copy link
Author

@challenger71498 challenger71498 Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: test fails as same as before, only exception message has been changed, due to the auth method mismatch.


@pytest.mark.anyio
async def test_basic_auth_invalid_base64_fails(
Expand Down Expand Up @@ -1279,10 +1279,10 @@ async def test_basic_auth_no_colon_fails(
assert "Invalid Basic authentication header" in error_response["error_description"]

@pytest.mark.anyio
async def test_basic_auth_client_id_mismatch_fails(
async def test_basic_auth_takes_precedence(
Copy link
Author

@challenger71498 challenger71498 Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the behaviour of the test has been changed. After the changes, basic auth takes precedence over form data, so even if there is a mismatch, the middleware retrieves the client_id from the header.

This behaviour is intended due to RFC 6749 Section 2.3.1:

Including the client credentials in the request-body using the two parameters is NOT RECOMMENDED and SHOULD be limited to clients unable to directly utilize the HTTP Basic authentication scheme (or other password-based HTTP authentication schemes).

If this is inappropriate, alternatives are:

  • behave AS-IS (throw mismatch exception)
  • disallow dup auth method
    • throw exception which indicates that more than one authentication method was used.

self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test that client_id mismatch between body and Basic auth fails."""
"""Test that even client_id at body is invalid, Basic auth passes because of the priority."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Client",
Expand All @@ -1308,23 +1308,83 @@ async def test_basic_auth_client_id_mismatch_fails(
# Send different client_id in Basic auth header
import base64

wrong_creds = base64.b64encode(f"wrong-client-id:{client_info['client_secret']}".encode()).decode()
creds = base64.b64encode(f"{client_info['client_id']}:{client_info['client_secret']}".encode()).decode()
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {wrong_creds}"},
headers={"Authorization": f"Basic {creds}"},
data={
"grant_type": "authorization_code",
"client_id": client_info["client_id"], # Correct client_id in body
"client_id": "wrong-client-id", # Wrong client_id in body
"code": auth_code,
"code_verifier": pkce_challenge["code_verifier"],
"redirect_uri": "https://client.example.com/callback",
},
)
assert response.status_code == 401
error_response = response.json()
# RFC 6749: authentication failures return "invalid_client"
assert error_response["error"] == "invalid_client"
assert "Client ID mismatch" in error_response["error_description"]

# Header takes precedence, so this should succeed
assert response.status_code == 200

@pytest.mark.anyio
async def test_basic_auth_without_client_id_at_body(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test that Basic auth works even if client_id is missing from body."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Client",
"token_endpoint_auth_method": "client_secret_basic",
"grant_types": ["authorization_code", "refresh_token"],
}

response = await test_client.post("/register", json=client_metadata)
assert response.status_code == 201
client_info = response.json()

auth_code = f"code_{int(time.time())}"
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
code=auth_code,
client_id=client_info["client_id"],
code_challenge=pkce_challenge["code_challenge"],
redirect_uri=AnyUrl("https://client.example.com/callback"),
redirect_uri_provided_explicitly=True,
scopes=["read", "write"],
expires_at=time.time() + 600,
)

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()

response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "authorization_code",
# client_id omitted from body
"code": auth_code,
"code_verifier": pkce_challenge["code_verifier"],
"redirect_uri": "https://client.example.com/callback",
},
)
assert response.status_code == 200
token_response = response.json()
assert "access_token" in token_response
assert "refresh_token" in token_response

refresh_token = token_response["refresh_token"]

# Now, use the refresh token without client_id in body
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "refresh_token",
# client_id omitted from body
"refresh_token": refresh_token,
},
)
assert response.status_code == 200
new_token_response = response.json()
assert "access_token" in new_token_response

@pytest.mark.anyio
async def test_none_auth_method_public_client(
Expand Down