diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 804180e05..aaacadca6 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -351,6 +351,7 @@ def __init__( self._negotiated_version: str | None = None self._stamp: Callable[[dict[str, Any], CallOptions], None] = _preconnect_stamp self._task_group: anyio.abc.TaskGroup | None = None + self._entered = False if dispatcher is not None: if read_stream is not None or write_stream is not None: raise ValueError("pass read_stream/write_stream or dispatcher, not both") @@ -372,6 +373,8 @@ def __init__( ) async def __aenter__(self) -> Self: + if self._entered: + raise RuntimeError("Session is already running") self._task_group = anyio.create_task_group() await self._task_group.__aenter__() try: @@ -398,6 +401,7 @@ async def __aenter__(self) -> Self: finally: self._close_binding_queues() raise + self._entered = True return self async def __aexit__( @@ -411,10 +415,11 @@ async def __aexit__( self._task_group.cancel_scope.cancel() try: result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + await resync_tracer() + return result finally: + self._entered = False self._close_binding_queues() - await resync_tracer() - return result def _close_binding_queues(self) -> None: # Unclosed memory object streams warn at garbage collection; close is idempotent. @@ -455,6 +460,11 @@ async def send_request( pydantic.ValidationError: The server returned a result that does not conform to the negotiated protocol version. """ + if self._task_group is None: + raise RuntimeError( + "Session is not running. Use it as an async context manager " + "(e.g. `async with ClientSession(...) as session:`)." + ) data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] opts: CallOptions = {} diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f76991f65..519bf91ef 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -153,6 +153,43 @@ async def message_handler( # pragma: no cover assert isinstance(initialized_notification, InitializedNotification) +@pytest.mark.anyio +async def test_client_session_requires_context_manager(): + client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + _server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_to_server_send, + _client_to_server_receive, + _server_to_client_send, + server_to_client_receive, + ): + session = ClientSession(server_to_client_receive, client_to_server_send) + + with pytest.raises(RuntimeError, match="async context manager"): + await session.initialize() + + +@pytest.mark.anyio +async def test_client_session_reentry_raises_runtime_error(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + session = ClientSession(server_to_client_receive, client_to_server_send) + await session.__aenter__() + try: + with pytest.raises(RuntimeError, match="already running"): + await session.__aenter__() + finally: + await session.__aexit__(None, None, None) + + @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -1263,12 +1300,12 @@ async def server_on_notify( @pytest.mark.anyio async def test_dispatcher_keyword_send_request_before_enter_raises_runtimeerror(): - """The documented pre-enter RuntimeError holds for dispatcher= sessions too.""" + """The documented pre-enter RuntimeError holds before any dispatcher call.""" client_side, _server_side = create_direct_dispatcher_pair() session = ClientSession(dispatcher=client_side) with anyio.fail_after(5), pytest.raises(RuntimeError) as exc: await session.send_ping() - assert str(exc.value) == "DirectDispatcher.send_raw_request called before run()" + assert "async context manager" in str(exc.value) @pytest.mark.anyio