Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodySubscriber;
import java.net.http.HttpResponse.ResponseInfo;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -56,7 +59,7 @@ record AggregateResponseEvent(ResponseInfo responseInfo, String data) implements

static BodySubscriber<Void> sseToBodySubscriber(ResponseInfo responseInfo, FluxSink<ResponseEvent> sink) {
return HttpResponse.BodySubscribers
.fromLineSubscriber(FlowAdapters.toFlowSubscriber(new SseLineSubscriber(responseInfo, sink)));
.fromSubscriber(FlowAdapters.toFlowSubscriber(new SseByteSubscriber(responseInfo, sink)));
}

static BodySubscriber<Void> aggregateBodySubscriber(ResponseInfo responseInfo, FluxSink<ResponseEvent> sink) {
Expand All @@ -69,56 +72,33 @@ static BodySubscriber<Void> bodilessBodySubscriber(ResponseInfo responseInfo, Fl
.fromLineSubscriber(FlowAdapters.toFlowSubscriber(new BodilessResponseLineSubscriber(responseInfo, sink)));
}

static class SseLineSubscriber extends BaseSubscriber<String> {
static class SseByteSubscriber extends BaseSubscriber<List<ByteBuffer>> {

/**
* Pattern to extract data content from SSE "data:" lines.
*/
private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE);

/**
* Pattern to extract event ID from SSE "id:" lines.
*/
private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE);

/**
* Pattern to extract event type from SSE "event:" lines.
*/
private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE);

/**
* The sink for emitting parsed response events.
*/
private final FluxSink<ResponseEvent> sink;

/**
* StringBuilder for accumulating multi-line event data.
*/
private final StringBuilder eventBuilder;

/**
* Current event's ID, if specified.
*/
private final AtomicReference<String> currentEventId;

/**
* Current event's type, if specified.
*/
private final AtomicReference<String> currentEventType;

/**
* The response information from the HTTP response. Send with each event to
* provide context.
*/
private ResponseInfo responseInfo;
private final ResponseInfo responseInfo;

/**
* Creates a new LineSubscriber that will emit parsed SSE events to the provided
* sink.
* @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects
* to
*/
public SseLineSubscriber(ResponseInfo responseInfo, FluxSink<ResponseEvent> sink) {
private final SseByteBuffer buffer = new SseByteBuffer();

private volatile boolean hasRequestedDemand = false;

private int scanIndex = 0;

private int start = 0;

public SseByteSubscriber(ResponseInfo responseInfo, FluxSink<ResponseEvent> sink) {
this.sink = sink;
this.eventBuilder = new StringBuilder();
this.currentEventId = new AtomicReference<>();
Expand All @@ -128,21 +108,71 @@ public SseLineSubscriber(ResponseInfo responseInfo, FluxSink<ResponseEvent> sink

@Override
protected void hookOnSubscribe(Subscription subscription) {

sink.onRequest(n -> {
subscription.request(n);
if (!hasRequestedDemand) {
subscription.request(Long.MAX_VALUE);
}
hasRequestedDemand = true;
});

// Register disposal callback to cancel subscription when Flux is disposed
sink.onDispose(() -> {
subscription.cancel();
});
}

@Override
protected void hookOnNext(String line) {
protected void hookOnNext(List<ByteBuffer> buffers) {
for (ByteBuffer b : buffers) {
int remaining = b.remaining();
if (remaining > 0) {
byte[] bytes = new byte[remaining];
b.get(bytes);
buffer.append(bytes, 0, remaining);
}
}
parseBuffer();
}

private void parseBuffer() {
byte[] buf = buffer.getBuf();
int count = buffer.getCount();

while (scanIndex < count) {
byte b = buf[scanIndex];
if (b == '\n') {
int lineEnd = scanIndex;
int terminatorLen = 1;
processLine(buf, start, lineEnd);
start = lineEnd + terminatorLen;
scanIndex = start;
}
else if (b == '\r') {
if (scanIndex + 1 < count) {
int lineEnd = scanIndex;
int terminatorLen = (buf[scanIndex + 1] == '\n') ? 2 : 1;
processLine(buf, start, lineEnd);
start = lineEnd + terminatorLen;
scanIndex = start;
}
else {
break;
}
}
else {
scanIndex++;
}
}

if (start > 0) {
buffer.shift(start);
scanIndex -= start;
start = 0;
}
}

private void processLine(byte[] buf, int start, int end) {
String line = new String(buf, start, end - start, StandardCharsets.UTF_8);
if (line.isEmpty()) {
// Empty line means end of event
if (this.eventBuilder.length() > 0) {
String eventData = this.eventBuilder.toString();
SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
Expand All @@ -157,39 +187,47 @@ protected void hookOnNext(String line) {
if (matcher.find()) {
this.eventBuilder.append(matcher.group(1).trim()).append("\n");
}
upstream().request(1);
}
else if (line.startsWith("id:")) {
var matcher = EVENT_ID_PATTERN.matcher(line);
if (matcher.find()) {
this.currentEventId.set(matcher.group(1).trim());
}
upstream().request(1);
}
else if (line.startsWith("event:")) {
var matcher = EVENT_TYPE_PATTERN.matcher(line);
if (matcher.find()) {
this.currentEventType.set(matcher.group(1).trim());
}
upstream().request(1);
}
else if (line.startsWith(":")) {
// Ignore comment lines starting with ":"
// This is a no-op, just to skip comments
logger.debug("Ignoring comment line: {}", line);
upstream().request(1);
}
else {
// If the response is not successful, emit an error
this.sink.error(new McpTransportException(
"Invalid SSE response. Status code: " + this.responseInfo.statusCode() + " Line: " + line));

}
}
}

@Override
protected void hookOnComplete() {
byte[] buf = buffer.getBuf();
int count = buffer.getCount();

// If we broke out of the loop because of a trailing '\r' at the end of the
// stream,
// treat it as a bare '\r' line terminator now.
if (scanIndex < count && buf[scanIndex] == '\r') {
int lineEnd = scanIndex;
int terminatorLen = 1;
processLine(buf, start, lineEnd);
start = lineEnd + terminatorLen;
}

if (start < count) {
processLine(buf, start, count);
}
if (this.eventBuilder.length() > 0) {
String eventData = this.eventBuilder.toString();
SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim());
Expand All @@ -205,6 +243,52 @@ protected void hookOnError(Throwable throwable) {

}

private static class SseByteBuffer {

private byte[] buf = new byte[4096];

private int count = 0;

public void append(byte[] b, int off, int len) {
ensureCapacity(count + len);
System.arraycopy(b, off, buf, count, len);
count += len;
}

private void ensureCapacity(int minCapacity) {
if (minCapacity - buf.length > 0) {
int newCapacity = buf.length * 2;
if (newCapacity - minCapacity < 0) {
newCapacity = minCapacity;
}
byte[] newBuf = new byte[newCapacity];
System.arraycopy(buf, 0, newBuf, 0, count);
buf = newBuf;
}
}

public byte[] getBuf() {
return buf;
}

public int getCount() {
return count;
}

public void shift(int bytesToShift) {
if (bytesToShift <= 0) {
return;
}
if (bytesToShift >= count) {
count = 0;
return;
}
System.arraycopy(buf, bytesToShift, buf, 0, count - bytesToShift);
count -= bytesToShift;
}

}

static class AggregateSubscriber extends BaseSubscriber<String> {

/**
Expand Down
Loading