diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index d6b05e066..6de2edb8d 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -16,6 +16,52 @@ from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +def _split_www_authenticate_segments(header_value: str) -> list[str]: + """Split a WWW-Authenticate header on top-level commas.""" + segments: list[str] = [] + current: list[str] = [] + in_quotes = False + + for char in header_value: + if char == '"': + in_quotes = not in_quotes + if char == "," and not in_quotes: + segment = "".join(current).strip() + if segment: + segments.append(segment) + current = [] + continue + current.append(char) + + tail = "".join(current).strip() + if tail: + segments.append(tail) + return segments + + +def _extract_bearer_auth_params(www_auth_header: str) -> str | None: + """Return the auth-param portion of the first Bearer challenge.""" + segments = _split_www_authenticate_segments(www_auth_header) + collecting = False + auth_params: list[str] = [] + + for segment in segments: + scheme, separator, remainder = segment.partition(" ") + if scheme.lower() == "bearer" and separator: + collecting = True + auth_params = [remainder.strip()] + continue + + if collecting: + if separator and "=" not in scheme: + break + auth_params.append(segment) + + if not auth_params: + return None + return ", ".join(part for part in auth_params if part) + + def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: """Extract field from WWW-Authenticate header. @@ -26,13 +72,16 @@ def extract_field_from_www_auth(response: Response, field_name: str) -> str | No if not www_auth_header: return None - # Pattern matches: field_name="value" or field_name=value (unquoted) - pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, www_auth_header) + auth_params = _extract_bearer_auth_params(www_auth_header) + if auth_params is None: + return None - if match: - # Return quoted value if present, otherwise unquoted value - return match.group(1) or match.group(2) + # Match comma-delimited auth-params while respecting quoted values. + pattern = re.compile(r'(?:^|,\s*)(?P[A-Za-z][A-Za-z0-9_-]*)=(?:"(?P[^"]+)"|(?P[^,\s]+))') + for match in pattern.finditer(auth_params): + if match.group("name") == field_name: + # Return quoted value if present, otherwise unquoted value + return match.group("quoted") or match.group("unquoted") return None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 1ec38ccf6..f42d953bb 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2016,6 +2016,17 @@ class TestWWWAuthenticate: "resource_metadata", "https://api.example.com/auth/metadata?version=1", ), + ('Bearer error_scope="decoy", scope="read write"', "scope", "read write"), + ( + 'Bearer error_description="missing scope=write permission", scope="read write"', + "scope", + "read write", + ), + ( + 'Basic realm="legacy", Bearer scope="read write", error="insufficient_scope"', + "scope", + "read write", + ), ], ) def test_extract_field_from_www_auth_valid_cases( @@ -2047,6 +2058,12 @@ def test_extract_field_from_www_auth_valid_cases( # Header without requested field ('Bearer realm="api", error="insufficient_scope"', "scope", "no scope parameter"), ('Bearer realm="api", scope="read write"', "resource_metadata", "no resource_metadata parameter"), + ('Bearer custom_scope="leaked"', "scope", "field name appears only as a substring"), + ( + 'Bearer x_resource_metadata="https://decoy.example.com"', + "resource_metadata", + "field name appears only as a substring", + ), # Malformed field (empty value) ("Bearer scope=", "scope", "malformed scope parameter"), ("Bearer resource_metadata=", "resource_metadata", "malformed resource_metadata parameter"),