diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index f28eb7c7a..47178ca66 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -13,6 +13,7 @@ from anyio.abc import TaskGroup from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from mcp_types import ( + CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, @@ -381,6 +382,17 @@ async def _handle_sse_response( if last_event_id is not None: # pragma: no branch logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + else: + await self._send_connection_closed(ctx, original_request_id) + + async def _send_connection_closed(self, ctx: RequestContext, request_id: RequestId) -> None: + """Resolve a pending POST SSE request when the stream closes before a reply.""" + error_data = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data)) + try: + await ctx.read_stream_writer.send(error_msg) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped connection-closed error for %r: read stream closed", request_id) async def _handle_reconnection( self, @@ -393,6 +405,8 @@ async def _handle_reconnection( # Bail if max retries exceeded if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") + if isinstance(ctx.session_message.message, JSONRPCRequest): + await self._send_connection_closed(ctx, ctx.session_message.message.id) return # Always wait - use server value or default diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 418a6bc54..634e49bea 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -6,13 +6,14 @@ import json +import anyio import httpx import mcp_types as types import pytest from mcp_types import RootsListChangedNotification from starlette.applications import Starlette from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import JSONResponse, Response, StreamingResponse from starlette.routing import Route from mcp import ClientSession, MCPError @@ -72,6 +73,24 @@ async def handle_mcp_request(request: Request) -> Response: return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) +def _create_empty_sse_response_app() -> Starlette: + """Create a server that closes a POST SSE response without a JSON-RPC reply.""" + + async def handle_mcp_request(request: Request) -> Response: + body = await request.body() + data = json.loads(body) + + if data.get("method") == "initialize": + return _init_json_response(data) + + if "id" not in data: + return Response(status_code=202) + + return StreamingResponse(iter(()), media_type="text/event-stream") + + return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) + + async def test_non_compliant_notification_response() -> None: """Verify the client ignores unexpected responses to notifications. @@ -117,6 +136,20 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None: await session.list_tools() +async def test_empty_post_sse_response_unblocks_pending_tool_call() -> None: + """An SSE response that closes before a JSON-RPC reply raises instead of hanging.""" + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_empty_sse_response_app())) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + with pytest.raises(MCPError) as exc_info: + with anyio.fail_after(1): + await session.call_tool("greet", {}) + + assert exc_info.value.error.code == types.CONNECTION_CLOSED + + def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette: """Create a server that returns an HTTP error for non-init requests."""