diff --git a/docs/advanced/caching.md b/docs/advanced/caching.md index f53a3096b..ba979ccc1 100644 --- a/docs/advanced/caching.md +++ b/docs/advanced/caching.md @@ -4,61 +4,114 @@ Every result a server returns for `tools/list`, `prompts/list`, `resources/list` The server doesn't cache anything. The fields are a *declaration*: "this tool list is the same for everyone and won't change for a minute." A client (or a gateway in front of you) may then skip the round trip. Honoring the hints is the client's choice; emitting them is the server's job, and the SDK does it for you. -Out of the box every result says `ttlMs: 0, cacheScope: "private"` — immediately stale, never shared. That is always safe and always conformant. If your lists really are stable and identical for all callers, say so at construction: +Out of the box every result says `ttlMs: 0, cacheScope: "private"`: immediately stale, never shared. That is always safe and always conformant. If your lists really are stable and identical for all callers, say so at construction: ```python title="server.py" hl_lines="5-8" --8<-- "docs_src/caching/tutorial001.py" ``` -* The map is keyed by **method name** — the six cacheable methods are the only legal keys. The parameter is typed `Mapping[CacheableMethod, CacheHint]`, so your editor autocompletes the keys and flags a typo before you run; anything that slips past the type checker raises at construction. +* The map is keyed by **method name**, and the six cacheable methods are the only legal keys. The parameter is typed `Mapping[CacheableMethod, CacheHint]`, so your editor autocompletes the keys and flags a typo before you run; anything that slips past the type checker raises at construction. * A method you don't mention keeps the defaults. The map is a set of overrides, not a manifest. * `CacheHint(ttl_ms=5_000)` left `scope` unset, so it stays `"private"`: five seconds of freshness, per caller. Scope and TTL are independent decisions. -* `"server/discover"` is a legal key too — the handshake result is cacheable like any list. +* `"server/discover"` is a legal key too, since the handshake result is cacheable like any list. !!! warning - `cacheScope: "public"` means *anyone* may be served your cached response — a shared + `cacheScope: "public"` means *anyone* may be served your cached response. A shared gateway will happily hand one user's result to another, even when the request was authenticated. Mark a result `"public"` only when it is identical for every caller, and never use `cacheScope` as access control: it is a label, not a lock. ## Per-handler override -On the low-level `Server`, handlers build their results by hand — and `ttl_ms` / `cache_scope` are just fields on the result models. A handler that sets them explicitly always wins over the constructor map, field by field: +On the low-level `Server`, handlers build their results by hand, and `ttl_ms` / `cache_scope` are just fields on the result models. A handler that sets them explicitly always wins over the constructor map, field by field: ```python title="server.py" hl_lines="11 17" --8<-- "docs_src/caching/tutorial002.py" ``` -The handler said `ttl_ms=1_000` and nothing about scope. On the wire: `ttlMs: 1000` (the handler's, not the map's `60_000`) and `cacheScope: "public"` (the map's — the handler left it unset). Explicit beats configured, configured beats default — per field, so a handler can pin one field and leave the other to the server-wide policy. +The handler said `ttl_ms=1_000` and nothing about scope. On the wire: `ttlMs: 1000` (the handler's, not the map's `60_000`) and `cacheScope: "public"` (the map's, because the handler left it unset). Explicit beats configured, and configured beats default. This holds per field, so a handler can pin one field and leave the other to the server-wide policy. This is also the escape hatch for dynamics the constructor can't know: a handler that filters `resources/read` per user can return `cache_scope="private"` for one URI from an otherwise-public server. -One caveat on paginated lists: the protocol requires the **same `cacheScope` on every page** of one list. The constructor map satisfies that by construction — it's keyed by method, not by page. But a handler that overrides the scope itself owns that consistency: override it on *every* page, never only when a cursor is present, or page one and page two will disagree. +One caveat on paginated lists: the protocol requires the **same `cacheScope` on every page** of one list. The constructor map satisfies that by construction, since it's keyed by method, not by page. But a handler that overrides the scope itself owns that consistency: override it on *every* page, never only when a cursor is present, or page one and page two will disagree. ## What the client sees -On the client, the hints arrive as plain fields on every cacheable result — `ttl_ms` and `cache_scope`, already parsed: +On a 2026-07-28 session, `Client` honors the hints for you: it has a built-in response cache, on by default. A result that arrives carrying a `ttlMs` is stored, and an identical call within that TTL is served from the cache with no round trip. A result that carries *no* hint is not cached: hint-less results get `CacheConfig.default_ttl_ms`, which defaults to `0` (immediately stale), so a server that declares nothing sees exactly the call-for-call traffic it always did. -```python title="client.py" hl_lines="15" +```python title="client.py" hl_lines="34 36 39" --8<-- "docs_src/caching/tutorial003.py" ``` -The SDK parses; it does not (yet) act. There is no built-in response cache: calling `list_tools()` twice makes two round trips, whatever the TTL said. The spec makes honoring optional — a client that ignores the hints entirely is fully conformant — so until the SDK grows a response cache, the supported path is to read the fields and do your own bookkeeping: +Four calls, three fetches. The second call found a fresh entry and never reached the server; advancing the (injected) clock past the TTL made the third fetch again; the fourth said `cache_mode="refresh"`. That kwarg exists on the five caching verbs (`list_tools`, `list_prompts`, `list_resources`, `list_resource_templates`, `read_resource`): -* **Freshness** is `now < t_received + ttl_ms / 1000`: record the clock when the response arrives, and treat the result as reusable until the TTL runs out. `ttl_ms == 0` means *immediately stale* — don't reuse it at all. -* **Scope is a sharing rule, not a suggestion.** A `"private"` result may be reused only within the same authorization context — same access token, same cache. Never put `"private"` results in a cache shared across users. -* **Notifications beat TTL.** If the server sends `list_changed` while your copy is still fresh, the copy is stale now — re-fetch. +* `"use"` (the default) serves a fresh entry if there is one, and stores the fetch if not. +* `"refresh"` never serves: it fetches and stores the result, replacing whatever was cached. +* `"bypass"` makes the round trip without touching the cache at all: no read, no write. -Against an **older server** (pre-2026 protocol), the fields are simply absent from the wire, and the models show their conservative defaults: `ttl_ms == 0`, `cache_scope == "private"` — stale and unshared, the right assumption for a server that declared nothing. If you need to distinguish "the server said 0" from "the server said nothing", check `"ttl_ms" in result.model_fields_set`: it's only set when the field actually arrived. +One rule sits above `"use"`: **calls carrying `meta` always reach the server.** A request with `meta` set (a progress token, tracing fields) expects a wire request, so under `cache_mode="use"` it is treated as `"refresh"`: the cache read is skipped, and the fetched result still replaces the cached entry. `"bypass"` and an explicit `"refresh"` behave as they always do. + +To turn caching off entirely, construct with `Client(server, cache=False)`: every call is a round trip again, and `cache_mode`, while still accepted, does nothing. + +Scope is honored automatically too: `"private"` entries are keyed to the cache's *partition* (below), while `"public"` ones may opt into wider sharing. And **notifications beat TTL** for the exact entries they name: a `list_changed` notification evicts the matching cached listing, and `resources/updated` evicts the cached read stored under exactly its URI, however fresh they were. + +One caveat on `resources/updated`: eviction is exact-URI only. The store contract has no enumerate or scan operation (same as the reference TypeScript implementation), so a notification carrying a *sub*-resource URI does not evict a cached read of its parent. If your server signals sub-resources this way, refetch the parent with `cache_mode="refresh"`. + +### Configuring it: `CacheConfig` + +```python +from mcp.client import CacheConfig + +client = Client("https://api.example.com/mcp", cache=CacheConfig(default_ttl_ms=5_000)) +``` + +* `store`: where entries live. The default is a fresh in-memory store per client; pass your own `ResponseCacheStore` implementation (Redis-backed, say) to share a cache across clients or processes. The contract types (`ResponseCacheStore`, `CacheKey`, `CacheEntry`, and the default `InMemoryResponseCacheStore`) are importable from `mcp.client`. A lookup may issue up to two sequential store `get`s (the private arm, then the public one), so size a remote store's latency expectations accordingly. A custom store **requires** an explicit `partition`. +* `partition`: the authorization-context label that keeps one principal's `"private"` entries from being served to another within a shared store. +* `target_id`: explicit server identity, for custom transports and in-process servers (below). +* `default_ttl_ms`: TTL applied to results that carry no `ttlMs` hint. The default `0` leaves hint-less results uncached. +* `share_public`: serve server-asserted-`"public"` entries across partitions (below). Off by default. +* `clock`: the wall-clock source, in epoch seconds. Inject one, as the example above does, and expiry tests need no sleeping. + +!!! warning "Partition = verified principal" + Derive `partition` from a **verified credential**, such as a validated token's subject. Never derive it from request-supplied data, and never from the server URL (server identity is a separate key axis). The SDK is a library with no authentication of its own: the trust anchor is whoever constructs the `CacheConfig`, which is the deployment, not the tenant. A multi-tenant gateway mints one `CacheConfig` per authenticated principal. + + The partition is also fixed for the `Client`'s lifetime. If the connection's authorization context changes mid-session (a re-authentication as a different principal, say), the cache does not follow; construct a new `Client` for the new principal. + +Cache keys also carry the **server's identity**: the URL string you dialed, with any `user:pass@` userinfo stripped and otherwise byte-exact. No case folding, no query reordering, no trailing-slash cleanup. Under-normalizing only costs sharing, while over-normalizing could merge two tenants (`?tenant=a` vs `?tenant=b`), so superficially different URLs simply don't share entries. When there is no URL (an in-process server, or a `Transport` instance), the client gets a random per-instance identity instead; set `CacheConfig.target_id` to name the server (with a custom store this is required, and construction says so). The identity is sha256-hashed before it enters key material, so a URL carrying secrets in its query string never appears in store keys. Don't log the pre-hash form yourself, either. + +!!! warning "`share_public` trusts the server, fleet-wide" + By default even `"public"` entries stay within their partition. `share_public=True` serves entries the server marked `cacheScope: "public"` to **every** partition using the store, trusting the server's classification on behalf of all of them. A server that stamps `"public"` on per-tenant data (by bug or by malice) then leaks one tenant's response to the others. The flag is deliberately constructor-level only: the per-call `cache_mode` can narrow caching, but nothing per-call can widen sharing. + +### What the cache never does + +* **Session-tier calls bypass it.** `client.session.list_tools()` and friends always make the round trip; the cache lives on the `Client` verbs. +* **`server/discover` stays out of it.** The discover result is delivered once, at connect, and never enters the response cache, even when it carries a `ttlMs`. If you persist one yourself to skip the reconnect probe ([`prior_discover`](../client/protocol-versions.md#reconnecting-with-prior_discover)), its freshness is your bookkeeping: `DiscoverResult` carries `ttl_ms` and `cache_scope`, already parsed, for exactly that purpose. +* **Continuation pages are never cached.** Only cursor-less calls participate. A continuation page rejected for an expired cursor does *evict* the cached listing, because the listing changed under it. +* **Multi-round-trip reads are never cached.** A `read_resource` seeded with `input_responses`/`request_state`, or one that resolves through input rounds, never enters the cache (a spec MUST). +* **Notification eviction needs notifications.** Eviction is only as good as the transport's delivery, and the modern in-process path (`Client(server)` with the default `mode="auto"`) does not deliver standalone notifications today. +* **Eviction is eventual, not instantaneous.** Wire-path notifications are dispatched from spawned tasks, so a call racing a notification's arrival may be served the pre-eviction entry once more; the window is bounded by dispatch latency, and the eviction still lands. +* **No stale-if-error.** An expired entry is never served because the refetch failed; the error propagates. +* **No early re-fetch.** A stored entry is served until its TTL expires and the next call after that pays the round trip; nothing refreshes in the background. +* **No coalescing.** Two concurrent identical calls are two fetches. +* **No TTL beyond 24 hours.** A larger `ttlMs`, whether server-sent or configured, is clamped down on store (`mcp.client.caching.MAX_TTL_MS`), bounding how long any entry, however generously hinted, can be served. +* On a **shared store**, clients race each other. Each client drops its own write when an eviction overtook the fetch in flight, but a *co-tenant* client can still write back an entry that an eviction it never saw had removed; and that race bookkeeping is itself bounded: past 4096 tracked keys the oldest key's guard is dropped first. Both windows are accepted, and closed by the TTL cap above. +* **No serving across protocol eras.** Entries are scoped to the negotiated protocol version: on a shared persistent store, a session never serves an entry written under a different negotiated version (the same listing genuinely differs by era, since the SDK strips the 2026 fields for older sessions). Eviction likewise touches only the current era's entries; another era's entries simply age out by TTL. + +### Reading the hints yourself + +The hints are also plain fields on every cacheable result (`result.ttl_ms` and `result.cache_scope`, already parsed), in case you want to layer your own bookkeeping on top of (or instead of) the built-in cache. + +Against an **older server** (pre-2026 protocol), the fields are simply absent from the wire, and the models show their conservative defaults: `ttl_ms == 0` and `cache_scope == "private"`, stale and unshared, the right assumption for a server that declared nothing. The cache treats a legacy session the same way: hints are never consulted there (whatever keys appear on the wire), only `default_ttl_ms` applies, and its default of `0` caches nothing, so a pre-2026 connection behaves exactly as it did before the cache existed. If you need to distinguish "the server said 0" from "the server said nothing", check `"ttl_ms" in result.model_fields_set`: it's only set when the field actually arrived. ## Older clients -Clients on pre-2026 protocol versions never see either field — the SDK strips them at serialization for those connections. Configure your hints once; there is nothing version-specific to write. +Clients on pre-2026 protocol versions never see either field; the SDK strips them at serialization for those connections. Configure your hints once; there is nothing version-specific to write. ## Recap -* Six methods carry `ttlMs`/`cacheScope`; the SDK defaults them to `0`/`"private"` — stale and unshared, always safe. +* Six methods carry `ttlMs`/`cacheScope`; the SDK defaults them to `0`/`"private"`, stale and unshared, always safe. * `cache_hints={method: CacheHint(...)}` at construction (both `MCPServer` and `Server`) sets server-wide values per method. * A handler that sets the fields on its result overrides the map, per field. * `"public"` is a promise that the result is identical for every caller. It is not access control. -* Clients read the hints as `result.ttl_ms` / `result.cache_scope` and own the caching decision themselves — the SDK has no built-in response cache yet. +* `Client` honors the hints automatically: its response cache is on by default, serves fresh entries instead of refetching, and caches nothing for servers (or sessions) that provide no hints. +* Per call, `cache_mode="refresh"` refetches and `"bypass"` skips the cache; `cache=False` at construction turns it off entirely. diff --git a/docs/migration.md b/docs/migration.md index 516cd8b18..047626ee2 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -427,6 +427,10 @@ On `ClientSession`, `call_tool` / `get_prompt` / `read_resource` still return th For protocol 2026-07-28 over Streamable HTTP, a tool's input-schema property may carry an `x-mcp-header` annotation. When a tool the client has listed is called, each annotated argument is mirrored into an `Mcp-Param-` request header (string verbatim, integer as decimal, boolean as `true`/`false`, base64-sentinel-wrapped when not header-safe; `null`/absent arguments are omitted). The argument is also left in the request body. `list_tools` caches a tool's annotations, so list a tool before calling it to enable mirroring; a tool the client never listed emits no `Mcp-Param-*` headers. Other transports ignore the annotation. +### `Client` verbs may serve cached responses ([SEP-2549](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2549)) + +On protocol 2026-07-28, servers attach caching hints (`ttlMs`, `cacheScope`) to the cacheable results, and `Client` now honors them: `list_tools`, `list_prompts`, `list_resources`, `list_resource_templates`, and `read_resource` may serve a cached response instead of making a round trip, for as long as the server's `ttlMs` says the result is fresh. With the default configuration, servers that send no hints, including every pre-2026 server, see identical call-for-call behavior, because hint-less results are not cached (a `CacheConfig.default_ttl_ms` above zero caches them too). Pass `Client(..., cache=False)` to disable the cache and restore v1 behavior exactly; per-call control (`cache_mode`) and configuration (`CacheConfig`) are described in [Caching hints](advanced/caching.md). + ### Server extensions API ([SEP-2133](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2133)) `MCPServer` now accepts opt-in extensions that bundle MCP behaviour behind a diff --git a/docs_src/caching/tutorial003.py b/docs_src/caching/tutorial003.py index 77ade546b..29c168c9f 100644 --- a/docs_src/caching/tutorial003.py +++ b/docs_src/caching/tutorial003.py @@ -1,15 +1,40 @@ +from dataclasses import dataclass +from typing import Any + +from mcp_types import ListToolsResult, PaginatedRequestParams, Tool + from mcp import Client -from mcp.server import CacheHint, MCPServer +from mcp.client import CacheConfig +from mcp.server import CacheHint, Server, ServerRequestContext + + +@dataclass +class DemoState: + fetches: int = 0 + now: float = 1_000_000.0 + + +state = DemoState() + -mcp = MCPServer("Weather", cache_hints={"tools/list": CacheHint(ttl_ms=60_000, scope="public")}) +async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + state.fetches += 1 + return ListToolsResult(tools=[Tool(name="forecast", input_schema={"type": "object"})]) -@mcp.tool() -def forecast(city: str) -> str: - return f"Sunny in {city}" +server = Server( + "Weather", + on_list_tools=list_tools, + cache_hints={"tools/list": CacheHint(ttl_ms=60_000, scope="public")}, +) async def main() -> None: - async with Client(mcp) as client: - tools = await client.list_tools() - print(f"{len(tools.tools)} tools, fresh for {tools.ttl_ms / 1000:.0f}s, scope={tools.cache_scope}") + start = state.fetches + async with Client(server, cache=CacheConfig(clock=lambda: state.now)) as client: + await client.list_tools() # fetch 1 + await client.list_tools() # fresh for 60s: served from the cache + state.now += 60.0 + await client.list_tools() # the TTL ran out: fetch 2 + await client.list_tools(cache_mode="refresh") # skip the cache read: fetch 3 + print(f"4 calls, {state.fetches - start} fetches") diff --git a/src/mcp-types/mcp_types/methods.py b/src/mcp-types/mcp_types/methods.py index 824dcfdfe..f49c158d9 100644 --- a/src/mcp-types/mcp_types/methods.py +++ b/src/mcp-types/mcp_types/methods.py @@ -13,7 +13,7 @@ from collections.abc import Mapping from functools import cache from types import MappingProxyType, UnionType -from typing import Any, Final, TypeVar +from typing import Any, Final, Literal, TypeVar, get_args from pydantic import BaseModel, TypeAdapter @@ -23,9 +23,11 @@ from mcp_types.version import KNOWN_PROTOCOL_VERSIONS __all__ = [ + "CACHEABLE_METHODS", "CLIENT_NOTIFICATIONS", "CLIENT_REQUESTS", "CLIENT_RESULTS", + "CacheableMethod", "MONOLITH_NOTIFICATIONS", "MONOLITH_REQUESTS", "MONOLITH_RESULTS", @@ -404,6 +406,24 @@ """Monolith result model (or two-arm union) per request method.""" +CacheableMethod = Literal[ + "prompts/list", + "resources/list", + "resources/read", + "resources/templates/list", + "server/discover", + "tools/list", +] +"""Methods whose results carry `ttlMs`/`cacheScope`; hand-written Literal, welded to `CACHEABLE_METHODS` by tests.""" + +CACHEABLE_METHODS: Final[frozenset[str]] = frozenset( + method + for method, row in MONOLITH_RESULTS.items() + if any(issubclass(arm, types.CacheableResult) for arm in (get_args(row) if isinstance(row, UnionType) else (row,))) +) +"""Runtime mirror of `CacheableMethod`, derived from `MONOLITH_RESULTS`.""" + + # --- Parse functions --- # Envelope stubs merged into bodies for surface validation (surface classes are full frames). diff --git a/src/mcp/client/__init__.py b/src/mcp/client/__init__.py index f9f732ad9..b7823f5ef 100644 --- a/src/mcp/client/__init__.py +++ b/src/mcp/client/__init__.py @@ -2,8 +2,28 @@ from mcp.client._input_required import InputRequiredRoundsExceededError from mcp.client._transport import Transport +from mcp.client.caching import ( + CacheConfig, + CacheEntry, + CacheKey, + CacheMode, + InMemoryResponseCacheStore, + ResponseCacheStore, +) from mcp.client.client import Client from mcp.client.context import ClientRequestContext from mcp.client.session import ClientSession -__all__ = ["Client", "ClientRequestContext", "ClientSession", "InputRequiredRoundsExceededError", "Transport"] +__all__ = [ + "CacheConfig", + "CacheEntry", + "CacheKey", + "CacheMode", + "Client", + "ClientRequestContext", + "ClientSession", + "InMemoryResponseCacheStore", + "InputRequiredRoundsExceededError", + "ResponseCacheStore", + "Transport", +] diff --git a/src/mcp/client/caching.py b/src/mcp/client/caching.py new file mode 100644 index 000000000..a464accd1 --- /dev/null +++ b/src/mcp/client/caching.py @@ -0,0 +1,387 @@ +"""Client-side response caching primitives (SEP-2549, protocol revision 2026-07-28).""" + +from __future__ import annotations + +import json +import logging +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Final, Literal, Protocol + +import anyio +import anyio.lowlevel +from mcp_types import ( + CacheableResult, + PromptListChangedNotification, + ResourceListChangedNotification, + ResourceUpdatedNotification, + ServerNotification, + ToolListChangedNotification, +) +from mcp_types.version import MODERN_PROTOCOL_VERSIONS + +__all__ = [ + "MAX_TTL_MS", + "CacheConfig", + "CacheEntry", + "CacheKey", + "CacheMode", + "InMemoryResponseCacheStore", + "ResponseCacheStore", +] + +logger = logging.getLogger(__name__) + +CacheMode = Literal["use", "refresh", "bypass"] +"""Per-call cache behavior: `"use"` serves and stores, `"refresh"` stores +without serving, `"bypass"` skips the cache entirely.""" + +MAX_TTL_MS: Final[int] = 24 * 60 * 60 * 1000 +"""Cap on any entry's time-to-live (24 hours, in milliseconds); larger `ttlMs` values are clamped down.""" + + +@dataclass(frozen=True, slots=True) +class CacheKey: + """Identity of one cached response; compare as the field tuple, never a flattened string (collision hazard).""" + + method: str + + params_key: str = "" + """Result-affecting params discriminator: the uri for `resources/read`, `""` for the list methods.""" + + partition: str = "" + """Coordinator-computed arm identifier; opaque to stores.""" + + +@dataclass(frozen=True, slots=True) +class CacheEntry: + """One cached response with its freshness and sharing metadata.""" + + value: Any + """The cached result; the SDK deep-copies on write and on serve, so a store may hold it as-is.""" + + scope: Literal["public", "private"] + """Server-asserted `cacheScope`: only `"public"` entries may be shared across authorization contexts.""" + + expires_at: float | None + """Epoch seconds after which the entry is stale; `None` is never fresh.""" + + +class ResponseCacheStore(Protocol): + """Storage contract for the client response cache. + + Each `Client` calls its store from a single event loop; per-operation + atomicity is the implementation's responsibility. Operations may raise - + the SDK degrades to a miss rather than failing the call. A serializing + store must round-trip `value` back to the result model object (a + wrong-shape entry is a miss, never an error). A lookup may issue two + sequential `get` calls (private arm, then public). + """ + + async def get(self, key: CacheKey) -> CacheEntry | None: ... + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: ... + + async def delete(self, key: CacheKey) -> None: ... + + async def clear(self) -> None: ... + + +@dataclass(frozen=True, slots=True) +class CacheConfig: + """Configuration for a `Client`'s response cache. + + Raises: + ValueError: On a custom `store` without `partition`, an empty `target_id`, or a negative `default_ttl_ms`. + """ + + store: ResponseCacheStore | None = None + """Backing store; `None` means a per-client `InMemoryResponseCacheStore`. + A custom store requires an explicit `partition`.""" + + partition: str = "" + """Authorization-context identifier isolating `"private"`-scoped entries + within a shared store. Derive it from a verified credential - never from + request-supplied data or the server URL. Fixed for the `Client`'s + lifetime: construct a new `Client` when the principal changes.""" + + target_id: str | None = None + """Server-identity override for custom transports and proxies where the + SDK cannot derive one from a URL; must be non-empty when provided.""" + + default_ttl_ms: int = 0 + """TTL in milliseconds for results carrying no `ttlMs` hint; the default `0` leaves them uncached.""" + + clock: Callable[[], float] = time.time + """Wall-clock source returning epoch seconds; injectable for expiry tests.""" + + share_public: bool = False + """Serve server-marked `"public"` entries across every partition in the store. + + WARNING: this trusts the server's `"public"` classification for every + principal sharing the store - a mislabeled response leaks across tenants. + Constructor-level only: the per-call `cache_mode` can never widen sharing.""" + + def __post_init__(self) -> None: + if self.store is not None and not self.partition: + raise ValueError("a custom store requires an explicit partition") + if self.target_id == "": + raise ValueError("target_id must be a non-empty string or omitted") + if self.default_ttl_ms < 0: + raise ValueError(f"default_ttl_ms must be >= 0, got {self.default_ttl_ms}") + + +class InMemoryResponseCacheStore: + """Default in-process `ResponseCacheStore`. + + Method bodies are synchronous, so concurrent tasks never observe a torn + write. `max_entries` caps the whole store, evicting least-recently-used + at the cap (`0` disables it); `get` and `set` both refresh recency, so a + hot entry survives churn from other keys. + + Raises: + ValueError: If `max_entries` is negative. + """ + + def __init__(self, *, max_entries: int = 1024) -> None: + if max_entries < 0: + raise ValueError(f"max_entries must be >= 0, got {max_entries}") + self._max_entries = max_entries + self._entries: dict[CacheKey, CacheEntry] = {} + + async def get(self, key: CacheKey) -> CacheEntry | None: + entry = self._entries.get(key) + if entry is not None: + # Pop-and-reinsert moves the key to the back: the dict's insertion order is the LRU ledger. + self._entries[key] = self._entries.pop(key) + return entry + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: + self._entries.pop(key, None) + self._entries[key] = entry + if self._max_entries and len(self._entries) > self._max_entries: + del self._entries[next(iter(self._entries))] + + async def delete(self, key: CacheKey) -> None: + self._entries.pop(key, None) + + async def clear(self) -> None: + self._entries.clear() + + +_GENERATION_MAP_CAP: Final[int] = 4096 +"""Cap on the generation map; at the cap the oldest key's eviction-race guard is dropped (FIFO).""" + +_STORE_CLEANUP_TIMEOUT: Final[float] = 5 +"""Bound for must-complete store cleanup deletes (mirrors the dispatcher's final-write bound); +a wedged store delete must not hold client teardown uncancellably.""" + + +class ClientResponseCache: + """Coordinates the `Client` caching verbs with a `ResponseCacheStore`: keys, era gate, TTL/scope, eviction.""" + + def __init__( + self, + *, + store: ResponseCacheStore, + partition: str, + arm_id: str, + default_ttl_ms: int, + clock: Callable[[], float], + share_public: bool, + negotiated_version: Callable[[], str | None], + generation_map_cap: int = _GENERATION_MAP_CAP, + store_cleanup_timeout: float = _STORE_CLEANUP_TIMEOUT, + ) -> None: + self._store = store + self._partition = partition + self._arm_id = arm_id + self._share_public = share_public + self._default_ttl_ms = default_ttl_ms + self._clock = clock + self._negotiated_version = negotiated_version + # A key is eviction-race-guarded iff registered here. + self._generations: dict[tuple[str, str], int] = {} + self._generation_map_cap = generation_map_cap + self._store_cleanup_timeout = store_cleanup_timeout + self._warned_store_ops: set[str] = set() + + def _arm(self, scope: Literal["public", "private"]) -> str: + # JSON arrays so crafted arm_id/partition values cannot collide across field boundaries. + # The negotiated version era-scopes every arm: a session never serves an entry written + # under a different protocol era (its content differs - sieve-stripped fields, header + # filtering). Every caller runs post-connect; were that ever untrue, the supplier's + # None still partitions harmlessly. + fields: list[str | None] = [scope, self._negotiated_version(), self._arm_id] + if scope == "private" or not self._share_public: + fields.append(self._partition) + return json.dumps(fields) + + async def read(self, method: str, params_key: str) -> CacheableResult | None: + """Serve a fresh entry for the key, or `None`; the served result is a deep copy.""" + # A hit completes without any other yielding await, so checkpoint here: a poll + # loop over a fresh entry must not starve spawned tasks (eviction dispatch). + await anyio.lowlevel.checkpoint() + # A wrong-shape entry raises as late as the copy, so the boundary wraps the whole read path. + try: + entry = await self._get_fresh(CacheKey(method, params_key, self._arm("private"))) + if entry is None: + # After a scope flip, a stale private entry must not shadow a fresh public one. + entry = await self._get_fresh(CacheKey(method, params_key, self._arm("public"))) + if entry is not None and entry.scope != "public": + # Never serve an entry the server scoped "private" out of the shared arm. + entry = None + copied: CacheableResult | None = None if entry is None else entry.value.model_copy(deep=True) + except Exception: # boundary around user store code: any read-path failure is a miss, never a failed call + self._warn_store_failure("get") + return None + self._warned_store_ops.discard("get") + return copied + + async def _get_fresh(self, key: CacheKey) -> CacheEntry | None: + entry = await self._store.get(key) + if entry is None or entry.expires_at is None or entry.expires_at <= self._clock(): + return None + return entry + + def capture(self, method: str, params_key: str) -> int: + """Register the key for eviction-race detection before the fetch; `write` takes the returned generation.""" + gen_key = (method, params_key) + if gen_key not in self._generations: + if len(self._generations) >= self._generation_map_cap: + # FIFO overflow: the dropped key's race guard degrades to the accepted co-tenant class. + del self._generations[next(iter(self._generations))] + self._generations[gen_key] = 0 + return self._generations[gen_key] + + async def write( + self, + method: str, + params_key: str, + result: CacheableResult, + gen_at_capture: int, + mode: Literal["use", "refresh"], + ) -> None: + """Store a fetched result under the arm its resolved scope selects.""" + gen_key = (method, params_key) + if self._generation_moved(gen_key, gen_at_capture): + return # the key was evicted while the fetch was in flight + ttl_ms, scope = self._resolve(result) + private_key = CacheKey(method, params_key, self._arm("private")) + public_key = CacheKey(method, params_key, self._arm("public")) + if ttl_ms <= 0: + if mode == "refresh": + # The refetch superseded the warm entry, which a cancellation must not leave serving. + await self._cleanup_delete(private_key, public_key) + return + own, opposite = (public_key, private_key) if scope == "public" else (private_key, public_key) + # Opposite arm first: a failed delete aborts before the set - never two arms answering for one key. + if not await self._delete(opposite): + # The own arm's entry is superseded too: best-effort delete, degrading to a full miss. + await self._cleanup_delete(own) + return + entry = CacheEntry(value=result.model_copy(deep=True), scope=scope, expires_at=self._clock() + ttl_ms / 1000) + try: + if not await self._set(own, entry): + # The fetch superseded any pre-existing own-arm entry, and the failed set + # left it in place: purge it (mirrors the opposite-arm-failure path). + await self._cleanup_delete(own) + finally: + # An eviction can land while the set commits - even when the await + # is cancelled - so re-check on every exit; the delete must complete + # so the pending cancellation cannot resurrect the evicted entry. + if self._generation_moved(gen_key, gen_at_capture): + await self._cleanup_delete(own) + + async def evict_method(self, method: str) -> None: + """Evict the method's cursor-less entry.""" + await self.evict_key(method, "") + + async def evict_key(self, method: str, params_key: str) -> None: + """Evict one key from both arms. + + Only the current era's arms are touched; other-era entries in a persistent store age out by TTL. + """ + gen_key = (method, params_key) + # Bump first so an in-flight fetch cannot write the evicted entry back. + # Unregistered keys skip the bump (uris must not grow the map) but not + # the deletes - a persistent store may hold uncaptured entries. + if gen_key in self._generations: + self._generations[gen_key] += 1 + # Must complete: a cancellation between the deletes would leave one arm serving the evicted entry. + await self._cleanup_delete( + CacheKey(method, params_key, self._arm("private")), + CacheKey(method, params_key, self._arm("public")), + ) + + async def evict_for_notification(self, notification: ServerNotification) -> None: + """Map a server notification to the entries it makes stale. + + Eviction is eventual (spawned-task dispatch): the generation bump closes + the write-back race; a racing read may briefly serve the old entry. + """ + match notification: + case ToolListChangedNotification(): + await self.evict_method("tools/list") + case PromptListChangedNotification(): + await self.evict_method("prompts/list") + case ResourceListChangedNotification(): + # Templates enumerate the same changed resource space. + await self.evict_method("resources/list") + await self.evict_method("resources/templates/list") + case ResourceUpdatedNotification(): + await self.evict_key("resources/read", notification.params.uri) + case _: + pass + + def _resolve(self, result: CacheableResult) -> tuple[int, Literal["public", "private"]]: + # A legacy peer can also put `ttlMs`/`cacheScope` keys on the wire, so + # wire presence is not a peer-era signal - hints count only when modern. + modern = self._negotiated_version() in MODERN_PROTOCOL_VERSIONS + if modern and "ttl_ms" in result.model_fields_set: + # An explicit `ttlMs: 0` stays 0, and negatives are unconstructible + # upstream (model ge=0, parse-seam floor) - only the cap applies. + ttl_ms = result.ttl_ms + else: + ttl_ms = self._default_ttl_ms + scope: Literal["public", "private"] = "public" if modern and result.cache_scope == "public" else "private" + return min(ttl_ms, MAX_TTL_MS), scope + + def _generation_moved(self, gen_key: tuple[str, str], gen_at_capture: int) -> bool: + # A FIFO-dropped key fails open (the accepted co-tenant race) rather than discarding the fetch. + return self._generations.get(gen_key, gen_at_capture) != gen_at_capture + + async def _set(self, key: CacheKey, entry: CacheEntry) -> bool: + try: + await self._store.set(key, entry) + except Exception: # boundary around user store code: nothing cached, the fetch already succeeded + self._warn_store_failure("set") + return False + self._warned_store_ops.discard("set") + return True + + async def _cleanup_delete(self, *keys: CacheKey) -> None: + # Must-complete cleanup: shielded so a pending cancellation cannot skip the deletes, + # bounded so a wedged store delete cannot hold client teardown uncancellably. + with anyio.move_on_after(self._store_cleanup_timeout, shield=True) as scope: + for key in keys: + await self._delete(key) + if scope.cancelled_caught: + logger.warning("Response cache store delete timed out; the entry will age out by TTL") + + async def _delete(self, key: CacheKey) -> bool: + try: + await self._store.delete(key) + except Exception: # boundary around user store code: callers decide whether a failed delete aborts + self._warn_store_failure("delete") + return False + self._warned_store_ops.discard("delete") + return True + + def _warn_store_failure(self, kind: Literal["get", "set", "delete"]) -> None: + # One warning per failure burst, per op kind; re-armed only when that + # same kind succeeds, so a healthy delete cannot re-arm a broken set. + if kind not in self._warned_store_ops: + self._warned_store_ops.add(kind) + logger.warning("Response cache store operation failed; continuing without the cache", exc_info=True) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index d3290f308..638ea63a9 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -2,14 +2,20 @@ from __future__ import annotations +import hashlib +import logging +import uuid from collections.abc import Awaitable, Callable, Mapping from contextlib import AsyncExitStack from dataclasses import KW_ONLY, dataclass, field -from typing import Any, Literal, TypeVar +from typing import Any, Literal, TypeVar, cast import anyio +import anyio.lowlevel import mcp_types as types from mcp_types import ( + INVALID_PARAMS, + CacheableResult, CallToolResult, CompleteResult, EmptyResult, @@ -39,6 +45,7 @@ from mcp.client._memory import InMemoryTransport from mcp.client._probe import negotiate_auto from mcp.client._transport import Transport +from mcp.client.caching import CacheConfig, CacheMode, ClientResponseCache, InMemoryResponseCacheStore from mcp.client.session import ( ClientRequestContext, ClientSession, @@ -54,8 +61,11 @@ from mcp.server.runner import modern_on_request from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import Dispatcher, ProgressFnT -from mcp.shared.exceptions import MCPDeprecationWarning +from mcp.shared.exceptions import MCPDeprecationWarning, MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.session import RequestResponder + +logger = logging.getLogger(__name__) ConnectMode = Literal["legacy", "auto"] | str """``mode=`` value: ``"legacy"`` (initialize handshake), ``"auto"`` (discover, fall back to @@ -64,6 +74,7 @@ _T = TypeVar("_T") _ResultT = TypeVar("_ResultT") +_CacheableT = TypeVar("_CacheableT", bound=CacheableResult) _Connector = Callable[[AsyncExitStack, ConnectMode, bool], Awaitable["Dispatcher[Any]"]] """Resolved at ``__post_init__`` from the shape of ``server`` alone: enter whatever resources @@ -115,6 +126,46 @@ def _connected(value: _T | None) -> _T: return value +def _strip_userinfo(url: str) -> str: + """Drop any userinfo from the URL's authority component; byte-exact otherwise. + + Credentials must not enter cache-key material; any further normalization could merge distinct servers. + """ + # Pure text, no urlsplit: it strips embedded tab/CR/LF before parsing, which would misalign slices. + sep = url.find("//") + if sep == -1: + return url + start = sep + 2 + end = len(url) + for delimiter in "/?#": + if (found := url.find(delimiter, start)) != -1: + end = min(end, found) + authority = url[start:end] + if "@" not in authority: + return url + return url[:start] + authority.rpartition("@")[2] + url[end:] + + +def _evicting_message_handler(cache: ClientResponseCache, user_handler: MessageHandlerFnT | None) -> MessageHandlerFnT: + """Wrap the session message handler with cache eviction on server notifications.""" + + async def handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + try: + await cache.evict_for_notification(message) + except Exception: # boundary: eviction reaches user store code; a cache fault must not block delivery + logger.exception("Response cache eviction failed; the notification is still delivered") + if user_handler is not None: + await user_handler(message) + else: + # Mirrors ClientSession's default handler (session._default_message_handler). + await anyio.lowlevel.checkpoint() + + return handler + + def _synthesize_discover(protocol_version: str) -> types.DiscoverResult: return types.DiscoverResult( supported_versions=[protocol_version], @@ -221,10 +272,20 @@ async def main(): """SEP-2133 extension support to advertise under `ClientCapabilities.extensions` (identifier -> settings), e.g. `{"io.modelcontextprotocol/ui": {"mimeTypes": [...]}}`.""" + cache: CacheConfig | Literal[False] | None = None + """Client-side response caching for the SEP-2549 cacheable methods (2026-07-28). + + `None` (the default) honors server `ttlMs`/`cacheScope` hints with a per-client + in-memory store; pass a `CacheConfig` to customize, or `False` to disable. The + cacheable verbs take a per-call `cache_mode` (see `CacheMode`); calls carrying + `meta` always reach the server. A `CacheConfig` with a custom `store` requires + `target_id` when the server is not a URL (no identity can be derived).""" + _entered: bool = field(init=False, default=False) _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _connect: _Connector = field(init=False, repr=False, compare=False) + _response_cache: ClientResponseCache | None = field(init=False, default=None, repr=False, compare=False) def __post_init__(self) -> None: if self.mode not in ("legacy", "auto") and self.mode not in MODERN_PROTOCOL_VERSIONS: @@ -247,16 +308,44 @@ def __post_init__(self) -> None: else: self._connect = _connect_transport(srv) + if self.cache is not False: + config = self.cache if self.cache is not None else CacheConfig() + # Only the hash below leaves this scope - the raw identity may carry credentials; never log or store it. + target_id = config.target_id + if target_id is None and isinstance(self.server, str): + target_id = _strip_userinfo(self.server) + if target_id is None: + if config.store is not None: + raise ValueError( + "a custom cache store requires CacheConfig.target_id when the server is not a URL: " + "in-process servers and Transport instances get a random per-client identity, so " + "their entries in a shared store could never be served to another client" + ) + target_id = uuid.uuid4().hex + self._response_cache = ClientResponseCache( + store=config.store if config.store is not None else InMemoryResponseCacheStore(), + partition=config.partition, + arm_id=hashlib.sha256(target_id.encode()).hexdigest(), + default_ttl_ms=config.default_ttl_ms, + clock=config.clock, + share_public=config.share_public, + # Lazy: the negotiated version is unknown until __aenter__'s handshake. + negotiated_version=lambda: self._session.protocol_version if self._session is not None else None, + ) + async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: """Enter the resolved connector and return an un-entered ClientSession.""" dispatcher = await self._connect(exit_stack, self.mode, self.raise_exceptions) + message_handler = self.message_handler + if self._response_cache is not None: + message_handler = _evicting_message_handler(self._response_cache, self.message_handler) return ClientSession( dispatcher=dispatcher, read_timeout_seconds=self.read_timeout_seconds, sampling_callback=self.sampling_callback, list_roots_callback=self.list_roots_callback, logging_callback=self.logging_callback, - message_handler=self.message_handler, + message_handler=message_handler, client_info=self.client_info, elicitation_callback=self.elicitation_callback, extensions=self.extensions, @@ -361,23 +450,76 @@ async def set_logging_level(self, level: LoggingLevel, *, meta: RequestParamsMet """Set the logging level on the server.""" return await self.session.set_logging_level(level=level, meta=meta) # pyright: ignore[reportDeprecated] + async def _cached_fetch( + self, + method: str, + *, + cursor: str | None, + meta: RequestParamsMeta | None, + cache_mode: CacheMode, + send: Callable[[], Awaitable[_CacheableT]], + absorb: Callable[[_CacheableT], _CacheableT] | None = None, + ) -> _CacheableT: + """Serve one of the four list verbs through the response cache. + + `absorb` (tools/list only) re-applies session-side derived state to a served cache hit. + """ + cache = self._response_cache + if cache is None or cache_mode == "bypass": + return await send() + # A closed (or never-entered) client must raise, never serve cached entries. + _ = self.session + if meta is not None and cache_mode == "use": + # meta (a progress token, tracing fields) expects a wire request; fetch and replace the entry. + cache_mode = "refresh" + if cursor is not None: + # Continuation pages skip the cache, but an expired cursor means the listing changed (spec SHOULD evict). + try: + return await send() + except MCPError as e: + if e.code == INVALID_PARAMS: + await cache.evict_method(method) + raise + if cache_mode == "use" and (hit := await cache.read(method, "")) is not None: + # The hit is a private deep copy, so absorption may mutate it freely. + served = cast(_CacheableT, hit) + return served if absorb is None else absorb(served) + gen = cache.capture(method, "") + result = await send() + await cache.write(method, "", result, gen, cache_mode) + return result + async def list_resources( self, *, cursor: str | None = None, meta: RequestParamsMeta | None = None, + cache_mode: CacheMode = "use", ) -> ListResourcesResult: """List available resources from the server.""" - return await self.session.list_resources(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) + return await self._cached_fetch( + "resources/list", + cursor=cursor, + meta=meta, + cache_mode=cache_mode, + send=lambda: self.session.list_resources(params=PaginatedRequestParams(cursor=cursor, _meta=meta)), + ) async def list_resource_templates( self, *, cursor: str | None = None, meta: RequestParamsMeta | None = None, + cache_mode: CacheMode = "use", ) -> ListResourceTemplatesResult: """List available resource templates from the server.""" - return await self.session.list_resource_templates(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) + return await self._cached_fetch( + "resources/templates/list", + cursor=cursor, + meta=meta, + cache_mode=cache_mode, + send=lambda: self.session.list_resource_templates(params=PaginatedRequestParams(cursor=cursor, _meta=meta)), + ) async def read_resource( self, @@ -386,6 +528,7 @@ async def read_resource( input_responses: InputResponses | None = None, request_state: str | None = None, meta: RequestParamsMeta | None = None, + cache_mode: CacheMode = "use", ) -> ReadResourceResult: """Read a resource from the server. @@ -400,6 +543,8 @@ async def read_resource( resuming from a persisted `InputRequiredResult`). request_state: Opaque state to seed the first call with. meta: Additional metadata for the request. + cache_mode: Cache behavior for this call (see `CacheMode`); seeded + calls (`input_responses` or `request_state` set) ignore it. Returns: The resource content. @@ -414,7 +559,29 @@ async def retry(r: InputResponses | None, s: str | None) -> ReadResourceResult | uri, input_responses=r, request_state=s, meta=meta, allow_input_required=True ) - return await self._drive_input_required(await retry(input_responses, request_state), retry) + # Seeded calls resume a specific exchange and must never be cached (spec MUST). + seeded = input_responses is not None or request_state is not None + cache = None if seeded else self._response_cache + if cache is None or cache_mode == "bypass": + return await self._drive_input_required(await retry(input_responses, request_state), retry) + # A closed (or never-entered) client must raise, never serve cached entries. + _ = self.session + if meta is not None and cache_mode == "use": + # Calls carrying meta always reach the server (mirrors `_cached_fetch`). + cache_mode = "refresh" + if cache_mode == "use" and (hit := await cache.read("resources/read", uri)) is not None: + # Only terminal first-round results are stored, so a hit legitimately skips the driver. + return cast(ReadResourceResult, hit) + gen = cache.capture("resources/read", uri) + first = await retry(None, None) + if not isinstance(first, InputRequiredResult): + await cache.write("resources/read", uri, first, gen, cache_mode) + elif cache_mode == "refresh": + # The refresh superseded whatever was cached, but an input_required resolution + # cannot be stored: purge the warm entry so it cannot be served again. + await cache.evict_key("resources/read", uri) + # Driver rounds carry inputResponses, so a terminal result reached through them is never cached (spec MUST). + return await self._drive_input_required(first, retry) async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> EmptyResult: """Subscribe to resource updates.""" @@ -481,9 +648,16 @@ async def list_prompts( *, cursor: str | None = None, meta: RequestParamsMeta | None = None, + cache_mode: CacheMode = "use", ) -> ListPromptsResult: """List available prompts from the server.""" - return await self.session.list_prompts(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) + return await self._cached_fetch( + "prompts/list", + cursor=cursor, + meta=meta, + cache_mode=cache_mode, + send=lambda: self.session.list_prompts(params=PaginatedRequestParams(cursor=cursor, _meta=meta)), + ) async def get_prompt( self, @@ -565,9 +739,27 @@ async def complete( """ return await self.session.complete(ref=ref, argument=argument, context_arguments=context_arguments) - async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta | None = None) -> ListToolsResult: + async def list_tools( + self, + *, + cursor: str | None = None, + meta: RequestParamsMeta | None = None, + cache_mode: CacheMode = "use", + ) -> ListToolsResult: """List available tools from the server.""" - return await self.session.list_tools(params=PaginatedRequestParams(cursor=cursor, _meta=meta)) + return await self._cached_fetch( + "tools/list", + cursor=cursor, + meta=meta, + cache_mode=cache_mode, + send=lambda: self.session.list_tools(params=PaginatedRequestParams(cursor=cursor, _meta=meta)), + # A cache hit skips session.list_tools, so the session re-absorbs the served + # listing to rebuild its derived per-tool state. Hits are cursorless, but a + # cached page 1 can carry next_cursor - never prune on a partial listing. + absorb=lambda hit: self.session._absorb_tool_listing( # pyright: ignore[reportPrivateUsage] + hit, complete=hit.next_cursor is None + ), + ) @deprecated("The roots capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) async def send_roots_list_changed(self) -> None: diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3cebb569e..6a2298ad9 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -55,6 +55,13 @@ logger = logging.getLogger("client") +def _clamp_inbound_ttl(raw: dict[str, Any]) -> None: + """Floor a negative inbound `ttlMs` to 0 before `ge=0` validation fails the call (2026-07-28 caching SHOULD).""" + ttl = raw.get("ttlMs") + if isinstance(ttl, int | float) and not isinstance(ttl, bool) and ttl < 0: + raw["ttlMs"] = 0 + + def _preconnect_stamp(data: dict[str, Any], opts: CallOptions) -> None: # initialize/discover forbid cancellation; other pre-handshake requests (lowlevel # ClientSession callers may skip the handshake entirely) keep the courtesy cancel. @@ -331,6 +338,7 @@ async def send_request( if metadata.on_resumption_token_update is not None: opts["on_resumption_token"] = metadata.on_resumption_token_update raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts) + _clamp_inbound_ttl(raw) # Literal fallback covers pre-handshake and stateless; matches runner.py. version = self._negotiated_version or "2025-11-25" try: @@ -458,7 +466,10 @@ async def send_discover(self, version: str) -> dict[str, Any]: "cancel_on_abandon": False, "headers": {MCP_PROTOCOL_VERSION_HEADER: version, MCP_METHOD_HEADER: data["method"]}, } - return await self._dispatcher.send_raw_request(data["method"], data.get("params"), opts) + raw = await self._dispatcher.send_raw_request(data["method"], data.get("params"), opts) + # Un-floored, a negative ttl fails the mode='auto' probe's validation and silently downgrades the handshake. + _clamp_inbound_ttl(raw) + return raw async def discover(self) -> types.DiscoverResult: """Probe `server/discover` and adopt the result. @@ -895,7 +906,15 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None types.ListToolsRequest(params=params), types.ListToolsResult, ) + complete = (params is None or params.cursor is None) and result.next_cursor is None + return self._absorb_tool_listing(result, complete=complete) + + def _absorb_tool_listing(self, result: types.ListToolsResult, *, complete: bool) -> types.ListToolsResult: + """Filter the listing per the 2026 x-mcp-header MUST and rebuild derived per-tool state, in place. + Idempotent: cached values are already post-filter, so the response cache can re-absorb a served listing. + `complete` (an uncursored single-page listing) prunes per-tool state down to the listing's tools. + """ if self._negotiated_version in MODERN_PROTOCOL_VERSIONS: # 2026-07-28: clients MUST drop tools whose x-mcp-header annotations are invalid. kept: list[types.Tool] = [] @@ -911,11 +930,17 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None kept.append(tool) result.tools = kept - # Cache tool output schemas for future validation - # Note: don't clear the cache, as we may be using a cursor + # Cache tool output schemas for future validation; cursor pages only ever add. for tool in result.tools: self._tool_output_schemas[tool.name] = tool.output_schema + if complete: + # The listing is the full tool universe, so state for unlisted tools is stale + # (the server dropped them, or a shared-cache writer's filter did). + names = {tool.name for tool in result.tools} + self._x_mcp_header_maps = {k: v for k, v in self._x_mcp_header_maps.items() if k in names} + self._tool_output_schemas = {k: v for k, v in self._tool_output_schemas.items() if k in names} + return result @deprecated("The roots capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) diff --git a/src/mcp/server/caching.py b/src/mcp/server/caching.py index a8a2a470c..5e9930315 100644 --- a/src/mcp/server/caching.py +++ b/src/mcp/server/caching.py @@ -11,27 +11,13 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Final, Literal, TypeVar, get_args +from typing import Any, Literal, TypeVar import mcp_types as types +from mcp_types.methods import CACHEABLE_METHODS, CacheableMethod __all__ = ["CACHEABLE_METHODS", "CacheHint", "CacheableMethod", "apply_cache_hint", "validate_cache_hints"] -CacheableMethod = Literal[ - "prompts/list", - "resources/list", - "resources/read", - "resources/templates/list", - "server/discover", - "tools/list", -] -"""The methods whose results carry `ttlMs`/`cacheScope`. Closed set: the spec -defines caching hints on exactly these six (tests pin it to which result models -mix in `CacheableResult`).""" - -CACHEABLE_METHODS: Final[frozenset[str]] = frozenset(get_args(CacheableMethod)) -"""Runtime mirror of `CacheableMethod`, for callers the type checker can't see.""" - @dataclass(frozen=True, slots=True) class CacheHint: @@ -87,7 +73,8 @@ def validate_cache_hints(cache_hints: Mapping[Any, Any] | None) -> dict[str, Cac """ if cache_hints is None: return {} - unknown = sorted(method for method in cache_hints if method not in CACHEABLE_METHODS) + # repr-format keys so a non-string key raises this ValueError, not a TypeError from sorted/join. + unknown = sorted(repr(method) for method in cache_hints if method not in CACHEABLE_METHODS) if unknown: raise ValueError(f"cache_hints keys must be cacheable methods (see CacheableMethod); got: {', '.join(unknown)}") validated: dict[str, CacheHint] = {} diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 4c25a8a5b..6773fd4de 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -198,12 +198,15 @@ async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult: if isinstance(result, ErrorData): # Raise inside the chain so middleware observes the failure. raise MCPError.from_error_data(result) - # Fill cache hints on the typed result, before the serialize sieve + # Fill cache hints on the handler result, before the serialize sieve # decides whether the negotiated version carries the fields at all. - # `input_required` interim results are not `CacheableResult` models, - # so the MRTR carve-out (no hints on them) holds by shape. - if isinstance(result, CacheableResult) and (hint := self.server.cache_hints.get(method)) is not None: - result = apply_cache_hint(result, hint) + # MRTR carve-out: `input_required` interim results, typed or mapping, never get hints. + if (hint := self.server.cache_hints.get(method)) is not None: + if isinstance(result, CacheableResult): + result = apply_cache_hint(result, hint) + elif isinstance(result, Mapping) and result.get("resultType") != "input_required": + # Hint keys first so wire keys the handler set win, matching `apply_cache_hint` precedence. + result = {"ttlMs": hint.ttl_ms, "cacheScope": hint.scope, **result} # Dump and serialize inside the chain so the OpenTelemetry span (the # outermost middleware) records a failing handler return shape too. return self._serialize(method, version, result) diff --git a/tests/client/test_caching.py b/tests/client/test_caching.py new file mode 100644 index 000000000..dc445a6ec --- /dev/null +++ b/tests/client/test_caching.py @@ -0,0 +1,1087 @@ +"""Tests for `mcp.client.caching`. The store-contract tests are parametrized +over `STORE_FACTORIES`; a third-party store can be run against the same +contract by adding its factory.""" + +import json +import logging +import time +from collections.abc import Awaitable, Callable +from typing import Any + +import anyio +import anyio.lowlevel +import pytest +from inline_snapshot import snapshot +from mcp_types import ( + ListPromptsResult, + ListToolsResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + PromptListChangedNotification, + ReadResourceResult, + ResourceListChangedNotification, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + ServerNotification, + ToolListChangedNotification, +) + +from mcp.client.caching import ( + MAX_TTL_MS, + CacheConfig, + CacheEntry, + CacheKey, + ClientResponseCache, + InMemoryResponseCacheStore, + ResponseCacheStore, +) + +pytestmark = pytest.mark.anyio + +STORE_FACTORIES: list[Callable[[], ResponseCacheStore]] = [InMemoryResponseCacheStore] + +store_contract = pytest.mark.parametrize("make_store", STORE_FACTORIES, ids=["InMemoryResponseCacheStore"]) + + +def _entry(value: Any = "cached") -> CacheEntry: + """Entries are opaque payloads at the store layer; only the key matters here.""" + return CacheEntry(value=value, scope="private", expires_at=None) + + +def _read_key(uri: str) -> CacheKey: + return CacheKey("resources/read", uri) + + +# --- Store contract --- + + +@store_contract +async def test_a_set_entry_round_trips_through_get(make_store: Callable[[], ResponseCacheStore]) -> None: + store = make_store() + key = CacheKey("tools/list", "", "partition-1") + entry = CacheEntry(value={"tools": []}, scope="public", expires_at=1700000000.0) + await store.set(key, entry) + assert await store.get(key) == entry + + +@store_contract +async def test_get_misses_for_a_key_never_set(make_store: Callable[[], ResponseCacheStore]) -> None: + store = make_store() + assert await store.get(CacheKey("tools/list")) is None + + +@store_contract +async def test_keys_differing_in_only_one_field_do_not_collide( + make_store: Callable[[], ResponseCacheStore], +) -> None: + """Spec-mandated: collapsing any key field would serve responses across method, params, or principal boundaries.""" + store = make_store() + base = CacheKey("resources/read", "file:///a", "partition-1") + keys = [ + base, + CacheKey("resources/list", base.params_key, base.partition), + CacheKey(base.method, "file:///b", base.partition), + CacheKey(base.method, base.params_key, "partition-2"), + ] + for i, key in enumerate(keys): + await store.set(key, _entry(i)) + for i, key in enumerate(keys): + assert await store.get(key) == _entry(i) + + +@store_contract +async def test_swapped_params_key_and_partition_values_are_distinct_keys( + make_store: Callable[[], ResponseCacheStore], +) -> None: + store = make_store() + await store.set(CacheKey("m", "a", "b"), _entry("params=a")) + await store.set(CacheKey("m", "b", "a"), _entry("params=b")) + assert await store.get(CacheKey("m", "a", "b")) == _entry("params=a") + assert await store.get(CacheKey("m", "b", "a")) == _entry("params=b") + + +@store_contract +async def test_keys_with_field_values_that_concatenate_identically_do_not_collide( + make_store: Callable[[], ResponseCacheStore], +) -> None: + """Keys compare as the field tuple - flattening would let crafted values collide across boundaries.""" + store = make_store() + keys = [ + CacheKey("a", "b.c", "p"), + CacheKey("a.b", "c", "p"), + CacheKey("m", "x", "y:z"), + CacheKey("m", "x:y", "z"), + CacheKey("m", "u/v", ""), + CacheKey("m/u", "v", ""), + CacheKey("ab", "", ""), + CacheKey("a", "b", ""), + CacheKey("", "ab", ""), + ] + for i, key in enumerate(keys): + await store.set(key, _entry(i)) + for i, key in enumerate(keys): + assert await store.get(key) == _entry(i) + + +@store_contract +async def test_set_replaces_the_entry_for_an_existing_key(make_store: Callable[[], ResponseCacheStore]) -> None: + store = make_store() + key = CacheKey("tools/list") + await store.set(key, _entry("first")) + await store.set(key, _entry("second")) + assert await store.get(key) == _entry("second") + + +@store_contract +async def test_delete_removes_only_the_given_key(make_store: Callable[[], ResponseCacheStore]) -> None: + store = make_store() + doomed = CacheKey("tools/list", "", "partition-1") + survivor = CacheKey("tools/list", "", "partition-2") + await store.set(doomed, _entry("doomed")) + await store.set(survivor, _entry("survivor")) + await store.delete(doomed) + assert await store.get(doomed) is None + assert await store.get(survivor) == _entry("survivor") + + +@store_contract +async def test_delete_is_idempotent(make_store: Callable[[], ResponseCacheStore]) -> None: + """The SDK issues unconditional deletes during eviction, so deleting an absent key must be a no-op.""" + store = make_store() + key = CacheKey("prompts/list") + await store.delete(key) + await store.set(key, _entry()) + await store.delete(key) + await store.delete(key) + assert await store.get(key) is None + + +@store_contract +async def test_clear_removes_every_entry_across_methods_and_partitions( + make_store: Callable[[], ResponseCacheStore], +) -> None: + store = make_store() + keys = [ + CacheKey("tools/list", "", "partition-1"), + CacheKey("prompts/list", "", "partition-2"), + CacheKey("resources/read", "file:///a", "partition-1"), + ] + for key in keys: + await store.set(key, _entry()) + await store.clear() + for key in keys: + assert await store.get(key) is None + + +# --- CacheConfig guards --- + + +def test_cache_config_defaults_construct_an_unshared_zero_ttl_config() -> None: + config = CacheConfig() + assert config.store is None + assert config.partition == "" + assert config.target_id is None + assert config.default_ttl_ms == 0 + assert config.clock is time.time + assert config.share_public is False + + +def test_a_custom_store_without_a_partition_is_rejected_at_construction() -> None: + """A custom store is shareable, so a missing partition would let private entries cross principals.""" + with pytest.raises(ValueError) as exc: + CacheConfig(store=InMemoryResponseCacheStore()) + assert str(exc.value) == snapshot("a custom store requires an explicit partition") + + +def test_a_custom_store_with_an_explicit_partition_constructs() -> None: + store = InMemoryResponseCacheStore() + config = CacheConfig(store=store, partition="token-subject-1") + assert config.store is store + assert config.partition == "token-subject-1" + + +def test_an_empty_target_id_is_rejected_at_construction() -> None: + """An empty target_id would collapse distinct servers onto the one shared sha256("") identity.""" + with pytest.raises(ValueError) as exc: + CacheConfig(target_id="") + assert str(exc.value) == snapshot("target_id must be a non-empty string or omitted") + + +def test_a_negative_default_ttl_is_rejected_at_construction() -> None: + """A configured negative TTL is a programming error; negative wire ttlMs is tolerated as 0 at the parse seam.""" + with pytest.raises(ValueError) as exc: + CacheConfig(default_ttl_ms=-1) + assert str(exc.value) == snapshot("default_ttl_ms must be >= 0, got -1") + + +# --- InMemoryResponseCacheStore LRU cap --- + + +async def test_a_new_entry_past_the_cap_evicts_the_least_recently_used_one() -> None: + store = InMemoryResponseCacheStore(max_entries=2) + await store.set(_read_key("file:///a"), _entry("a")) + await store.set(_read_key("file:///b"), _entry("b")) + await store.set(_read_key("file:///c"), _entry("c")) + assert await store.get(_read_key("file:///a")) is None + assert await store.get(_read_key("file:///b")) == _entry("b") + assert await store.get(_read_key("file:///c")) == _entry("c") + + +async def test_a_get_refreshes_an_entrys_recency() -> None: + """Eviction order is recency (LRU), not insertion order: serving an entry keeps it alive.""" + store = InMemoryResponseCacheStore(max_entries=2) + await store.set(_read_key("file:///a"), _entry("a")) + await store.set(_read_key("file:///b"), _entry("b")) + assert await store.get(_read_key("file:///a")) == _entry("a") # a is now the most recent + await store.set(_read_key("file:///c"), _entry("c")) # evicts b, not a + assert await store.get(_read_key("file:///a")) == _entry("a") + assert await store.get(_read_key("file:///b")) is None + assert await store.get(_read_key("file:///c")) == _entry("c") + + +async def test_replacing_an_entry_at_the_cap_refreshes_its_recency_without_evicting() -> None: + store = InMemoryResponseCacheStore(max_entries=2) + await store.set(_read_key("file:///a"), _entry("a")) + await store.set(_read_key("file:///b"), _entry("b")) + await store.set(_read_key("file:///a"), _entry("a-replaced")) # still two entries; a is now the most recent + await store.set(_read_key("file:///c"), _entry("c")) # evicts b + assert await store.get(_read_key("file:///a")) == _entry("a-replaced") + assert await store.get(_read_key("file:///b")) is None + assert await store.get(_read_key("file:///c")) == _entry("c") + + +async def test_a_touched_list_entry_survives_read_key_churn_through_the_cap() -> None: + """The reason the cap is LRU over all entries: a hot list singleton each principal + keeps re-reading must survive churn from per-uri resources/read keys.""" + store = InMemoryResponseCacheStore(max_entries=3) + await store.set(CacheKey("tools/list"), _entry("tools")) + for i in range(10): + assert await store.get(CacheKey("tools/list")) == _entry("tools") # each serve re-touches it + await store.set(_read_key(f"file:///{i}"), _entry(i)) + assert await store.get(CacheKey("tools/list")) == _entry("tools") + + +async def test_a_zero_cap_disables_eviction() -> None: + store = InMemoryResponseCacheStore(max_entries=0) + uris = [f"file:///{i}" for i in range(5)] + for uri in uris: + await store.set(_read_key(uri), _entry(uri)) + for uri in uris: + assert await store.get(_read_key(uri)) == _entry(uri) + + +async def test_deleting_an_entry_frees_its_cap_slot() -> None: + store = InMemoryResponseCacheStore(max_entries=1) + await store.set(_read_key("file:///a"), _entry("a")) + await store.delete(_read_key("file:///a")) + await store.set(_read_key("file:///b"), _entry("b")) + assert await store.get(_read_key("file:///b")) == _entry("b") + + +def test_a_negative_cap_is_rejected_at_construction() -> None: + with pytest.raises(ValueError) as exc: + InMemoryResponseCacheStore(max_entries=-1) + assert str(exc.value) == snapshot("max_entries must be >= 0, got -1") + + +# --- ClientResponseCache coordinator --- + +MODERN_VERSION = "2026-07-28" +LEGACY_VERSION = "2025-11-25" + + +class _ManualClock: + """Injected wall clock: tests advance `now` instead of sleeping.""" + + def __init__(self) -> None: + self.now = 1_000_000.0 + + def __call__(self) -> float: + return self.now + + +def _coordinator( + store: ResponseCacheStore, + *, + partition: str = "", + arm_id: str = "arm", + default_ttl_ms: int = 0, + clock: _ManualClock | None = None, + share_public: bool = False, + version: str | None = MODERN_VERSION, + generation_map_cap: int = 4096, + store_cleanup_timeout: float = 5, +) -> ClientResponseCache: + return ClientResponseCache( + store=store, + partition=partition, + arm_id=arm_id, + default_ttl_ms=default_ttl_ms, + clock=clock or _ManualClock(), + share_public=share_public, + negotiated_version=lambda: version, + generation_map_cap=generation_map_cap, + store_cleanup_timeout=store_cleanup_timeout, + ) + + +def _private_arm(arm_id: str = "arm", partition: str = "", era: str | None = MODERN_VERSION) -> str: + return json.dumps(["private", era, arm_id, partition]) + + +def _public_arm(arm_id: str = "arm", partition: str = "", era: str | None = MODERN_VERSION) -> str: + return json.dumps(["public", era, arm_id, partition]) + + +def _wire_result(ttl_ms: int | None = None, cache_scope: str | None = None) -> ListToolsResult: + """A wire-parsed `tools/list` result; `None` keeps the hint out of `model_fields_set`.""" + payload: dict[str, Any] = {"tools": []} + if ttl_ms is not None: + payload["ttlMs"] = ttl_ms + if cache_scope is not None: + payload["cacheScope"] = cache_scope + return ListToolsResult.model_validate(payload) + + +def _read_result(ttl_ms: int) -> ReadResourceResult: + return ReadResourceResult.model_validate({"contents": [], "ttlMs": ttl_ms}) + + +class _ScriptedStore: + """Logs `(op, key)` and awaits one-shot hooks around commits, modelling an async store mid-commit.""" + + def __init__(self) -> None: + self.inner = InMemoryResponseCacheStore() + self.ops: list[tuple[str, CacheKey]] = [] + self.before_set_commits: Callable[[], Awaitable[None]] | None = None + self.after_set_commits: Callable[[], Awaitable[None]] | None = None + self.after_delete_commits: Callable[[], Awaitable[None]] | None = None + + async def get(self, key: CacheKey) -> CacheEntry | None: + self.ops.append(("get", key)) + return await self.inner.get(key) + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: + self.ops.append(("set", key)) + if self.before_set_commits is not None: + hook, self.before_set_commits = self.before_set_commits, None + await hook() + await self.inner.set(key, entry) + if self.after_set_commits is not None: + hook, self.after_set_commits = self.after_set_commits, None + await hook() + + async def delete(self, key: CacheKey) -> None: + self.ops.append(("delete", key)) + await self.inner.delete(key) + if self.after_delete_commits is not None: + hook, self.after_delete_commits = self.after_delete_commits, None + await hook() + + async def clear(self) -> None: + raise NotImplementedError + + +class _FailingStore: + """Operations raise while their flag is set; toggling a flag models recovery.""" + + def __init__(self, *, fail_get: bool = False, fail_set: bool = False, fail_delete: bool = False) -> None: + self.inner = InMemoryResponseCacheStore() + self.fail_get = fail_get + self.fail_set = fail_set + self.fail_delete = fail_delete + + async def get(self, key: CacheKey) -> CacheEntry | None: + if self.fail_get: + raise RuntimeError("store get failed") + return await self.inner.get(key) + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: + if self.fail_set: + raise RuntimeError("store set failed") + await self.inner.set(key, entry) + + async def delete(self, key: CacheKey) -> None: + if self.fail_delete: + raise RuntimeError("store delete failed") + await self.inner.delete(key) + + async def clear(self) -> None: + raise NotImplementedError + + +class _ArmDeleteFailingStore: + """`delete` raises only for keys on the given arm, modelling a failed opposite-arm cleanup.""" + + def __init__(self, failing_arm: str) -> None: + self.inner = InMemoryResponseCacheStore() + self.failing_arm = failing_arm + + async def get(self, key: CacheKey) -> CacheEntry | None: + return await self.inner.get(key) + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: + raise NotImplementedError + + async def delete(self, key: CacheKey) -> None: + if key.partition == self.failing_arm: + raise RuntimeError("store delete failed") + await self.inner.delete(key) + + async def clear(self) -> None: + raise NotImplementedError + + +# The lax pragmas here and in the wedged-store tests: 3.11's settrace-based coverage loses +# tracing in frames resumed after the coordinator's bounded-shield cleanup cancellation. +class _WedgingDeleteStore: + """Once `wedged` flips, every `delete` blocks forever (an Event nothing sets), + modelling a remote store with no socket timeout of its own.""" + + before_set_commits: Callable[[], Awaitable[None]] + """Awaited before `set` commits; assigned by the one test whose write reaches `set`.""" + + def __init__(self, *, wedged: bool = False) -> None: + self.inner = InMemoryResponseCacheStore() + self.wedged = wedged + self.deletes_started = 0 + + async def get(self, key: CacheKey) -> CacheEntry | None: + raise NotImplementedError + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: + await self.before_set_commits() + await self.inner.set(key, entry) # pragma: lax no cover + + async def delete(self, key: CacheKey) -> None: + self.deletes_started += 1 + if self.wedged: + await anyio.Event().wait() + await self.inner.delete(key) + + async def clear(self) -> None: + raise NotImplementedError + + +class _RehydratingStore: + """`get` returns whatever a persistent store's deserializer produced - not necessarily what `set` received.""" + + def __init__(self, rehydrated: Any) -> None: + self.rehydrated = rehydrated + + async def get(self, key: CacheKey) -> CacheEntry | None: + return self.rehydrated + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: + raise NotImplementedError + + async def delete(self, key: CacheKey) -> None: + raise NotImplementedError + + async def clear(self) -> None: + raise NotImplementedError + + +# --- Coordinator: era gate --- + + +@pytest.mark.parametrize("version", [LEGACY_VERSION, None], ids=["legacy", "pre-negotiation"]) +async def test_hints_from_a_non_modern_session_are_ignored(version: str | None) -> None: + """The hints are 2026-07-28 assertions a legacy peer can still inject onto the wire (unknown keys + reach `model_fields_set`), so on a non-modern session every result is treated as hint-absent.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store, version=version) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + assert await cache.read("tools/list", "") is None + assert await store.get(CacheKey("tools/list", "", _private_arm(era=version))) is None + assert await store.get(CacheKey("tools/list", "", _public_arm(era=version))) is None + + +async def test_a_legacy_session_with_a_default_ttl_caches_on_the_private_arm_only() -> None: + """The operator's default TTL still applies on legacy sessions; injected hints cannot promote or re-clock.""" + store = InMemoryResponseCacheStore() + clock = _ManualClock() + cache = _coordinator(store, version=LEGACY_VERSION, default_ttl_ms=60_000, clock=clock) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=5, cache_scope="public"), gen, "use") + private_entry = await store.get(CacheKey("tools/list", "", _private_arm(era=LEGACY_VERSION))) + assert private_entry is not None + assert private_entry.scope == "private" + assert await store.get(CacheKey("tools/list", "", _public_arm(era=LEGACY_VERSION))) is None + clock.now += 1.0 # well past the injected 5ms; the default 60s governs + assert await cache.read("tools/list", "") == _wire_result(ttl_ms=5, cache_scope="public") + + +async def test_entries_never_cross_negotiated_eras_on_a_shared_store() -> None: + """Arms fold in the negotiated version: the same listing genuinely differs by era + (the SDK strips the 2026 fields for legacy sessions), so a 2025-negotiated session + is never served an entry a 2026 session wrote - on either arm - nor vice versa.""" + store = InMemoryResponseCacheStore() + modern = _coordinator(store, partition="p", default_ttl_ms=60_000) + legacy = _coordinator(store, partition="p", version=LEGACY_VERSION, default_ttl_ms=60_000) + + gen = modern.capture("tools/list", "") + await modern.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") # public arm + private_result = ListPromptsResult.model_validate({"prompts": [], "ttlMs": 60_000}) + gen = modern.capture("prompts/list", "") + await modern.write("prompts/list", "", private_result, gen, "use") # private arm + assert await legacy.read("tools/list", "") is None + assert await legacy.read("prompts/list", "") is None + + gen = legacy.capture("resources/read", "file:///a") + await legacy.write("resources/read", "file:///a", _read_result(ttl_ms=60_000), gen, "use") + assert await legacy.read("resources/read", "file:///a") is not None # cached for legacy itself... + assert await modern.read("resources/read", "file:///a") is None # ...but invisible across the era boundary + + +async def test_coordinators_negotiating_the_same_era_share_entries_through_the_store() -> None: + """Era scoping splits eras only: same-era clients sharing a store still share both arms.""" + store = InMemoryResponseCacheStore() + writer = _coordinator(store, partition="p") + reader = _coordinator(store, partition="p") + + gen = writer.capture("tools/list", "") + await writer.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + private_result = ListPromptsResult.model_validate({"prompts": [], "ttlMs": 60_000}) + gen = writer.capture("prompts/list", "") + await writer.write("prompts/list", "", private_result, gen, "use") + + assert await reader.read("tools/list", "") == _wire_result(ttl_ms=60_000, cache_scope="public") + assert await reader.read("prompts/list", "") == private_result + + +# --- Coordinator: TTL and scope resolution --- + + +async def test_an_explicit_zero_ttl_is_not_overridden_by_the_default_ttl() -> None: + """Spec-mandated: ttlMs 0 means immediately stale; the default fills in only for hint-absent results.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store, default_ttl_ms=60_000) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=0), gen, "use") + assert await store.get(CacheKey("tools/list", "", _private_arm())) is None + assert await store.get(CacheKey("tools/list", "", _public_arm())) is None + + +async def test_a_hint_absent_modern_result_uses_the_default_ttl_privately() -> None: + store = InMemoryResponseCacheStore() + clock = _ManualClock() + cache = _coordinator(store, default_ttl_ms=60_000, clock=clock) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(), gen, "use") + entry = await store.get(CacheKey("tools/list", "", _private_arm())) + assert entry is not None + assert entry.scope == "private" + assert entry.expires_at == clock.now + 60.0 + assert await cache.read("tools/list", "") == _wire_result() + clock.now += 60.0 + assert await cache.read("tools/list", "") is None + + +async def test_a_ttl_above_24_hours_is_clamped_to_the_cap() -> None: + """SEP-2549 hardening: a server cannot pin an entry beyond `MAX_TTL_MS`.""" + store = InMemoryResponseCacheStore() + clock = _ManualClock() + cache = _coordinator(store, clock=clock) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=7 * MAX_TTL_MS), gen, "use") + entry = await store.get(CacheKey("tools/list", "", _private_arm())) + assert entry is not None + assert entry.expires_at == clock.now + MAX_TTL_MS / 1000 + + +async def test_a_public_result_lands_on_the_public_arm_and_clears_the_private_arm() -> None: + """On a scope flip, writing the new arm deletes the other so the two arms never both answer.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert await store.get(CacheKey("tools/list", "", _private_arm())) is not None + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + public_entry = await store.get(CacheKey("tools/list", "", _public_arm())) + assert public_entry is not None + assert public_entry.scope == "public" + assert await store.get(CacheKey("tools/list", "", _private_arm())) is None + + +# --- Coordinator: partition arms and the scope guard --- + + +async def test_arm_key_layout_is_pinned_for_shared_store_compatibility() -> None: + """Arm strings are cross-process store key material; changing their layout breaks shared stores.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store, partition="tenant-a", arm_id="abc123", default_ttl_ms=60_000) + assert cache._arm("private") == snapshot('["private", "2026-07-28", "abc123", "tenant-a"]') + assert cache._arm("public") == snapshot('["public", "2026-07-28", "abc123", "tenant-a"]') + shared = _coordinator(store, partition="tenant-a", arm_id="abc123", share_public=True) + assert shared._arm("public") == snapshot('["public", "2026-07-28", "abc123"]') + # And entries genuinely land under those strings. + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(), gen, "use") + assert await store.get(CacheKey("tools/list", "", '["private", "2026-07-28", "abc123", "tenant-a"]')) is not None + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + assert await store.get(CacheKey("tools/list", "", '["public", "2026-07-28", "abc123", "tenant-a"]')) is not None + gen = shared.capture("tools/list", "") + await shared.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + assert await store.get(CacheKey("tools/list", "", '["public", "2026-07-28", "abc123"]')) is not None + + +async def test_public_entries_do_not_cross_partitions_by_default() -> None: + """Security default (deviates from the TypeScript SDK): a server stamping per-tenant data public + (bug or malice) cannot leak one tenant's response to another through a shared store.""" + store = InMemoryResponseCacheStore() + tenant_a = _coordinator(store, partition="tenant-a") + tenant_b = _coordinator(store, partition="tenant-b") + gen = tenant_a.capture("tools/list", "") + await tenant_a.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + assert await tenant_a.read("tools/list", "") == _wire_result(ttl_ms=60_000, cache_scope="public") + assert await tenant_b.read("tools/list", "") is None + + +async def test_share_public_serves_public_entries_across_partitions_but_never_private_ones() -> None: + store = InMemoryResponseCacheStore() + tenant_a = _coordinator(store, partition="tenant-a", share_public=True) + tenant_b = _coordinator(store, partition="tenant-b", share_public=True) + gen = tenant_a.capture("tools/list", "") + await tenant_a.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + assert await tenant_b.read("tools/list", "") == _wire_result(ttl_ms=60_000, cache_scope="public") + private_result = ListPromptsResult.model_validate({"prompts": [], "ttlMs": 60_000}) + gen = tenant_a.capture("prompts/list", "") + await tenant_a.write("prompts/list", "", private_result, gen, "use") + assert await tenant_b.read("prompts/list", "") is None + + +async def test_a_private_scoped_entry_under_the_public_arm_is_not_served() -> None: + """Defense in depth against a corrupted or pre-seeded store: the arm routes, the entry's scope verifies.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store) + await store.set( + CacheKey("tools/list", "", _public_arm()), + CacheEntry(value=_wire_result(), scope="private", expires_at=2_000_000.0), + ) + assert await cache.read("tools/list", "") is None + + +async def test_a_stale_private_entry_does_not_shadow_a_fresh_public_one() -> None: + """A stale private entry is an arm-probe miss, so the fall-through finds a public entry seeded by + another client after a server scope flip.""" + store = InMemoryResponseCacheStore() + clock = _ManualClock() + cache = _coordinator(store, clock=clock) + await store.set( + CacheKey("tools/list", "", _private_arm()), + CacheEntry(value=_wire_result(), scope="private", expires_at=clock.now - 1.0), + ) + public_result = _wire_result(ttl_ms=60_000, cache_scope="public") + await store.set( + CacheKey("tools/list", "", _public_arm()), + CacheEntry(value=public_result, scope="public", expires_at=clock.now + 60.0), + ) + assert await cache.read("tools/list", "") == public_result + + +async def test_an_entry_without_an_expiry_is_never_fresh() -> None: + """Entries rehydrated without expiry metadata are misses, not immortal.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store) + await store.set( + CacheKey("tools/list", "", _private_arm()), + CacheEntry(value=_wire_result(), scope="private", expires_at=None), + ) + assert await cache.read("tools/list", "") is None + + +# --- Coordinator: write ordering --- + + +async def test_write_deletes_the_opposite_arm_before_setting_its_own() -> None: + """Delete-then-set: a cancellation between the two operations leaves a miss, never two answering arms.""" + store = _ScriptedStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + assert store.ops == [ + ("delete", CacheKey("tools/list", "", _private_arm())), + ("set", CacheKey("tools/list", "", _public_arm())), + ] + + +async def test_an_eviction_landing_during_an_async_set_is_compensated() -> None: + """TOCTOU re-check: the eviction's deletes see nothing (the set has not committed yet), so the + post-set generation re-check must fire a compensating delete.""" + store = _ScriptedStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + + async def evict_mid_commit() -> None: + await cache.evict_method("tools/list") + + store.before_set_commits = evict_mid_commit + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + private_key = CacheKey("tools/list", "", _private_arm()) + public_key = CacheKey("tools/list", "", _public_arm()) + assert store.ops == [ + ("delete", public_key), # write: opposite arm first + ("set", private_key), # write: own arm, commit still pending + ("delete", private_key), # eviction (sees nothing - not committed yet) + ("delete", public_key), # eviction + ("delete", private_key), # post-set re-check compensation + ] + assert await store.inner.get(private_key) is None + assert await cache.read("tools/list", "") is None + + +async def test_a_cancellation_landing_as_the_set_commits_still_compensates_an_eviction() -> None: + """The compensating delete is shielded: a timeout firing while the store's set is already on the + wire must not resurrect the evicted entry for its full TTL.""" + store = _ScriptedStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + private_key = CacheKey("tools/list", "", _private_arm()) + public_key = CacheKey("tools/list", "", _public_arm()) + with anyio.CancelScope() as scope: + + async def evict_then_cancel() -> None: + await cache.evict_method("tools/list") + scope.cancel() + + store.before_set_commits = evict_then_cancel + store.after_set_commits = anyio.lowlevel.checkpoint # first checkpoint after the commit + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert scope.cancelled_caught + assert store.ops == [ + ("delete", public_key), # write: opposite arm first + ("set", private_key), # write: own arm, commit still pending + ("delete", private_key), # eviction (sees nothing - not committed yet) + ("delete", public_key), # eviction + ("delete", private_key), # post-set re-check compensation, shielded + ] + assert await store.inner.get(private_key) is None + + +async def test_a_cancellation_during_the_refresh_purge_still_purges_both_arms() -> None: + """The refresh purge is shielded - a mid-purge cancellation must not leave the superseded opposite arm.""" + store = _ScriptedStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + public_key = CacheKey("tools/list", "", _public_arm()) + assert await store.inner.get(public_key) is not None + with anyio.CancelScope() as scope: + scope.cancel() + # Delivers at the first checkpoint after the private-arm delete commits. + store.after_delete_commits = anyio.lowlevel.checkpoint + await cache.write("tools/list", "", _wire_result(ttl_ms=0), gen, "refresh") + assert await store.inner.get(public_key) is None + + +async def test_a_cancellation_during_an_eviction_still_evicts_both_arms() -> None: + """Eviction's arm deletes are shielded - a notification task cancelled mid-eviction (session + teardown) must not leave one arm serving the evicted entry.""" + store = _ScriptedStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000, cache_scope="public"), gen, "use") + public_key = CacheKey("tools/list", "", _public_arm()) + with anyio.CancelScope() as scope: + scope.cancel() + # Delivers at the first checkpoint after the private-arm delete commits. + store.after_delete_commits = anyio.lowlevel.checkpoint + await cache.evict_method("tools/list") + assert await store.inner.get(public_key) is None + + +# --- Coordinator: bounded must-complete cleanup --- +# These tests inject a tiny `store_cleanup_timeout` because the bound itself is the +# behavior under test; the wedged delete only ever blocks for that injected bound. + + +async def test_evict_key_with_a_wedged_store_delete_returns_at_the_cleanup_bound( + caplog: pytest.LogCaptureFixture, +) -> None: + """A store delete that never completes cannot make eviction - and with it client + teardown - hang uncancellably: the must-complete cleanup is bounded, the remaining + deletes are abandoned, and the unreaped entries age out by TTL.""" + store = _WedgingDeleteStore(wedged=True) + cache = _coordinator(store, store_cleanup_timeout=0.01) + with caplog.at_level(logging.WARNING, logger="mcp.client.caching"), anyio.fail_after(5): + await cache.evict_key("tools/list", "") + assert store.deletes_started == 1 # pragma: lax no cover # the second arm's delete was abandoned with the first + assert caplog.messages == snapshot( # pragma: lax no cover + ["Response cache store delete timed out; the entry will age out by TTL"] + ) + + +async def test_a_refresh_purge_with_a_wedged_store_delete_returns_at_the_cleanup_bound() -> None: + store = _WedgingDeleteStore(wedged=True) + cache = _coordinator(store, store_cleanup_timeout=0.01) + gen = cache.capture("tools/list", "") + with anyio.fail_after(5): + await cache.write("tools/list", "", _wire_result(ttl_ms=0), gen, "refresh") + assert store.deletes_started == 1 # pragma: lax no cover + + +async def test_an_eviction_mid_set_with_a_wedged_store_delete_returns_at_the_cleanup_bound() -> None: + """The post-set compensating delete is bounded like every other must-complete delete; + the entry it could not reap stays in the store and ages out by TTL.""" + store = _WedgingDeleteStore() + cache = _coordinator(store, store_cleanup_timeout=0.01) + gen = cache.capture("tools/list", "") + + async def wedge_then_evict() -> None: + store.wedged = True + await cache.evict_method("tools/list") # its own cleanup hits the bound too + + store.before_set_commits = wedge_then_evict + with anyio.fail_after(5): + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + # Opposite-arm delete, the eviction's first delete, the compensating delete. + assert store.deletes_started == 3 # pragma: lax no cover + # The accepted degradation: the unreaped entry stays until its TTL expires. + assert await store.inner.get(CacheKey("tools/list", "", _private_arm())) is not None # pragma: lax no cover + + +# --- Coordinator: store error discipline --- + + +async def test_a_raising_store_get_is_a_cache_miss() -> None: + store = _FailingStore(fail_get=True) + cache = _coordinator(store) + assert await cache.read("tools/list", "") is None + + +@pytest.mark.parametrize( + "rehydrated", + [ + CacheEntry(value={"tools": []}, scope="private", expires_at=2_000_000.0), + {"value": {"tools": []}, "scope": "private", "expires_at": 2_000_000.0}, + ], + ids=["dict-value", "dict-entry"], +) +async def test_an_entry_rehydrated_into_the_wrong_shape_is_a_warned_miss( + rehydrated: Any, caplog: pytest.LogCaptureFixture +) -> None: + """A persistent store has no method-to-model mapping, so its `get` may return serialized shapes; + the warned miss is one burst, not one warning per cached read.""" + cache = _coordinator(_RehydratingStore(rehydrated)) + with caplog.at_level(logging.WARNING, logger="mcp.client.caching"): + assert await cache.read("tools/list", "") is None + assert await cache.read("tools/list", "") is None + assert len(caplog.records) == 1 + + +async def test_a_raising_opposite_arm_delete_aborts_the_write() -> None: + """Setting after a failed opposite-arm delete could leave both arms populated.""" + store = _FailingStore(fail_delete=True) + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert await store.inner.get(CacheKey("tools/list", "", _private_arm())) is None + assert await store.inner.get(CacheKey("tools/list", "", _public_arm())) is None + + +async def test_a_failed_opposite_arm_delete_degrades_the_key_to_a_full_miss() -> None: + """The fetch superseded the warm own-arm entry, so it is best-effort deleted too; the write never raises.""" + store = _ArmDeleteFailingStore(failing_arm=_public_arm()) + cache = _coordinator(store) + await store.inner.set( + CacheKey("tools/list", "", _private_arm()), + CacheEntry(value=_wire_result(), scope="private", expires_at=2_000_000.0), + ) + assert await cache.read("tools/list", "") is not None # the warm own-arm entry + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert await store.inner.get(CacheKey("tools/list", "", _private_arm())) is None + assert await store.inner.get(CacheKey("tools/list", "", _public_arm())) is None + assert await cache.read("tools/list", "") is None + + +async def test_a_raising_store_set_caches_nothing_and_does_not_raise() -> None: + store = _FailingStore(fail_set=True) + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert await cache.read("tools/list", "") is None + + +async def test_a_failed_set_purges_the_pre_existing_own_arm_entry() -> None: + """The fetch superseded the warm own-arm entry, and the failed set left it in place: + without the purge it would keep serving the superseded value for its full TTL.""" + store = _FailingStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert await cache.read("tools/list", "") is not None # the warm own-arm entry + store.fail_set = True + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") # the caller's fetch is unaffected + assert await store.inner.get(CacheKey("tools/list", "", _private_arm())) is None + assert await store.inner.get(CacheKey("tools/list", "", _public_arm())) is None + assert await cache.read("tools/list", "") is None + + +async def test_eviction_with_a_raising_delete_still_bumps_the_generation() -> None: + """Bump-first: a fetch captured before the eviction cannot write back even when the deletes raise.""" + store = _FailingStore() + cache = _coordinator(store) + stale_gen = cache.capture("tools/list", "") # fetch in flight when the eviction lands + store.fail_delete = True + await cache.evict_method("tools/list") # deletes raise; the bump already happened + store.fail_delete = False + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), stale_gen, "use") + assert await store.inner.get(CacheKey("tools/list", "", _private_arm())) is None + fresh_gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), fresh_gen, "use") + assert await cache.read("tools/list", "") == _wire_result(ttl_ms=60_000) + + +async def test_store_failures_warn_once_per_burst(caplog: pytest.LogCaptureFixture) -> None: + store = _FailingStore(fail_get=True) + cache = _coordinator(store) + with caplog.at_level(logging.WARNING, logger="mcp.client.caching"): + await cache.read("tools/list", "") # consecutive failing reads, one burst + await cache.read("tools/list", "") + assert len(caplog.records) == 1 + store.fail_get = False + await cache.read("tools/list", "") # success re-arms the warning + store.fail_get = True + await cache.read("tools/list", "") + assert len(caplog.records) == 2 + assert caplog.messages[0] == snapshot("Response cache store operation failed; continuing without the cache") + + +async def test_a_set_only_store_failure_warns_once_across_write_cycles(caplog: pytest.LogCaptureFixture) -> None: + """Bursts are tracked per operation kind - the healthy deletes between failing sets never re-arm.""" + store = _FailingStore(fail_set=True) + cache = _coordinator(store) + with caplog.at_level(logging.WARNING, logger="mcp.client.caching"): + for _ in range(3): # each cycle: opposite-arm delete succeeds, then the set fails + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert len(caplog.records) == 1 + store.fail_set = False + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") # set succeeds, re-arms + store.fail_set = True + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert len(caplog.records) == 2 + + +# --- Coordinator: generation discipline --- + + +async def test_an_eviction_between_capture_and_write_discards_the_write() -> None: + """Spec-aligned: a fetch in flight when its key is evicted must not write the evicted entry back.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.evict_method("tools/list") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert await store.get(CacheKey("tools/list", "", _private_arm())) is None + assert await store.get(CacheKey("tools/list", "", _public_arm())) is None + + +async def test_recapturing_a_registered_key_returns_its_current_generation() -> None: + store = InMemoryResponseCacheStore() + cache = _coordinator(store) + gen_before = cache.capture("tools/list", "") + await cache.evict_method("tools/list") + gen_after = cache.capture("tools/list", "") + assert gen_after != gen_before + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen_after, "use") + assert await cache.read("tools/list", "") == _wire_result(ttl_ms=60_000) + + +async def test_the_generation_map_drops_the_oldest_key_at_its_cap() -> None: + """A dropped key's race guard degrades to the accepted co-tenant class - an eviction racing its + in-flight fetch goes undetected (cap is 4096 in production, parametrized small here).""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store, generation_map_cap=2) + gen_a = cache.capture("resources/read", "file:///a") + gen_b = cache.capture("resources/read", "file:///b") + cache.capture("resources/read", "file:///c") # at the cap: drops file:///a + await cache.evict_key("resources/read", "file:///a") # unregistered: no bump + await cache.evict_key("resources/read", "file:///b") # registered: bump + await cache.write("resources/read", "file:///a", _read_result(ttl_ms=60_000), gen_a, "use") + await cache.write("resources/read", "file:///b", _read_result(ttl_ms=60_000), gen_b, "use") + assert await cache.read("resources/read", "file:///a") is not None # degraded guard fails open + assert await cache.read("resources/read", "file:///b") is None # guard held + + +# --- Coordinator: eviction --- + + +async def test_a_refresh_resolving_uncacheable_purges_the_warm_entry() -> None: + """The refetch superseded the warm entry, which must not be served again.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store) + gen = cache.capture("tools/list", "") + await cache.write("tools/list", "", _wire_result(ttl_ms=60_000), gen, "use") + assert await cache.read("tools/list", "") is not None + await cache.write("tools/list", "", _wire_result(ttl_ms=0), gen, "refresh") + assert await store.get(CacheKey("tools/list", "", _private_arm())) is None + assert await store.get(CacheKey("tools/list", "", _public_arm())) is None + + +async def test_evict_key_on_an_unregistered_key_still_deletes_both_arms() -> None: + """A persistent store may hold warm entries from a prior process this coordinator never captured.""" + store = InMemoryResponseCacheStore() + await store.set( + CacheKey("resources/read", "file:///warm", _private_arm()), + CacheEntry(value=_read_result(ttl_ms=60_000), scope="private", expires_at=2_000_000.0), + ) + await store.set( + CacheKey("resources/read", "file:///warm", _public_arm()), + CacheEntry(value=_read_result(ttl_ms=60_000), scope="public", expires_at=2_000_000.0), + ) + cache = _coordinator(store) + await cache.evict_key("resources/read", "file:///warm") + assert await store.get(CacheKey("resources/read", "file:///warm", _private_arm())) is None + assert await store.get(CacheKey("resources/read", "file:///warm", _public_arm())) is None + + +@pytest.mark.parametrize( + ("notification", "evicted"), + [ + (ToolListChangedNotification(), {("tools/list", "")}), + (PromptListChangedNotification(), {("prompts/list", "")}), + (ResourceListChangedNotification(), {("resources/list", ""), ("resources/templates/list", "")}), + ( + ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///a")), + {("resources/read", "file:///a")}, + ), + ( + LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="x")), + set[tuple[str, str]](), + ), + ], + ids=["tools-list-changed", "prompts-list-changed", "resources-list-changed", "resource-updated", "unrelated"], +) +async def test_notifications_evict_exactly_their_mapped_entries( + notification: ServerNotification, evicted: set[tuple[str, str]] +) -> None: + """Spec SHOULD: notifications invalidate - and nothing beyond their mapped entries.""" + store = InMemoryResponseCacheStore() + cache = _coordinator(store) + seeded = [ + ("tools/list", ""), + ("prompts/list", ""), + ("resources/list", ""), + ("resources/templates/list", ""), + ("resources/read", "file:///a"), + ("resources/read", "file:///b"), + ] + for method, params_key in seeded: + # The value's content is irrelevant to eviction; any cacheable model serves. + await store.set( + CacheKey(method, params_key, _private_arm()), + CacheEntry(value=_wire_result(), scope="private", expires_at=2_000_000.0), + ) + await cache.evict_for_notification(notification) + for method, params_key in seeded: + if (method, params_key) in evicted: + assert await cache.read(method, params_key) is None + else: + assert await cache.read(method, params_key) is not None diff --git a/tests/client/test_client.py b/tests/client/test_client.py index a6a9ac6ea..820478f3f 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -506,6 +506,100 @@ async def on_list_tools( assert [t.name for t in result.tools] == ["ok", "dropme"] +_RETIRED_TOOL = Tool( + name="retired", + input_schema={"type": "object", "properties": {"region": {"type": "string", "x-mcp-header": "Region"}}}, + output_schema={"type": "object"}, +) +_SURVIVOR_TOOL = Tool(name="survivor", input_schema={"type": "object"}) + + +def _scripted_listing_server(listings: list[ListToolsResult]) -> Server: + """Serves the given listings in order, one per tools/list request.""" + + async def on_list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return listings.pop(0) + + return Server("test", on_list_tools=on_list_tools) + + +async def test_a_complete_listing_prunes_per_tool_state_for_tools_it_no_longer_contains() -> None: + """SDK-defined: a complete (uncursored, cursorless) listing is the full tool universe, so the + header map and output schema derived from an earlier listing of a now-absent tool are dropped.""" + server = _scripted_listing_server( + [ + ListToolsResult(tools=[_RETIRED_TOOL, _SURVIVOR_TOOL]), + ListToolsResult(tools=[_SURVIVOR_TOOL]), + ] + ) + + with anyio.fail_after(5): + async with Client(server) as client: + await client.session.list_tools() + assert set(client.session._x_mcp_header_maps) == {"retired", "survivor"} + assert set(client.session._tool_output_schemas) == {"retired", "survivor"} + + await client.session.list_tools() + assert set(client.session._x_mcp_header_maps) == {"survivor"} + assert set(client.session._tool_output_schemas) == {"survivor"} + + +async def test_a_complete_listing_prunes_output_schemas_on_a_legacy_session_too() -> None: + """SDK-defined: the prune is era-independent -- legacy sessions cache output schemas the same + way (their header-map dict just stays empty, since the x-mcp-header filter is 2026-only).""" + server = _scripted_listing_server( + [ + ListToolsResult(tools=[_RETIRED_TOOL, _SURVIVOR_TOOL]), + ListToolsResult(tools=[_SURVIVOR_TOOL]), + ] + ) + + with anyio.fail_after(5): + async with Client(server, mode="legacy") as client: + await client.session.list_tools() + assert set(client.session._tool_output_schemas) == {"retired", "survivor"} + assert client.session._x_mcp_header_maps == {} + + await client.session.list_tools() + assert set(client.session._tool_output_schemas) == {"survivor"} + + +async def test_a_listing_with_a_next_cursor_prunes_no_per_tool_state() -> None: + """SDK-defined: a first page carrying next_cursor is not the full universe -- state for tools + expected on later pages must survive it.""" + server = _scripted_listing_server( + [ + ListToolsResult(tools=[_RETIRED_TOOL, _SURVIVOR_TOOL]), + ListToolsResult(tools=[_SURVIVOR_TOOL], next_cursor="2"), + ] + ) + + with anyio.fail_after(5): + async with Client(server) as client: + await client.session.list_tools() + await client.session.list_tools() + assert set(client.session._x_mcp_header_maps) == {"retired", "survivor"} + assert set(client.session._tool_output_schemas) == {"retired", "survivor"} + + +async def test_a_cursor_page_fetch_prunes_no_per_tool_state() -> None: + """SDK-defined: a continuation page is partial even when it ends the pagination (no + next_cursor) -- only an uncursored single-page listing prunes.""" + server = _scripted_listing_server( + [ + ListToolsResult(tools=[_RETIRED_TOOL, _SURVIVOR_TOOL]), + ListToolsResult(tools=[_SURVIVOR_TOOL]), + ] + ) + + with anyio.fail_after(5): + async with Client(server) as client: + await client.session.list_tools() + await client.session.list_tools(params=types.PaginatedRequestParams(cursor="2")) + assert set(client.session._x_mcp_header_maps) == {"retired", "survivor"} + assert set(client.session._tool_output_schemas) == {"retired", "survivor"} + + def test_client_rejects_handshake_era_mode_at_construction() -> None: """A handshake-era protocol-version string passed as `mode=` is rejected by `__post_init__` with a hint to use `mode='legacy'` — the version-pin path is diff --git a/tests/client/test_client_caching.py b/tests/client/test_client_caching.py new file mode 100644 index 000000000..708d83db4 --- /dev/null +++ b/tests/client/test_client_caching.py @@ -0,0 +1,1579 @@ +"""`Client` wiring for the response cache: the `cache=` kwarg, server identity +resolution, the custom-store guard, notification eviction, and the five cacheable +verbs. The coordinator's own behavior is covered in `test_caching.py`.""" + +import hashlib +import json +import time +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import asynccontextmanager +from types import TracebackType +from typing import Any, Literal + +import anyio +import anyio.lowlevel +import httpx +import mcp_types as types +import pytest +from inline_snapshot import snapshot +from mcp_types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + CallToolResult, + DiscoverResult, + ElicitRequest, + ElicitRequestFormParams, + ElicitResult, + Implementation, + InputRequiredResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + ReadResourceResult, + ResourceListChangedNotification, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + ServerCapabilities, + ServerNotification, + TextContent, + TextResourceContents, + Tool, + ToolListChangedNotification, +) +from mcp_types.version import LATEST_MODERN_VERSION + +from mcp.client import Client +from mcp.client._transport import TransportStreams +from mcp.client.caching import ( + CacheConfig, + CacheEntry, + CacheKey, + ClientResponseCache, + InMemoryResponseCacheStore, +) +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.caching import CacheHint +from mcp.shared.exceptions import MCPError +from mcp.shared.memory import MessageStream, create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from tests.interaction._connect import BASE_URL, mounted_app + +pytestmark = pytest.mark.anyio + +IncomingMessage = RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception + + +def _coordinator(client: Client) -> ClientResponseCache: + cache = client._response_cache + assert cache is not None + return cache + + +def _private_arm(client: Client) -> str: + """The identity arm stamped into store keys; only equality between clients matters here.""" + return _coordinator(client)._arm("private") + + +def _tools_list_key(client: Client) -> CacheKey: + return CacheKey("tools/list", "", _private_arm(client)) + + +class _OpaqueTransport: + """Shape-only `Transport`: identity resolution happens at construction, so tests never enter it.""" + + async def __aenter__(self) -> TransportStreams: + raise NotImplementedError + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> None: + raise NotImplementedError + + +def _list_changed_server() -> Server[Any]: + """Server whose `touch` tool emits tools/list_changed; connect with `mode="legacy"` + because the modern in-process path drops standalone server notifications.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="touched")]) + + return Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) + + +async def _warm_tools_list_entry(client: Client) -> CacheKey: + """Seed a private-arm tools/list entry directly in the store; payload and expiry are inert to eviction.""" + key = _tools_list_key(client) + await _coordinator(client)._store.set(key, CacheEntry(value="warm", scope="private", expires_at=None)) + return key + + +def test_an_explicit_target_id_overrides_both_url_and_in_process_identity() -> None: + by_target_url = Client("https://example.com/mcp", cache=CacheConfig(target_id="svc")) + by_target_inproc = Client(Server("plain"), cache=CacheConfig(target_id="svc")) + by_url = Client("https://example.com/mcp") + + assert _private_arm(by_target_url) == _private_arm(by_target_inproc) + assert _private_arm(by_target_url) != _private_arm(by_url) + + +def test_userinfo_variants_of_a_server_url_share_one_cache_identity() -> None: + """Stripping userinfo is the single permitted URL rewrite.""" + bare = Client("https://example.com/mcp") + with_password = Client("https://user:secret@example.com/mcp") + with_token = Client("https://token@example.com/mcp") + + assert _private_arm(bare) == _private_arm(with_password) == _private_arm(with_token) + + +@pytest.mark.parametrize( + ("with_userinfo", "bare"), + [ + ("HTTPS://a@X.example/mcp", "HTTPS://X.example/mcp"), + ("https://u@h/p?", "https://h/p?"), + ("https://u@h/p#", "https://h/p#"), + ("https://u\tser:p@h.example/p", "https://h.example/p"), + ("https://u:p@h.example/pa\tth", "https://h.example/pa\tth"), + ], + ids=["scheme-case", "empty-query", "empty-fragment", "tab-in-userinfo", "tab-in-path"], +) +def test_stripping_userinfo_changes_no_other_byte_of_the_url(with_userinfo: str, bare: str) -> None: + """The removed `userinfo@` is the only byte difference: no scheme case-folding, no dropped + empty `?`/`#` delimiters, and control characters - which urlsplit would silently strip, + misaligning any parser-derived slice - stay byte-exact outside the removed span. A + userinfo-free URL passes through untouched, so arm equality proves the stripped form is + byte-identical to the bare URL.""" + assert _private_arm(Client(with_userinfo)) == _private_arm(Client(bare)) + + +def test_a_url_without_an_authority_passes_through_unchanged() -> None: + """No `//` means no authority span, so an `@` elsewhere strips nothing.""" + arm_id = hashlib.sha256(b"mailto:a@b").hexdigest() + assert _private_arm(Client("mailto:a@b")) == json.dumps(["private", None, arm_id, ""]) + + +def test_the_server_url_is_sha256_hashed_before_it_enters_key_material() -> None: + """Pins the docs' secrets-never-in-keys claim: a query-string secret never appears in store keys.""" + client = Client("https://user:pass@example.com/mcp?api_key=SECRET") + + arm_id = hashlib.sha256(b"https://example.com/mcp?api_key=SECRET").hexdigest() + # The era slot is None pre-connect; only the identity hash matters here. + assert _private_arm(client) == json.dumps(["private", None, arm_id, ""]) + + +def test_urls_differing_only_in_query_have_distinct_cache_identities() -> None: + """URL identity is byte-exact outside userinfo; over-normalization would merge tenants.""" + tenant_a = Client("https://example.com/mcp?tenant=a") + tenant_b = Client("https://example.com/mcp?tenant=b") + + assert _private_arm(tenant_a) != _private_arm(tenant_b) + + +def test_two_clients_on_one_in_process_server_get_distinct_cache_identities() -> None: + server = Server("plain") + + assert _private_arm(Client(server)) != _private_arm(Client(server)) + + +def test_a_transport_object_gets_a_per_client_cache_identity() -> None: + transport = _OpaqueTransport() + + assert _private_arm(Client(transport)) != _private_arm(Client(transport)) + + +@pytest.mark.parametrize("make_server", [lambda: Server("plain"), _OpaqueTransport], ids=["in-process", "transport"]) +def test_a_custom_store_without_a_url_or_target_id_is_rejected(make_server: Any) -> None: + with pytest.raises(ValueError) as exc_info: + Client(make_server(), cache=CacheConfig(store=InMemoryResponseCacheStore(), partition="p")) + assert str(exc_info.value) == snapshot( + "a custom cache store requires CacheConfig.target_id when the server is not a URL: in-process servers " + "and Transport instances get a random per-client identity, so their entries in a shared store could " + "never be served to another client" + ) + + +def test_a_custom_store_with_a_url_server_constructs_and_is_used() -> None: + store = InMemoryResponseCacheStore() + client = Client("https://example.com/mcp", cache=CacheConfig(store=store, partition="p")) + + assert _coordinator(client)._store is store + + +def test_a_custom_store_with_an_explicit_target_id_constructs_for_any_server() -> None: + store = InMemoryResponseCacheStore() + client = Client(Server("plain"), cache=CacheConfig(store=store, partition="p", target_id="svc")) + + assert _coordinator(client)._store is store + + +async def test_cache_false_disables_the_cache_and_the_handler_wrap() -> None: + async def handler(message: IncomingMessage) -> None: + raise NotImplementedError + + client = Client(_list_changed_server(), cache=False, message_handler=handler) + assert client._response_cache is None + + async with client: + assert client.session._message_handler is handler + + +def test_the_default_cache_uses_a_per_client_in_memory_store() -> None: + """`cache=None` (the default) is cache-on.""" + server = Server("plain") + first = Client(server) + second = Client(server) + + assert isinstance(_coordinator(first)._store, InMemoryResponseCacheStore) + assert _coordinator(first)._store is not _coordinator(second)._store + + +async def test_the_negotiated_version_supplier_tracks_the_session_lifecycle() -> None: + """The era gate must never read a stale or raising source.""" + client = Client(_list_changed_server()) + supplier = _coordinator(client)._negotiated_version + + assert supplier() is None + async with client: + assert supplier() == client.protocol_version + assert supplier() is None + + +async def test_a_list_changed_notification_evicts_without_a_user_handler() -> None: + """Spec SHOULD (notifications invalidate): the entry is deleted from both arms.""" + + class _EventedStore(InMemoryResponseCacheStore): + """Signals once both arms of an eviction have been deleted.""" + + def __init__(self) -> None: + super().__init__() + self._deletes = 0 + self.both_arms_deleted = anyio.Event() + + async def delete(self, key: CacheKey) -> None: + await super().delete(key) + self._deletes += 1 + if self._deletes == 2: + self.both_arms_deleted.set() + + store = _EventedStore() + client = Client( + _list_changed_server(), mode="legacy", cache=CacheConfig(store=store, partition="p", target_id="svc") + ) + + async with client: + key = await _warm_tools_list_entry(client) + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await store.both_arms_deleted.wait() + assert await store.get(key) is None + + +async def test_a_user_handler_receives_the_notification_the_eviction_consumed() -> None: + """Eviction is a tee, not a filter.""" + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + client = Client(_list_changed_server(), mode="legacy", message_handler=collect) + + async with client: + key = await _warm_tools_list_entry(client) + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await seen.wait() + # The wrap evicts before delegating: delivery implies the entry is gone. + assert await _coordinator(client)._store.get(key) is None + + assert received == snapshot([ToolListChangedNotification()]) + + +async def test_non_notification_items_pass_through_to_the_user_handler_untouched() -> None: + """Transport `Exception` items can't occur in-process, so the installed handler is invoked directly.""" + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + client = Client(_list_changed_server(), message_handler=collect) + + async with client: + installed = client.session._message_handler + assert installed is not collect # the wrap, not the bare user handler + key = await _warm_tools_list_entry(client) + fault = RuntimeError("stream broke") + await installed(fault) + assert received == [fault] + assert await _coordinator(client)._store.get(key) is not None + + +async def test_a_raising_eviction_does_not_block_notification_delivery(caplog: pytest.LogCaptureFixture) -> None: + class _ExplodingCache(ClientResponseCache): + async def evict_for_notification(self, notification: ServerNotification) -> None: + raise RuntimeError("cache bug") + + received: list[IncomingMessage] = [] + seen = anyio.Event() + + async def collect(message: IncomingMessage) -> None: + received.append(message) + seen.set() + + client = Client(_list_changed_server(), mode="legacy", message_handler=collect) + # The wrap reads `_response_cache` at session build, so the swap must happen pre-enter. + client._response_cache = _ExplodingCache( + store=InMemoryResponseCacheStore(), + partition="", + arm_id="arm", + default_ttl_ms=0, + clock=time.time, + share_public=False, + negotiated_version=lambda: None, + ) + + async with client: + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await seen.wait() + + assert received == snapshot([ToolListChangedNotification()]) + assert "Response cache eviction failed; the notification is still delivered" in [ + record.message for record in caplog.records + ] + + +# --- The cacheable verbs --- + + +class _ManualClock: + """Injected wall clock: tests advance `now` instead of sleeping.""" + + def __init__(self) -> None: + self.now = 1_000_000.0 + + def __call__(self) -> float: + return self.now + + +def _varying_tools_server( + *, ttl_ms: int = 60_000, scope: Literal["public", "private"] = "private" +) -> tuple[Server[Any], list[str | None]]: + """Server whose every tools/list fetch returns a distinct tool name `t`, + so a served entry is distinguishable from a refetch by payload.""" + fetches: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + return ListToolsResult(tools=[Tool(name=f"t{len(fetches) - 1}", input_schema={"type": "object"})]) + + server = Server( + "varying", on_list_tools=list_tools, cache_hints={"tools/list": CacheHint(ttl_ms=ttl_ms, scope=scope)} + ) + return server, fetches + + +def _tool_names(result: ListToolsResult) -> list[str]: + return [tool.name for tool in result.tools] + + +async def test_a_second_list_tools_within_the_ttl_is_served_from_the_cache() -> None: + """SEP-2549: a result carrying a `ttlMs` hint is reusable until it expires.""" + server, fetches = _varying_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + first = await client.list_tools() + second = await client.list_tools() + + assert fetches == [None] + assert second == first + + +async def test_an_expired_entry_is_refetched() -> None: + """Freshness is strict: at exactly `ttlMs` the entry is expired.""" + clock = _ManualClock() + server, fetches = _varying_tools_server(ttl_ms=60_000) + + async with Client(server, cache=CacheConfig(clock=clock)) as client: + assert _tool_names(await client.list_tools()) == ["t0"] + clock.now += 60.0 + assert _tool_names(await client.list_tools()) == ["t1"] + + assert fetches == [None, None] + + +async def test_each_list_verb_caches_independently_under_its_own_method() -> None: + """Cache keys discriminate by method (spec MUST).""" + fetched: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetched.append("tools/list") + return ListToolsResult(tools=[]) + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + fetched.append("prompts/list") + return ListPromptsResult(prompts=[]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + fetched.append("resources/list") + return ListResourcesResult(resources=[]) + + async def list_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + fetched.append("resources/templates/list") + return ListResourceTemplatesResult(resource_templates=[]) + + hint = CacheHint(ttl_ms=60_000) + server = Server( + "all-lists", + on_list_tools=list_tools, + on_list_prompts=list_prompts, + on_list_resources=list_resources, + on_list_resource_templates=list_templates, + cache_hints={ + "tools/list": hint, + "prompts/list": hint, + "resources/list": hint, + "resources/templates/list": hint, + }, + ) + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + await client.list_tools() + await client.list_prompts() + await client.list_resources() + await client.list_resource_templates() + await client.list_tools() + await client.list_prompts() + await client.list_resources() + await client.list_resource_templates() + + assert fetched == ["tools/list", "prompts/list", "resources/list", "resources/templates/list"] + + +async def test_read_resource_caches_per_uri() -> None: + """Cache keys discriminate by result-affecting params (spec MUST).""" + reads: list[str] = [] + + async def read(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + reads.append(params.uri) + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text=params.uri)]) + + server = Server("res", on_read_resource=read, cache_hints={"resources/read": CacheHint(ttl_ms=60_000)}) + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + first_a = await client.read_resource("memo://a") + first_b = await client.read_resource("memo://b") + assert await client.read_resource("memo://a") == first_a + assert await client.read_resource("memo://b") == first_b + + assert reads == ["memo://a", "memo://b"] + + +def _paginated_tools_server() -> tuple[Server[Any], list[str | None]]: + """Cacheable first page; cursor "expired" -> INVALID_PARAMS (the spec's expired-cursor + signal), "fail" -> INTERNAL_ERROR.""" + fetches: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + cursor = params.cursor if params is not None else None + fetches.append(cursor) + if cursor is None: + first_page = Tool(name="first-page", input_schema={"type": "object"}) + return ListToolsResult(tools=[first_page], next_cursor="page-2") + if cursor == "page-2": + return ListToolsResult(tools=[Tool(name="second-page", input_schema={"type": "object"})]) + if cursor == "fail": + raise MCPError(code=INTERNAL_ERROR, message="transient failure") + raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {cursor!r}") + + server = Server("paginated", on_list_tools=list_tools, cache_hints={"tools/list": CacheHint(ttl_ms=60_000)}) + return server, fetches + + +async def test_cursor_continuations_neither_read_nor_write_the_cache() -> None: + """Only cursor-less calls participate in caching (SDK-defined single-page entry).""" + server, fetches = _paginated_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _tool_names(await client.list_tools()) == ["first-page"] + assert _tool_names(await client.list_tools(cursor="page-2")) == ["second-page"] + assert _tool_names(await client.list_tools()) == ["first-page"] # not overwritten by the continuation + + assert fetches == [None, "page-2"] + + +async def test_an_expired_cursor_rejection_evicts_the_methods_entry() -> None: + """Spec SHOULD: INVALID_PARAMS on a continuation cursor means the listing changed.""" + server, fetches = _paginated_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + await client.list_tools() + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="expired") + assert exc_info.value.code == INVALID_PARAMS + await client.list_tools() + + assert fetches == [None, "expired", None] + + +async def test_an_expired_cursor_rejection_under_bypass_does_not_evict() -> None: + """Bypass means no cache side-effects at all, eviction included.""" + server, fetches = _paginated_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + await client.list_tools() + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="expired", cache_mode="bypass") + assert exc_info.value.code == INVALID_PARAMS + await client.list_tools() # still served from the warm entry + + assert fetches == [None, "expired"] + + +async def test_a_non_cursor_error_on_a_continuation_does_not_evict() -> None: + """Only INVALID_PARAMS signals cursor expiry.""" + server, fetches = _paginated_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + await client.list_tools() + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="fail") + assert exc_info.value.code == INTERNAL_ERROR + await client.list_tools() # still served from the warm entry + + assert fetches == [None, "fail"] + + +async def test_bypass_neither_serves_nor_disturbs_a_warm_entry() -> None: + server, fetches = _varying_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _tool_names(await client.list_tools()) == ["t0"] + assert _tool_names(await client.list_tools(cache_mode="bypass")) == ["t1"] + assert _tool_names(await client.list_tools()) == ["t0"] # warm entry intact + + assert fetches == [None, None] + + +async def test_refresh_skips_the_read_and_stores_the_refetched_result() -> None: + server, fetches = _varying_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _tool_names(await client.list_tools()) == ["t0"] + assert _tool_names(await client.list_tools(cache_mode="refresh")) == ["t1"] + assert _tool_names(await client.list_tools()) == ["t1"] + + assert fetches == [None, None] + + +async def test_refresh_storing_a_ttl_zero_result_purges_the_warm_entry() -> None: + """An uncacheable refetch still supersedes the warm entry.""" + fetches: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + ttl_ms = 60_000 if len(fetches) == 1 else 0 + tool = Tool(name=f"t{len(fetches) - 1}", input_schema={"type": "object"}) + return ListToolsResult(tools=[tool], ttl_ms=ttl_ms) + + server = Server("flip", on_list_tools=list_tools) + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _tool_names(await client.list_tools()) == ["t0"] + assert _tool_names(await client.list_tools(cache_mode="refresh")) == ["t1"] + assert _tool_names(await client.list_tools()) == ["t2"] # t0 purged, t1 (ttl 0) never stored + + assert fetches == [None, None, None] + + +async def test_a_list_call_carrying_meta_is_fetched_and_replaces_the_warm_entry() -> None: + """SDK-defined: `meta` (a progress token, tracing fields) expects a wire request, + so under the default "use" the call behaves as a refresh.""" + server, fetches = _varying_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _tool_names(await client.list_tools()) == ["t0"] + assert _tool_names(await client.list_tools()) == ["t0"] # warm, meta-less: served + assert _tool_names(await client.list_tools(meta={"progress_token": "tok"})) == ["t1"] # meta: fetched + assert _tool_names(await client.list_tools()) == ["t1"] # the fresh result replaced the entry + + assert fetches == [None, None] + + +async def test_a_read_resource_carrying_meta_is_fetched_and_replaces_the_warm_entry() -> None: + reads: list[str] = [] + + async def read(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + reads.append(params.uri) + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text=f"v{len(reads)}")], ttl_ms=60_000) + + server = Server("versioned-reads", on_read_resource=read) + + def text(result: ReadResourceResult) -> str: + content = result.contents[0] + assert isinstance(content, TextResourceContents) + return content.text + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert text(await client.read_resource("memo://a")) == "v1" + assert text(await client.read_resource("memo://a")) == "v1" # warm, meta-less: served + assert text(await client.read_resource("memo://a", meta={"progress_token": "tok"})) == "v2" # meta: fetched + assert text(await client.read_resource("memo://a")) == "v2" # the fresh result replaced the entry + + assert reads == ["memo://a", "memo://a"] + + +async def test_cache_mode_is_inert_when_caching_is_disabled() -> None: + server, fetches = _varying_tools_server() + + async with Client(server, cache=False) as client: + await client.list_tools() + await client.list_tools(cache_mode="use") + await client.list_tools(cache_mode="refresh") + + assert fetches == [None, None, None] + + +@pytest.mark.parametrize( + "seed", + [{"request_state": "round-2"}, {"input_responses": {"ask": ElicitResult(action="decline")}}], + ids=["request_state", "input_responses"], +) +async def test_a_seeded_read_resource_skips_the_cache_and_ignores_cache_mode(seed: dict[str, Any]) -> None: + """Spec MUST: results of requests carrying `inputResponses` or `requestState` are never cached.""" + reads = 0 + + async def read(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + nonlocal reads + reads += 1 + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text=f"v{reads}")], ttl_ms=60_000) + + server = Server("res", on_read_resource=read) + + def text(result: ReadResourceResult) -> str: + content = result.contents[0] + assert isinstance(content, TextResourceContents) + return content.text + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert text(await client.read_resource("memo://a")) == "v1" + assert text(await client.read_resource("memo://a", **seed)) == "v2" + assert text(await client.read_resource("memo://a", **seed, cache_mode="refresh")) == "v3" + assert text(await client.read_resource("memo://a")) == "v1" # nothing read, written, or purged + + assert reads == 3 + + +async def test_a_terminal_read_reached_through_driver_rounds_is_never_cached() -> None: + """Spec MUST: the driver's retry rounds carry `inputResponses`, so their terminal result is not cached.""" + seeded_rounds: list[bool] = [] + ask = ElicitRequest( + params=ElicitRequestFormParams( + message="What is your name?", + requested_schema={"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + ) + + async def read( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> ReadResourceResult | InputRequiredResult: + seeded_rounds.append(params.input_responses is not None) + if params.input_responses is not None: + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text="terminal")], ttl_ms=60_000) + return InputRequiredResult(input_requests={"ask": ask}) + + async def elicitation_callback( + context: Any, params: types.ElicitRequestParams + ) -> types.ElicitResult | types.ErrorData: + return ElicitResult(action="accept", content={"name": "Ada"}) + + server = Server("gated", on_read_resource=read) + + with anyio.fail_after(5): + async with Client( + server, elicitation_callback=elicitation_callback, cache=CacheConfig(clock=_ManualClock()) + ) as client: + first = await client.read_resource("memo://gated") + second = await client.read_resource("memo://gated") + + assert isinstance(first.contents[0], TextResourceContents) and first.contents[0].text == "terminal" + assert second == first + assert seeded_rounds == [False, True, False, True] # two wire rounds per call: never served + + +async def test_a_refresh_that_resolves_to_input_required_purges_the_warm_entry() -> None: + """The refresh cannot store its driven terminal result (the rounds carry + `inputResponses`, a spec MUST), but it still purges the warm entry.""" + reads = 0 + ask = ElicitRequest( + params=ElicitRequestFormParams( + message="What is your name?", + requested_schema={"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + ) + + async def read( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> ReadResourceResult | InputRequiredResult: + nonlocal reads + reads += 1 + # Starts plain, then flips to requiring input. + if reads > 1 and params.input_responses is None: + return InputRequiredResult(input_requests={"ask": ask}) + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text=f"v{reads}")], ttl_ms=60_000) + + async def elicitation_callback( + context: Any, params: types.ElicitRequestParams + ) -> types.ElicitResult | types.ErrorData: + return ElicitResult(action="accept", content={"name": "Ada"}) + + server = Server("flipping", on_read_resource=read) + + def text(result: ReadResourceResult) -> str: + content = result.contents[0] + assert isinstance(content, TextResourceContents) + return content.text + + with anyio.fail_after(5): + async with Client( + server, elicitation_callback=elicitation_callback, cache=CacheConfig(clock=_ManualClock()) + ) as client: + assert text(await client.read_resource("memo://a")) == "v1" # cached for 60s + assert text(await client.read_resource("memo://a", cache_mode="refresh")) == "v3" + # v1 purged and v3 never stored: the plain read drives fresh rounds. + assert text(await client.read_resource("memo://a")) == "v5" + + assert reads == 5 + + +def _output_schema_server(call_result: CallToolResult) -> tuple[Server[Any], list[str | None]]: + """One tool declaring an output schema; `call_tool` returns the canned `call_result`.""" + fetches: list[str | None] = [] + tool = Tool( + name="run", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"n": {"type": "integer"}}, "required": ["n"]}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + return ListToolsResult(tools=[tool]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "run" + return call_result + + server = Server( + "schemas", + on_list_tools=list_tools, + on_call_tool=call_tool, + cache_hints={"tools/list": CacheHint(ttl_ms=60_000)}, + ) + return server, fetches + + +async def test_a_listing_served_from_a_shared_store_rebuilds_output_schemas() -> None: + """A served listing is absorbed into the session: output validation works without a wire fetch.""" + call_result = CallToolResult(content=[TextContent(text="ok")], structured_content={"n": 1}) + server, fetches = _output_schema_server(call_result) + config = CacheConfig(store=InMemoryResponseCacheStore(), partition="p", target_id="svc", clock=_ManualClock()) + + async with Client(server, cache=config) as warming: + listing = await warming.list_tools() + + async with Client(server, cache=config) as fresh: + assert await fresh.list_tools() == listing # served from the shared store + result = await fresh.call_tool("run", {}) + + assert result.structured_content == {"n": 1} + # A starved schema cache would have re-listed here. + assert fetches == [None] + + +async def test_validation_from_a_served_listing_rejects_missing_structured_content() -> None: + """The schema absorbed from a served listing is enforced, not just present.""" + server, fetches = _output_schema_server(CallToolResult(content=[TextContent(text="ok")])) + config = CacheConfig(store=InMemoryResponseCacheStore(), partition="p", target_id="svc", clock=_ManualClock()) + + async with Client(server, cache=config) as warming: + await warming.list_tools() + + async with Client(server, cache=config) as fresh: + await fresh.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await fresh.call_tool("run", {}) + + assert str(exc_info.value) == snapshot("Tool run has an output schema but did not return structured content") + assert fetches == [None] + + +async def test_a_cache_hit_listing_still_mirrors_x_mcp_headers_on_tools_call() -> None: + """The arg-to-header maps are rebuilt from a served listing. Asserted at the wire + because the client never surfaces outgoing headers.""" + tool = Tool( + name="run", + input_schema={"type": "object", "properties": {"region": {"type": "string", "x-mcp-header": "Region"}}}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[tool], ttl_ms=60_000) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "run" + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("headers", on_list_tools=list_tools, on_call_tool=call_tool) + + posts: list[httpx.Request] = [] + + async def on_request(request: httpx.Request) -> None: + posts.append(request) + + config = CacheConfig(store=InMemoryResponseCacheStore(), partition="p", target_id="svc") + discover = DiscoverResult( + supported_versions=[LATEST_MODERN_VERSION], + capabilities=ServerCapabilities(), + server_info=Implementation(name="srv", version="0"), + ) + + with anyio.fail_after(5): + async with mounted_app(server, on_request=on_request) as (http, _): + warming = Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http), + mode=LATEST_MODERN_VERSION, + prior_discover=discover, + cache=config, + ) + async with warming: + await warming.list_tools() + fresh = Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http), + mode=LATEST_MODERN_VERSION, + prior_discover=discover, + cache=config, + ) + async with fresh: + await fresh.list_tools() + await fresh.call_tool("run", {"region": "us-west1"}) + + # One tools/list on the wire: the fresh client served from the store. + assert [json.loads(request.content)["method"] for request in posts] == ["tools/list", "tools/call"] + assert posts[-1].headers["mcp-param-region"] == "us-west1" + + +async def test_a_shared_store_hit_prunes_a_header_map_the_writers_filter_dropped() -> None: + """Cached listings are post-filter: when another client's refresh wrote a listing whose + filter dropped tool `x` (its annotation went invalid), a hit on that entry must prune the + reader's stale arg-to-header map, or it would keep emitting Mcp-Param-* headers for `x`.""" + valid = {"type": "object", "properties": {"region": {"type": "string", "x-mcp-header": "Region"}}} + invalid = {"type": "object", "properties": {"region": {"type": "string", "x-mcp-header": "bad name"}}} + schema = valid + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="x", input_schema=schema)]) + + server = Server("filtering", on_list_tools=list_tools, cache_hints={"tools/list": CacheHint(ttl_ms=60_000)}) + config = CacheConfig(store=InMemoryResponseCacheStore(), partition="p", target_id="svc", clock=_ManualClock()) + + with anyio.fail_after(5): + async with Client(server, cache=config) as reader, Client(server, cache=config) as writer: + await reader.list_tools() # fetches while `x` is valid; the reader holds its header map + assert "x" in reader.session._x_mcp_header_maps + + schema = invalid + await writer.list_tools(cache_mode="refresh") # the writer's filter drops `x`; the entry is replaced + + served = await reader.list_tools() # hit on the writer's entry + assert served.tools == [] + assert "x" not in reader.session._x_mcp_header_maps + + +async def test_a_tools_list_changed_notification_makes_the_next_list_refetch() -> None: + """Spec SHOULD: list_changed invalidates the cached listing. Legacy session + + `default_ttl_ms` entry: eviction is era-independent.""" + fetches: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + return ListToolsResult(tools=[Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("notify", on_list_tools=list_tools, on_call_tool=call_tool) + + # The wrap evicts before delegating: delivery implies eviction completed. + delivered = anyio.Event() + + async def on_message(message: IncomingMessage) -> None: + assert isinstance(message, ToolListChangedNotification) # the only message this server emits + delivered.set() + + client = Client(server, mode="legacy", cache=CacheConfig(default_ttl_ms=60_000), message_handler=on_message) + async with client: + await client.list_tools() + await client.list_tools() + assert fetches == [None] # cached via default_ttl_ms + await client.call_tool("touch", {}) + with anyio.fail_after(5): + await delivered.wait() + await client.list_tools() + + assert fetches == [None, None] + + +async def test_a_resource_updated_notification_evicts_that_uris_read_entry() -> None: + """Spec SHOULD: resources/updated invalidates the cached read for its uri, + and the notification's `params.uri` must match the stored key's uri form.""" + uri = "memo://cached" + reads: list[str] = [] + + async def read(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + reads.append(params.uri) + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text=f"v{len(reads)}")]) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="poke", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "poke" + await ctx.session.send_resource_updated(uri) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("updates", on_read_resource=read, on_list_tools=list_tools, on_call_tool=call_tool) + + delivered: list[str] = [] + seen = anyio.Event() + + async def on_message(message: IncomingMessage) -> None: + assert isinstance(message, ResourceUpdatedNotification) # the only message this server emits + delivered.append(message.params.uri) + seen.set() + + client = Client(server, mode="legacy", cache=CacheConfig(default_ttl_ms=60_000), message_handler=on_message) + async with client: + await client.read_resource(uri) + await client.read_resource(uri) + assert reads == [uri] # cached via default_ttl_ms + await client.call_tool("poke", {}) + with anyio.fail_after(5): + await seen.wait() + await client.read_resource(uri) + + assert delivered == [uri] # the exact string the entry was stored under + assert reads == [uri, uri] + + +async def test_the_modern_in_process_path_drops_the_eviction_notification() -> None: + """Pins the documented gap: the default in-process path (DirectDispatcher) drops + standalone notifications, so the warm entry survives. If this starts failing the + path gained delivery: flip the `docs/advanced/caching.md` caveat and the legacy-mode tests.""" + fetches: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + return ListToolsResult(tools=[Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server( + "notify", + on_list_tools=list_tools, + on_call_tool=call_tool, + cache_hints={"tools/list": CacheHint(ttl_ms=60_000)}, + ) + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + await client.list_tools() + await client.call_tool("touch", {}) + await client.list_tools() # still served from the warm entry: no eviction arrived + + assert fetches == [None] + + +async def test_a_discover_result_never_enters_the_response_cache() -> None: + """SDK ruling (documented): the cache covers the five verbs only; a persisted + `prior_discover`'s freshness is the user's bookkeeping.""" + server = Server("hinted", cache_hints={"server/discover": CacheHint(ttl_ms=60_000)}) + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + discover = client.session.discover_result + assert discover is not None + assert discover.ttl_ms == 60_000 # the hint arrived with the probe result... + store = _coordinator(client)._store + assert isinstance(store, InMemoryResponseCacheStore) + assert store._entries == {} # ...and nothing entered the cache + + +# --- The inbound ttlMs clamp (parse seam) --- + + +@pytest.mark.parametrize("wire_ttl", [-5, -5.0]) +async def test_a_negative_inbound_ttl_is_served_as_zero_and_never_cached(wire_ttl: int | float) -> None: + """Spec SHOULD: a negative `ttlMs` is treated as 0, not a wire-validation failure. + Scripted peer: an SDK server enforces `ge=0` and cannot emit one.""" + listings_served = 0 + + async def scripted_server(streams: MessageStream) -> None: + nonlocal listings_served + server_read, server_write = streams + async for message in server_read: + assert isinstance(message, SessionMessage) + frame = message.message + assert isinstance(frame, types.JSONRPCRequest) + if frame.method == "server/discover": + result: dict[str, Any] = { + "supportedVersions": [LATEST_MODERN_VERSION], + "capabilities": {}, + "serverInfo": {"name": "negative-ttl", "version": "0.0.1"}, + "resultType": "complete", + "ttlMs": 0, + } + else: + assert frame.method == "tools/list" + listings_served += 1 + result = {"resultType": "complete", "tools": [], "ttlMs": wire_ttl, "cacheScope": "private"} + await server_write.send(SessionMessage(types.JSONRPCResponse(jsonrpc="2.0", id=frame.id, result=result))) + + @asynccontextmanager + async def scripted_transport() -> AsyncIterator[TransportStreams]: + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ): + tg.start_soon(scripted_server, server_streams) + yield client_read, client_write + tg.cancel_scope.cancel() + + with anyio.fail_after(5): + async with Client(scripted_transport(), mode="auto") as client: + first = await client.list_tools() + second = await client.list_tools() + + assert first.ttl_ms == 0 + assert second.ttl_ms == 0 + assert listings_served == 2 # the clamped-to-zero ttl was never stored + + +@pytest.mark.parametrize("wire_ttl", [-5, -5.0]) +async def test_a_negative_discover_ttl_still_connects_modern_in_auto_mode(wire_ttl: int | float) -> None: + """Regression: pre-clamp, a negative discover `ttlMs` failed validation inside the + mode="auto" probe and silently downgraded to the legacy handshake.""" + methods_seen: list[str] = [] + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + async for message in server_read: + assert isinstance(message, SessionMessage) + frame = message.message + assert isinstance(frame, types.JSONRPCRequest) + methods_seen.append(frame.method) + # A legacy downgrade would send `initialize`; fail loudly instead. + assert frame.method == "server/discover" + result: dict[str, Any] = { + "supportedVersions": [LATEST_MODERN_VERSION], + "capabilities": {}, + "serverInfo": {"name": "negative-ttl", "version": "0.0.1"}, + "resultType": "complete", + "ttlMs": wire_ttl, + } + await server_write.send(SessionMessage(types.JSONRPCResponse(jsonrpc="2.0", id=frame.id, result=result))) + + @asynccontextmanager + async def scripted_transport() -> AsyncIterator[TransportStreams]: + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ): + tg.start_soon(scripted_server, server_streams) + yield client_read, client_write + tg.cancel_scope.cancel() + + with anyio.fail_after(5): + async with Client(scripted_transport(), mode="auto") as client: + assert client.protocol_version == LATEST_MODERN_VERSION + discover = client.session.discover_result + assert discover is not None + assert discover.ttl_ms == 0 + + assert methods_seen == ["server/discover"] + + +# --- Hardening e2e --- + + +def _versioned_read_server(*, ttl_ms: int = 60_000) -> tuple[Server[Any], list[str]]: + """Server whose every read returns a distinct payload `v`, + so a served entry is distinguishable from a refetch.""" + reads: list[str] = [] + + async def read(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + reads.append(params.uri) + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text=f"v{len(reads)}")], ttl_ms=ttl_ms) + + return Server("versioned-reads", on_read_resource=read), reads + + +def _resource_text(result: ReadResourceResult) -> str: + content = result.contents[0] + assert isinstance(content, TextResourceContents) + return content.text + + +async def test_each_notification_evicts_exactly_its_entries_end_to_end() -> None: + """Spec SHOULD (notifications invalidate) plus its negative space: each notification + refetches exactly its own entries, and resources/list_changed also covers templates.""" + uri_x, uri_y = "memo://x", "memo://y" + fetched: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetched.append("tools/list") + return ListToolsResult(tools=[Tool(name="notify", input_schema={"type": "object"})]) + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + fetched.append("prompts/list") + return ListPromptsResult(prompts=[]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + fetched.append("resources/list") + return ListResourcesResult(resources=[]) + + async def list_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + fetched.append("resources/templates/list") + return ListResourceTemplatesResult(resource_templates=[]) + + async def read(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + fetched.append(f"resources/read {params.uri}") + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text="body")]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "notify" + kind = (params.arguments or {})["kind"] + if kind == "tools": + await ctx.session.send_tool_list_changed() + elif kind == "resources": + await ctx.session.send_resource_list_changed() + else: + assert kind == "updated-x" + await ctx.session.send_resource_updated(uri_x) + return CallToolResult(content=[TextContent(text="sent")]) + + server = Server( + "notifier", + on_list_tools=list_tools, + on_list_prompts=list_prompts, + on_list_resources=list_resources, + on_list_resource_templates=list_templates, + on_read_resource=read, + on_call_tool=call_tool, + ) + + delivered: list[IncomingMessage] = [] + eviction_done = [anyio.Event() for _ in range(3)] + + async def on_message(message: IncomingMessage) -> None: + # The wrap evicts before delegating: each event implies its eviction completed. + delivered.append(message) + eviction_done[len(delivered) - 1].set() + + client = Client( + server, + mode="legacy", + cache=CacheConfig(default_ttl_ms=60_000, clock=_ManualClock()), + message_handler=on_message, + ) + + async with client: + + async def served_round() -> list[str]: + """Call every cacheable verb once; return the calls that reached the server.""" + before = len(fetched) + await client.list_tools() + await client.list_prompts() + await client.list_resources() + await client.list_resource_templates() + await client.read_resource(uri_x) + await client.read_resource(uri_y) + return fetched[before:] + + assert await served_round() == [ + "tools/list", + "prompts/list", + "resources/list", + "resources/templates/list", + f"resources/read {uri_x}", + f"resources/read {uri_y}", + ] + assert await served_round() == [] # everything primed and served + + await client.call_tool("notify", {"kind": "tools"}) + with anyio.fail_after(5): + await eviction_done[0].wait() + assert await served_round() == ["tools/list"] + + await client.call_tool("notify", {"kind": "resources"}) + with anyio.fail_after(5): + await eviction_done[1].wait() + assert await served_round() == ["resources/list", "resources/templates/list"] + + await client.call_tool("notify", {"kind": "updated-x"}) + with anyio.fail_after(5): + await eviction_done[2].wait() + assert await served_round() == [f"resources/read {uri_x}"] + + assert delivered == [ + ToolListChangedNotification(), + ResourceListChangedNotification(), + ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri=uri_x)), + ] + + +async def test_private_entries_never_cross_partitions_between_clients_sharing_a_store() -> None: + """Spec MUST: "private" never crosses authorization contexts.""" + server, fetches = _varying_tools_server() + store = InMemoryResponseCacheStore() + + def config(partition: str) -> CacheConfig: + return CacheConfig(store=store, partition=partition, target_id="svc", clock=_ManualClock()) + + async with Client(server, cache=config("tenant-a")) as tenant_a: + assert _tool_names(await tenant_a.list_tools()) == ["t0"] + async with Client(server, cache=config("tenant-b")) as tenant_b: + assert _tool_names(await tenant_b.list_tools()) == ["t1"] # fetched, not tenant-a's entry + + assert fetches == [None, None] + + +async def test_a_server_stamped_public_entry_does_not_cross_partitions_by_default() -> None: + """SDK security default (deviates from the ts SDK): the public arm is still keyed by partition.""" + server, fetches = _varying_tools_server(scope="public") + store = InMemoryResponseCacheStore() + + def config(partition: str) -> CacheConfig: + return CacheConfig(store=store, partition=partition, target_id="svc", clock=_ManualClock()) + + async with Client(server, cache=config("tenant-a")) as tenant_a: + assert _tool_names(await tenant_a.list_tools()) == ["t0"] + async with Client(server, cache=config("tenant-a")) as same_partition: + assert _tool_names(await same_partition.list_tools()) == ["t0"] # served from the store + async with Client(server, cache=config("tenant-b")) as tenant_b: + assert _tool_names(await tenant_b.list_tools()) == ["t1"] # fetched + + assert fetches == [None, None] + + +async def test_share_public_serves_a_server_stamped_public_entry_across_partitions() -> None: + """With `share_public=True` the public arm drops the partition.""" + server, fetches = _varying_tools_server(scope="public") + store = InMemoryResponseCacheStore() + + def config(partition: str) -> CacheConfig: + return CacheConfig(store=store, partition=partition, target_id="svc", share_public=True, clock=_ManualClock()) + + async with Client(server, cache=config("tenant-a")) as tenant_a: + assert _tool_names(await tenant_a.list_tools()) == ["t0"] + async with Client(server, cache=config("tenant-b")) as tenant_b: + assert _tool_names(await tenant_b.list_tools()) == ["t0"] # served across partitions + + assert fetches == [None] + + +async def test_same_partition_clients_share_read_entries_through_the_store() -> None: + server, reads = _versioned_read_server() + store = InMemoryResponseCacheStore() + + def config() -> CacheConfig: + return CacheConfig(store=store, partition="p", target_id="svc", clock=_ManualClock()) + + async with Client(server, cache=config()) as first: + first_result = await first.read_resource("memo://a") + async with Client(server, cache=config()) as second: + assert await second.read_resource("memo://a") == first_result + + assert reads == ["memo://a"] + + +async def test_mutating_returned_results_never_corrupts_the_cached_entry() -> None: + """Deep-copy isolation in both directions: write-side (the fetched result) and + serve-side (the served hit).""" + server, fetches = _varying_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + first = await client.list_tools() + first.tools[0].name = "tampered-after-fetch" + second = await client.list_tools() # cache hit, unaffected by the mutation + assert _tool_names(second) == ["t0"] + second.tools[0].name = "tampered-after-serve" + assert _tool_names(await client.list_tools()) == ["t0"] # still pristine + + assert fetches == [None] + + +async def test_a_cache_hit_still_yields_to_the_event_loop() -> None: + """A hit completes without a wire await, so the verb checkpoints explicitly: a poll + loop over a fresh entry would otherwise starve spawned tasks (eviction dispatch). + Pinned by calling a warm verb inside an already-cancelled scope: only a yield can + observe the cancellation.""" + server, fetches = _varying_tools_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _tool_names(await client.list_tools()) == ["t0"] # warm the entry + with anyio.CancelScope() as scope: + scope.cancel() + await client.list_tools() # would be a hit; must yield and observe the cancellation + assert scope.cancelled_caught + + assert fetches == [None] # the cancelled call neither fetched nor served + + +async def test_a_legacy_peer_injecting_cache_hints_caches_nothing() -> None: + """Era gate: hint keys a 2025 peer puts on the wire cache nothing. Scripted peer: + an SDK server strips the hint fields when serializing for a 2025 session.""" + listings_served = 0 + + async def scripted_server(streams: MessageStream) -> None: + nonlocal listings_served + server_read, server_write = streams + async for message in server_read: + assert isinstance(message, SessionMessage) + frame = message.message + if isinstance(frame, types.JSONRPCNotification): + assert frame.method == "notifications/initialized" + continue + assert isinstance(frame, types.JSONRPCRequest) + if frame.method == "initialize": + result: dict[str, Any] = { + "protocolVersion": "2025-11-25", + "capabilities": {}, + "serverInfo": {"name": "legacy-injector", "version": "0.0.1"}, + } + else: + assert frame.method == "tools/list" + listings_served += 1 + result = {"tools": [], "ttlMs": 60_000, "cacheScope": "public"} + await server_write.send(SessionMessage(types.JSONRPCResponse(jsonrpc="2.0", id=frame.id, result=result))) + + @asynccontextmanager + async def scripted_transport() -> AsyncIterator[TransportStreams]: + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ): + tg.start_soon(scripted_server, server_streams) + yield client_read, client_write + tg.cancel_scope.cancel() + + with anyio.fail_after(5): + async with Client(scripted_transport(), mode="legacy", cache=CacheConfig(clock=_ManualClock())) as client: + await client.list_tools() + await client.list_tools() + store = _coordinator(client)._store + assert isinstance(store, InMemoryResponseCacheStore) + assert store._entries == {} # neither arm holds an entry + + assert listings_served == 2 + + +class _CancelOnSetStore(InMemoryResponseCacheStore): + """Store whose next `set` awaits a one-shot hook before committing.""" + + def __init__(self) -> None: + super().__init__() + self.before_set: Callable[[], Awaitable[None]] | None = None + + async def set(self, key: CacheKey, entry: CacheEntry) -> None: + if self.before_set is not None: + hook, self.before_set = self.before_set, None + await hook() + await super().set(key, entry) + + +async def test_a_verb_cancelled_mid_write_leaves_no_stale_arm_pair() -> None: + """No-stale-pair invariant: a cancellation between the opposite-arm delete and the + `set` commit leaves at most one entry per key, so the superseded entry cannot be served.""" + fetches: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + scope: Literal["public", "private"] = "public" if len(fetches) == 1 else "private" + tool = Tool(name=f"t{len(fetches) - 1}", input_schema={"type": "object"}) + return ListToolsResult(tools=[tool], ttl_ms=60_000, cache_scope=scope) + + server = Server("scope-flip", on_list_tools=list_tools) + store = _CancelOnSetStore() + client = Client(server, cache=CacheConfig(store=store, partition="p", target_id="svc", clock=_ManualClock())) + + async with client: + assert _tool_names(await client.list_tools()) == ["t0"] + assert len(store._entries) == 1 # the public-arm entry + + with anyio.CancelScope() as scope: + + async def cancel_mid_commit() -> None: + scope.cancel() + await anyio.lowlevel.checkpoint() # the cancellation is delivered here, inside `set` + + store.before_set = cancel_mid_commit + await client.list_tools(cache_mode="refresh") + assert scope.cancelled_caught + + # The opposite (public) arm was deleted before the cancelled set could commit. + assert store._entries == {} + assert _tool_names(await client.list_tools()) == ["t2"] # nothing cached: refetched + + assert fetches == [None, None, None] + + +async def test_an_eviction_landing_mid_fetch_discards_that_fetchs_write() -> None: + """Spec-aligned race rule: an eviction landing mid-fetch discards that fetch's write. + The server waits for the client-side eviction before responding, so the interleaving + is deterministic, not scheduler-dependent.""" + fetches: list[str | None] = [] + evicted = anyio.Event() + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + if len(fetches) == 1: + await ctx.session.send_tool_list_changed() + with anyio.fail_after(5): + await evicted.wait() + return ListToolsResult(tools=[Tool(name=f"t{len(fetches) - 1}", input_schema={"type": "object"})]) + + async def on_message(message: IncomingMessage) -> None: + assert isinstance(message, ToolListChangedNotification) # the only message this server emits + evicted.set() + + server = Server("racer", on_list_tools=list_tools) + client = Client( + server, + mode="legacy", + cache=CacheConfig(default_ttl_ms=60_000, clock=_ManualClock()), + message_handler=on_message, + ) + + async with client: + assert _tool_names(await client.list_tools()) == ["t0"] + # Empty proves the write was skipped, not stored-then-evicted: the eviction + # completed strictly before the response, the write strictly after. + store = _coordinator(client)._store + assert isinstance(store, InMemoryResponseCacheStore) + assert store._entries == {} + assert _tool_names(await client.list_tools()) == ["t1"] # refetched... + assert _tool_names(await client.list_tools()) == ["t1"] # ...and that fetch cached normally + + assert fetches == [None, None] + + +async def test_read_resource_bypass_neither_serves_nor_disturbs_a_warm_entry() -> None: + server, reads = _versioned_read_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _resource_text(await client.read_resource("memo://a")) == "v1" + assert _resource_text(await client.read_resource("memo://a", cache_mode="bypass")) == "v2" + assert _resource_text(await client.read_resource("memo://a")) == "v1" # warm entry intact + + assert reads == ["memo://a", "memo://a"] + + +async def test_read_resource_refresh_refetches_and_restores() -> None: + server, reads = _versioned_read_server() + + async with Client(server, cache=CacheConfig(clock=_ManualClock())) as client: + assert _resource_text(await client.read_resource("memo://a")) == "v1" + assert _resource_text(await client.read_resource("memo://a", cache_mode="refresh")) == "v2" + assert _resource_text(await client.read_resource("memo://a")) == "v2" # the refreshed value re-stored + + assert reads == ["memo://a", "memo://a"] + + +async def test_a_closed_client_raises_on_every_cacheable_verb_instead_of_serving_the_cache() -> None: + """Cache participation requires a live session.""" + fetched: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + fetched.append("tools/list") + return ListToolsResult(tools=[]) + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + fetched.append("prompts/list") + return ListPromptsResult(prompts=[]) + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + fetched.append("resources/list") + return ListResourcesResult(resources=[]) + + async def list_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + fetched.append("resources/templates/list") + return ListResourceTemplatesResult(resource_templates=[]) + + async def read(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + fetched.append(f"resources/read {params.uri}") + return ReadResourceResult(contents=[TextResourceContents(uri=params.uri, text="body")]) + + hint = CacheHint(ttl_ms=60_000) + server = Server( + "warm", + on_list_tools=list_tools, + on_list_prompts=list_prompts, + on_list_resources=list_resources, + on_list_resource_templates=list_templates, + on_read_resource=read, + cache_hints={ + "tools/list": hint, + "prompts/list": hint, + "resources/list": hint, + "resources/templates/list": hint, + "resources/read": hint, + }, + ) + + client = Client(server, cache=CacheConfig(clock=_ManualClock())) + async with client: + await client.list_tools() + await client.list_prompts() + await client.list_resources() + await client.list_resource_templates() + await client.read_resource("memo://a") + # A repeat round is served entirely from the warm entries. + await client.list_tools() + await client.read_resource("memo://a") + assert len(fetched) == 5 + + with pytest.raises(RuntimeError) as exc_info: + await client.list_tools() + assert str(exc_info.value) == snapshot("Client must be used within an async context manager") + with pytest.raises(RuntimeError): + await client.list_prompts() + with pytest.raises(RuntimeError): + await client.list_resources() + with pytest.raises(RuntimeError): + await client.list_resource_templates() + with pytest.raises(RuntimeError): + await client.read_resource("memo://a") + + assert len(fetched) == 5 # nothing was served from the cache and nothing reached the server diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 83893e36f..f76991f65 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1661,6 +1661,33 @@ async def test_discover_reraises_unsupported_version_with_malformed_error_data() assert [m for m, _ in dispatcher.calls] == ["server/discover"] +# --- inbound ttlMs clamp --- + + +@pytest.mark.anyio +async def test_a_positive_inbound_ttl_reaches_the_result_unchanged() -> None: + listing: dict[str, Any] = {"resultType": "complete", "tools": [], "ttlMs": 60_000, "cacheScope": "private"} + dispatcher = _ScriptedDispatcher(_discover_result_dict(), listing) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + result = await session.list_tools() + assert result.ttl_ms == 60_000 + + +@pytest.mark.anyio +@pytest.mark.parametrize("wire_ttl", [True, False]) +async def test_a_boolean_inbound_ttl_is_not_clamped_only_coerced_by_validation(wire_ttl: bool) -> None: + """SDK-defined: `bool` is an `int` subclass; the clamp skips it and pydantic's lax mode coerces it instead.""" + listing: dict[str, Any] = {"resultType": "complete", "tools": [], "ttlMs": wire_ttl, "cacheScope": "private"} + dispatcher = _ScriptedDispatcher(_discover_result_dict(), listing) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + result = await session.list_tools() + assert result.ttl_ms == int(wire_ttl) + + @pytest.mark.anyio async def test_session_call_tool_returns_input_required_result_when_opted_in() -> None: """`ClientSession.call_tool(..., allow_input_required=True)` surfaces the diff --git a/tests/docs_src/test_caching.py b/tests/docs_src/test_caching.py index bc2feb9ac..58014879c 100644 --- a/tests/docs_src/test_caching.py +++ b/tests/docs_src/test_caching.py @@ -1,13 +1,19 @@ """`docs/advanced/caching.md`: every claim the page makes, proved against the real SDK.""" +from collections.abc import Mapping from typing import Any, cast +import anyio import pytest from inline_snapshot import snapshot +from mcp_types import INTERNAL_ERROR, ListToolsResult, PaginatedRequestParams, Tool from docs_src.caching import tutorial001, tutorial002, tutorial003 -from mcp import Client -from mcp.server import CacheHint, MCPServer +from mcp import Client, MCPError +from mcp.client import CacheConfig +from mcp.client.caching import InMemoryResponseCacheStore +from mcp.server import CacheHint, MCPServer, Server, ServerRequestContext +from mcp.server.caching import CacheableMethod # See test_index.py for why this is a per-module mark and not a conftest hook. pytestmark = [pytest.mark.anyio, pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning")] @@ -42,7 +48,7 @@ async def test_a_non_cacheable_method_is_rejected_at_construction() -> None: with pytest.raises(ValueError) as exc: MCPServer("Weather", cache_hints=cast(Any, {"tools/call": CacheHint(ttl_ms=1_000)})) assert str(exc.value) == snapshot( - "cache_hints keys must be cacheable methods (see CacheableMethod); got: tools/call" + "cache_hints keys must be cacheable methods (see CacheableMethod); got: 'tools/call'" ) @@ -55,16 +61,149 @@ async def test_the_handler_value_wins_over_the_map_per_field() -> None: assert tools.cache_scope == "public" -async def test_the_client_program_on_the_page_reads_the_hints(capsys: pytest.CaptureFixture[str]) -> None: - """tutorial003: `main()` is the literal client program on the page - the hints - arrive as parsed fields on the result.""" +async def test_the_client_program_on_the_page_makes_three_fetches_for_four_calls( + capsys: pytest.CaptureFixture[str], +) -> None: + """tutorial003: a cache hit, an expiry, and `cache_mode="refresh"` make four calls cost three fetches.""" await tutorial003.main() - assert capsys.readouterr().out == "1 tools, fresh for 60s, scope=public\n" + assert capsys.readouterr().out == "4 calls, 3 fetches\n" + + +def _counting_tools_server(*, ttl_ms: int | None = 60_000) -> tuple[Server[Any], list[str | None]]: + """Each tools/list fetch returns a distinct tool name, so a cache hit is + payload-distinguishable from a refetch; `ttl_ms=None` sends no hints.""" + fetches: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(params.cursor if params is not None else None) + return ListToolsResult(tools=[Tool(name=f"t{len(fetches) - 1}", input_schema={"type": "object"})]) + + hints: Mapping[CacheableMethod, CacheHint] | None = None + if ttl_ms is not None: + hints = {"tools/list": CacheHint(ttl_ms=ttl_ms)} + return Server("counting", on_list_tools=list_tools, cache_hints=hints), fetches + + +async def test_caching_is_on_by_default_the_second_call_makes_no_fetch() -> None: + server, fetches = _counting_tools_server() + async with Client(server) as client: + first = await client.list_tools() + second = await client.list_tools() + assert fetches == [None] + assert second == first + + +async def test_a_hintless_result_is_not_cached_by_default() -> None: + """`default_ttl_ms` defaults to 0, so a hintless server sees its usual call-for-call traffic.""" + server, fetches = _counting_tools_server(ttl_ms=None) + async with Client(server) as client: + await client.list_tools() + await client.list_tools() + assert fetches == [None, None] + + +async def test_cache_false_makes_every_call_a_round_trip() -> None: + server, fetches = _counting_tools_server() + async with Client(server, cache=False) as client: + await client.list_tools() + await client.list_tools() + assert fetches == [None, None] + + +async def test_refresh_refetches_and_replaces_the_cached_entry() -> None: + server, fetches = _counting_tools_server() + async with Client(server) as client: + await client.list_tools() + refreshed = await client.list_tools(cache_mode="refresh") + served = await client.list_tools() + assert fetches == [None, None] + assert [tool.name for tool in refreshed.tools] == ["t1"] + assert served == refreshed + + +async def test_bypass_fetches_without_reading_or_writing_the_cache() -> None: + server, fetches = _counting_tools_server() + async with Client(server) as client: + first = await client.list_tools() + bypassed = await client.list_tools(cache_mode="bypass") + served = await client.list_tools() + assert fetches == [None, None] + assert [tool.name for tool in bypassed.tools] == ["t1"] + assert served == first + + +async def test_an_expired_entry_is_not_revived_when_the_refetch_fails() -> None: + """SDK ruling: no stale-if-error - the refetch failure propagates.""" + now = 1_000_000.0 + fetches: list[None] = [] + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(None) + if len(fetches) > 1: + raise MCPError(code=INTERNAL_ERROR, message="backend down") + return ListToolsResult(tools=[Tool(name="t0", input_schema={"type": "object"})]) + + server = Server("flaky", on_list_tools=list_tools, cache_hints={"tools/list": CacheHint(ttl_ms=60_000)}) + async with Client(server, cache=CacheConfig(clock=lambda: now)) as client: + await client.list_tools() + now += 60.0 # past the 60s TTL + with pytest.raises(MCPError) as exc: + await client.list_tools() + assert exc.value.code == INTERNAL_ERROR + assert len(fetches) == 2 + + +async def test_two_concurrent_identical_calls_are_two_fetches() -> None: + """SDK ruling: no coalescing. The handler barrier releases only once both + calls are inside it, so the test passes only if the fetches were concurrent.""" + both_fetching = anyio.Event() + fetches: list[None] = [] + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + fetches.append(None) + if len(fetches) == 2: + both_fetching.set() + with anyio.fail_after(5): + await both_fetching.wait() + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) + + server = Server("concurrent", on_list_tools=list_tools, cache_hints={"tools/list": CacheHint(ttl_ms=60_000)}) + async with Client(server) as client: + async with anyio.create_task_group() as tg: + tg.start_soon(client.list_tools) + tg.start_soon(client.list_tools) + assert len(fetches) == 2 + + +async def test_a_session_tier_call_always_makes_the_round_trip() -> None: + """The cache lives on the `Client` verbs; `client.session` sits below it.""" + server, fetches = _counting_tools_server() + async with Client(server) as client: + await client.list_tools() + await client.session.list_tools() + assert fetches == [None, None] + + +async def test_a_custom_store_requires_a_partition() -> None: + with pytest.raises(ValueError) as exc: + CacheConfig(store=InMemoryResponseCacheStore()) + assert str(exc.value) == snapshot("a custom store requires an explicit partition") + + +async def test_a_custom_store_with_an_in_process_server_requires_target_id() -> None: + server, _ = _counting_tools_server() + with pytest.raises(ValueError) as exc: + Client(server, cache=CacheConfig(store=InMemoryResponseCacheStore(), partition="user-1")) + assert str(exc.value) == snapshot( + "a custom cache store requires CacheConfig.target_id when the server is not a URL: in-process servers " + "and Transport instances get a random per-client identity, so their entries in a shared store could " + "never be served to another client" + ) async def test_the_wire_presence_check_the_page_recommends_works() -> None: """The page's claim: `"ttl_ms" in result.model_fields_set` distinguishes a server that sent the field from one that said nothing (model defaults).""" - async with Client(tutorial003.mcp) as client: + async with Client(tutorial001.mcp) as client: tools = await client.list_tools() assert "ttl_ms" in tools.model_fields_set diff --git a/tests/interaction/transports/test_hosting_http_modern.py b/tests/interaction/transports/test_hosting_http_modern.py index a8f1f53c7..3feed4fed 100644 --- a/tests/interaction/transports/test_hosting_http_modern.py +++ b/tests/interaction/transports/test_hosting_http_modern.py @@ -511,7 +511,8 @@ async def test_modern_client_stops_mirroring_after_a_re_list_drops_the_tool() -> bad_schema = {"type": "object", "properties": {"a": {"type": "string", "x-mcp-header": "bad name"}}} valid = Tool(name="run", input_schema=schema) invalid = Tool(name="run", input_schema=bad_schema) - listings = iter([valid, invalid]) + # Three pages: the call after the drop re-lists once because the prune also cleared `run`'s schema entry. + listings = iter([valid, invalid, invalid]) async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: return ListToolsResult(tools=[next(listings)], ttl_ms=0, cache_scope="public") diff --git a/tests/server/test_caching.py b/tests/server/test_caching.py index 46701d659..abfcfba97 100644 --- a/tests/server/test_caching.py +++ b/tests/server/test_caching.py @@ -1,40 +1,27 @@ """`mcp.server.caching`: `CacheHint` validation, per-field fills, and the `cache_hints` constructor map reaching the wire on both server tiers.""" -from types import UnionType -from typing import Any, cast, get_args +from typing import Any, cast import pytest from inline_snapshot import snapshot from mcp_types import ( - CacheableResult, + InputRequiredResult, ListResourcesResult, ListToolsResult, PaginatedRequestParams, + ReadResourceRequestParams, Resource, Tool, - methods, ) from mcp import Client from mcp.server import CacheHint, MCPServer, Server, ServerRequestContext -from mcp.server.caching import CACHEABLE_METHODS, apply_cache_hint +from mcp.server.caching import apply_cache_hint pytestmark = pytest.mark.anyio -def test_cacheable_methods_match_the_result_models() -> None: - """Spec-mandated set (SEP-2549): `CACHEABLE_METHODS` mirrors exactly the - methods whose monolith result models mix in `CacheableResult` - if the - schema gains or loses a cacheable result, this weld breaks.""" - derived: set[str] = set() - for method, model in methods.MONOLITH_RESULTS.items(): - arms = get_args(model) if isinstance(model, UnionType) else (model,) - if any(isinstance(arm, type) and issubclass(arm, CacheableResult) for arm in arms): - derived.add(method) - assert CACHEABLE_METHODS == derived - - def test_cache_hint_defaults_match_the_conservative_model_defaults() -> None: """SDK-defined: an unconfigured hint fills the same values the result models already default to - immediately stale, not shared - so stamping it is @@ -83,7 +70,7 @@ def test_a_non_cacheable_method_in_cache_hints_is_rejected_at_server_constructio with pytest.raises(ValueError) as exc: Server("srv", cache_hints=cast(Any, {"tools/call": CacheHint()})) assert str(exc.value) == snapshot( - "cache_hints keys must be cacheable methods (see CacheableMethod); got: tools/call" + "cache_hints keys must be cacheable methods (see CacheableMethod); got: 'tools/call'" ) @@ -96,6 +83,72 @@ def test_a_non_cache_hint_value_is_rejected_at_server_construction() -> None: assert str(exc.value) == snapshot("cache_hints['tools/list'] must be a CacheHint, got dict") +def test_a_non_string_cache_hints_key_is_rejected_with_the_unknown_key_error() -> None: + """A non-string key takes the same unknown-key ValueError as a typo, not a TypeError from message formatting.""" + with pytest.raises(ValueError) as exc: + Server("srv", cache_hints=cast(Any, {42: CacheHint()})) + assert str(exc.value) == snapshot("cache_hints keys must be cacheable methods (see CacheableMethod); got: 42") + + +async def test_a_dict_returning_handler_takes_the_configured_hint() -> None: + """The stamp covers raw-dict results too - 2026-07-28 requires both fields on the wire.""" + hint = CacheHint(ttl_ms=60_000, scope="public") + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams) -> dict[str, Any]: + return {"tools": [], "resultType": "complete"} + + server = Server("srv", cache_hints={"tools/list": hint}) + server.add_request_handler("tools/list", PaginatedRequestParams, list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.ttl_ms == hint.ttl_ms + assert result.cache_scope == hint.scope + + +async def test_a_dict_provided_ttl_wins_and_the_hint_fills_only_the_missing_scope() -> None: + """Dict path mirrors the model path's `model_fields_set` precedence: present wire keys win.""" + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams) -> dict[str, Any]: + return {"tools": [], "resultType": "complete", "ttlMs": 25} + + server = Server("srv", cache_hints={"tools/list": CacheHint(ttl_ms=60_000, scope="public")}) + server.add_request_handler("tools/list", PaginatedRequestParams, list_tools) + async with Client(server) as client: + result = await client.list_tools() + assert result.ttl_ms == 25 + assert result.cache_scope == "public" + + +async def test_a_dict_returning_handler_leaks_no_hint_fields_to_a_2025_session() -> None: + """The stamp runs version-independently; the 2025 serialize sieve strips the fields.""" + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams) -> dict[str, Any]: + return {"tools": []} + + server = Server("srv", cache_hints={"tools/list": CacheHint(ttl_ms=60_000, scope="public")}) + server.add_request_handler("tools/list", PaginatedRequestParams, list_tools) + async with Client(server, mode="legacy") as client: + result = await client.list_tools() + assert "ttl_ms" not in result.model_fields_set + assert "cache_scope" not in result.model_fields_set + + +async def test_an_input_required_shaped_dict_is_never_stamped() -> None: + """Spec carve-out: interim `input_required` results carry no cache hints, even on a hinted method.""" + + async def read_resource(ctx: ServerRequestContext[Any], params: ReadResourceRequestParams) -> dict[str, Any]: + return {"resultType": "input_required", "requestState": "s1"} + + server = Server("srv", cache_hints={"resources/read": CacheHint(ttl_ms=60_000, scope="public")}) + server.add_request_handler("resources/read", ReadResourceRequestParams, read_resource) + async with Client(server) as client: + result = await client.session.read_resource("res://x", allow_input_required=True) + assert isinstance(result, InputRequiredResult) + assert result.model_dump(by_alias=True, exclude_none=True) == snapshot( + {"resultType": "input_required", "requestState": "s1"} + ) + + async def test_server_cache_hints_reach_the_wire_for_a_bare_handler_result() -> None: """SDK-defined: a lowlevel handler that never thinks about caching emits the server-wide hint configured at construction.""" diff --git a/tests/types/test_methods.py b/tests/types/test_methods.py index 79ea067c6..342720c32 100644 --- a/tests/types/test_methods.py +++ b/tests/types/test_methods.py @@ -548,6 +548,11 @@ def test_built_in_maps_are_immutable(): _assign_item(built_in) +def test_cacheable_methods_mirror_the_cacheable_method_literal(): + """SEP-2549 weld: the hand-written Literal and the set derived from `MONOLITH_RESULTS` must agree.""" + assert methods.CACHEABLE_METHODS == frozenset(get_args(methods.CacheableMethod)) + + def test_minimal_request_bodies_parse_through_every_request_row(): for (method, version), surface_type in methods.CLIENT_REQUESTS.items(): parsed = methods.parse_client_request(method, version, REQUEST_PARAMS_FIXTURES[surface_type])