Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from anyio.abc import TaskGroup
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from mcp_types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
INVALID_REQUEST,
METHOD_NOT_FOUND,
Expand Down Expand Up @@ -381,6 +382,17 @@ async def _handle_sse_response(
if last_event_id is not None: # pragma: no branch
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
else:
await self._send_connection_closed(ctx, original_request_id)

async def _send_connection_closed(self, ctx: RequestContext, request_id: RequestId) -> None:
"""Resolve a pending POST SSE request when the stream closes before a reply."""
error_data = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data))
try:
await ctx.read_stream_writer.send(error_msg)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped connection-closed error for %r: read stream closed", request_id)

async def _handle_reconnection(
self,
Expand All @@ -393,6 +405,8 @@ async def _handle_reconnection(
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
if isinstance(ctx.session_message.message, JSONRPCRequest):
await self._send_connection_closed(ctx, ctx.session_message.message.id)
return

# Always wait - use server value or default
Expand Down
35 changes: 34 additions & 1 deletion tests/client/test_notification_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

import json

import anyio
import httpx
import mcp_types as types
import pytest
from mcp_types import RootsListChangedNotification
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Route

from mcp import ClientSession, MCPError
Expand Down Expand Up @@ -72,6 +73,24 @@ async def handle_mcp_request(request: Request) -> Response:
return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])])


def _create_empty_sse_response_app() -> Starlette:
"""Create a server that closes a POST SSE response without a JSON-RPC reply."""

async def handle_mcp_request(request: Request) -> Response:
body = await request.body()
data = json.loads(body)

if data.get("method") == "initialize":
return _init_json_response(data)

if "id" not in data:
return Response(status_code=202)

return StreamingResponse(iter(()), media_type="text/event-stream")

return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])])


async def test_non_compliant_notification_response() -> None:
"""Verify the client ignores unexpected responses to notifications.

Expand Down Expand Up @@ -117,6 +136,20 @@ async def test_unexpected_content_type_sends_jsonrpc_error() -> None:
await session.list_tools()


async def test_empty_post_sse_response_unblocks_pending_tool_call() -> None:
"""An SSE response that closes before a JSON-RPC reply raises instead of hanging."""
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_empty_sse_response_app())) as client:
async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()

with pytest.raises(MCPError) as exc_info:
with anyio.fail_after(1):
await session.call_tool("greet", {})

assert exc_info.value.error.code == types.CONNECTION_CLOSED


def _create_http_error_app(error_status: int, *, error_on_notifications: bool = False) -> Starlette:
"""Create a server that returns an HTTP error for non-init requests."""

Expand Down
Loading