Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,26 +456,6 @@ async def _handle_session_message(message: SessionMessage) -> None:
pass
self._response_streams.clear()

def _normalize_request_id(self, response_id: RequestId) -> RequestId:
"""Normalize a response ID to match how request IDs are stored.

Since the client always sends integer IDs, we normalize string IDs
to integers when possible. This matches the TypeScript SDK approach:
https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861

Args:
response_id: The response ID from the incoming message.

Returns:
The normalized ID (int if possible, otherwise original value).
"""
if isinstance(response_id, str):
try:
return int(response_id)
except ValueError:
logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests")
return response_id

async def _handle_response(self, message: SessionMessage) -> None:
"""Handle an incoming response or error message.

Expand All @@ -495,8 +475,7 @@ async def _handle_response(self, message: SessionMessage) -> None:
logging.warning(f"Received error with null ID: {error.message}")
await self._handle_incoming(MCPError(error.code, error.message, error.data))
return
# Normalize response ID to handle type mismatches (e.g., "0" vs 0)
response_id = self._normalize_request_id(message.message.id)
response_id = message.message.id

# First, check response routers (e.g., TaskResultHandler)
if isinstance(message.message, JSONRPCError):
Expand Down
81 changes: 39 additions & 42 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,44 +99,47 @@ async def make_request(client: Client):


@pytest.mark.anyio
async def test_response_id_type_mismatch_string_to_int():
"""Test that responses with string IDs are correctly matched to requests sent with
integer IDs.
async def test_response_id_type_mismatch_string_to_int_rejected():
"""Verify that a response with a string ID does not match a request sent with
an integer ID.

This handles the case where a server returns "id": "0" (string) but the client
sent "id": 0 (integer). Without ID type normalization, this would cause a timeout.
Per JSON-RPC 2.0, the response ID "MUST be the same as the value of the id
member in the Request Object". Since Python treats 0 != "0", a server that
echoes back "0" instead of 0 is non-compliant and the request should time out.
"""
ev_response_received = anyio.Event()
result_holder: list[types.EmptyResult] = []
ev_timeout = anyio.Event()

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def mock_server():
"""Receive a request and respond with a string ID instead of integer."""
message = await server_read.receive()
assert isinstance(message, SessionMessage)
root = message.message
assert isinstance(root, JSONRPCRequest)
# Get the original request ID (which is an integer)
request_id = root.id
assert isinstance(request_id, int), f"Expected int, got {type(request_id)}"
assert isinstance(request_id, int)

# Respond with the ID as a string (simulating a buggy server)
# Respond with the ID as a string (non-compliant server)
response = JSONRPCResponse(
jsonrpc="2.0",
id=str(request_id), # Convert to string to simulate mismatch
id=str(request_id),
result={},
)
await server_write.send(SessionMessage(message=response))

async def make_request(client_session: ClientSession):
nonlocal result_holder
# Send a ping request (uses integer ID internally)
result = await client_session.send_ping()
result_holder.append(result)
ev_response_received.set()
try:
await client_session.send_request(
types.PingRequest(),
types.EmptyResult,
request_read_timeout_seconds=0.5,
)
pytest.fail("Expected timeout") # pragma: no cover
except MCPError as e:
assert "Timed out" in str(e)
ev_timeout.set()

async with (
anyio.create_task_group() as tg,
Expand All @@ -146,52 +149,49 @@ async def make_request(client_session: ClientSession):
tg.start_soon(make_request, client_session)

with anyio.fail_after(2): # pragma: no branch
await ev_response_received.wait()

assert len(result_holder) == 1
assert isinstance(result_holder[0], EmptyResult)
await ev_timeout.wait()


@pytest.mark.anyio
async def test_error_response_id_type_mismatch_string_to_int():
"""Test that error responses with string IDs are correctly matched to requests
sent with integer IDs.
async def test_error_response_id_type_mismatch_string_to_int_rejected():
"""Verify that an error response with a string ID does not match a request
sent with an integer ID.

This handles the case where a server returns an error with "id": "0" (string)
but the client sent "id": 0 (integer).
The JSON-RPC spec requires exact ID matching including type.
"""
ev_error_received = anyio.Event()
error_holder: list[MCPError | Exception] = []
ev_timeout = anyio.Event()

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def mock_server():
"""Receive a request and respond with an error using a string ID."""
message = await server_read.receive()
assert isinstance(message, SessionMessage)
root = message.message
assert isinstance(root, JSONRPCRequest)
request_id = root.id
assert isinstance(request_id, int)

# Respond with an error, using the ID as a string
# Respond with an error using the ID as a string (non-compliant)
error_response = JSONRPCError(
jsonrpc="2.0",
id=str(request_id), # Convert to string to simulate mismatch
id=str(request_id),
error=ErrorData(code=-32600, message="Test error"),
)
await server_write.send(SessionMessage(message=error_response))

async def make_request(client_session: ClientSession):
nonlocal error_holder
try:
await client_session.send_ping()
pytest.fail("Expected MCPError to be raised") # pragma: no cover
await client_session.send_request(
types.PingRequest(),
types.EmptyResult,
request_read_timeout_seconds=0.5,
)
pytest.fail("Expected timeout") # pragma: no cover
except MCPError as e:
error_holder.append(e)
ev_error_received.set()
assert "Timed out" in str(e)
ev_timeout.set()

async with (
anyio.create_task_group() as tg,
Expand All @@ -201,16 +201,13 @@ async def make_request(client_session: ClientSession):
tg.start_soon(make_request, client_session)

with anyio.fail_after(2): # pragma: no branch
await ev_error_received.wait()

assert len(error_holder) == 1
assert "Test error" in str(error_holder[0])
await ev_timeout.wait()


@pytest.mark.anyio
async def test_response_id_non_numeric_string_no_match():
"""Test that responses with non-numeric string IDs don't incorrectly match
integer request IDs.
"""Test that responses with non-numeric string IDs don't match integer
request IDs.

If a server returns "id": "abc" (non-numeric string), it should not match
a request sent with "id": 0 (integer).
Expand Down
Loading