Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def __init__(
self._negotiated_version: str | None = None
self._stamp: Callable[[dict[str, Any], CallOptions], None] = _preconnect_stamp
self._task_group: anyio.abc.TaskGroup | None = None
self._entered = False
if dispatcher is not None:
if read_stream is not None or write_stream is not None:
raise ValueError("pass read_stream/write_stream or dispatcher, not both")
Expand All @@ -372,6 +373,8 @@ def __init__(
)

async def __aenter__(self) -> Self:
if self._entered:
raise RuntimeError("Session is already running")
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
try:
Expand All @@ -398,6 +401,7 @@ async def __aenter__(self) -> Self:
finally:
self._close_binding_queues()
raise
self._entered = True
return self

async def __aexit__(
Expand All @@ -411,10 +415,11 @@ async def __aexit__(
self._task_group.cancel_scope.cancel()
try:
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
await resync_tracer()
return result
finally:
self._entered = False
self._close_binding_queues()
await resync_tracer()
return result

def _close_binding_queues(self) -> None:
# Unclosed memory object streams warn at garbage collection; close is idempotent.
Expand Down Expand Up @@ -455,6 +460,11 @@ async def send_request(
pydantic.ValidationError: The server returned a result that does not
conform to the negotiated protocol version.
"""
if self._task_group is None:
raise RuntimeError(
"Session is not running. Use it as an async context manager "
"(e.g. `async with ClientSession(...) as session:`)."
)
data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
method: str = data["method"]
opts: CallOptions = {}
Expand Down
41 changes: 39 additions & 2 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,43 @@ async def message_handler( # pragma: no cover
assert isinstance(initialized_notification, InitializedNotification)


@pytest.mark.anyio
async def test_client_session_requires_context_manager():
client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
_server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

async with (
client_to_server_send,
_client_to_server_receive,
_server_to_client_send,
server_to_client_receive,
):
session = ClientSession(server_to_client_receive, client_to_server_send)

with pytest.raises(RuntimeError, match="async context manager"):
await session.initialize()


@pytest.mark.anyio
async def test_client_session_reentry_raises_runtime_error():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
session = ClientSession(server_to_client_receive, client_to_server_send)
await session.__aenter__()
try:
with pytest.raises(RuntimeError, match="already running"):
await session.__aenter__()
finally:
await session.__aexit__(None, None, None)


@pytest.mark.anyio
async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
Expand Down Expand Up @@ -1263,12 +1300,12 @@ async def server_on_notify(

@pytest.mark.anyio
async def test_dispatcher_keyword_send_request_before_enter_raises_runtimeerror():
"""The documented pre-enter RuntimeError holds for dispatcher= sessions too."""
"""The documented pre-enter RuntimeError holds before any dispatcher call."""
client_side, _server_side = create_direct_dispatcher_pair()
session = ClientSession(dispatcher=client_side)
with anyio.fail_after(5), pytest.raises(RuntimeError) as exc:
await session.send_ping()
assert str(exc.value) == "DirectDispatcher.send_raw_request called before run()"
assert "async context manager" in str(exc.value)


@pytest.mark.anyio
Expand Down
Loading