diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 09e5048cc..e4c32c03e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -356,6 +356,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: error_data = ErrorData(code=METHOD_NOT_FOUND, message="Not Found") else: error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") + elif response.status_code == 401: + error_data = ErrorData(code=INTERNAL_ERROR, message="Unauthorized") else: error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index defda41f8..8e76caaf5 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -18,6 +18,7 @@ from mcp_types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, + INTERNAL_ERROR, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, JSONRPCError, @@ -124,6 +125,32 @@ def handler(request: httpx.Request) -> httpx.Response: assert reply.message.error.code == METHOD_NOT_FOUND +@pytest.mark.anyio +async def test_bare_401_request_maps_to_unauthorized_jsonrpc_error() -> None: + """A bare HTTP 401 should reach the caller as a correlated JSON-RPC error. + + Authorization failures can be operation-specific. The client transport must + leave room for the agent/session layer to handle the denial instead of + collapsing it into an indistinguishable transport failure. + """ + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(401) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={}))) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCError) + assert reply.message.id == 1 + assert reply.message.error.code == INTERNAL_ERROR + assert reply.message.error.message == "Unauthorized" + + @pytest.mark.anyio async def test_initialize_post_clears_cached_pv_header_and_unstamped_posts_read_it() -> None: """``initialize`` discards the cached protocol-version header; every other POST reads it.