Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion java/src/main/java/com/github/copilot/CopilotSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public final class CopilotSession implements AutoCloseable {
private final Set<Consumer<SessionEvent>> eventHandlers = ConcurrentHashMap.newKeySet();
private final Map<String, ToolDefinition> toolHandlers = new ConcurrentHashMap<>();
private final Map<String, CommandHandler> commandHandlers = new ConcurrentHashMap<>();
private final Map<String, com.github.copilot.rpc.AbortSignal> activeToolSignals = new ConcurrentHashMap<>();
private final AtomicReference<PermissionHandler> permissionHandler = new AtomicReference<>();
private final AtomicReference<UserInputHandler> userInputHandler = new AtomicReference<>();
private final AtomicReference<ElicitationHandler> elicitationHandler = new AtomicReference<>();
Expand Down Expand Up @@ -882,15 +883,19 @@ private void handleBroadcastEventAsync(SessionEvent event) {
*/
private void executeToolAndRespondAsync(String requestId, String toolName, String toolCallId, Object arguments,
ToolDefinition tool) {
var signal = new com.github.copilot.rpc.AbortSignal();
activeToolSignals.put(requestId, signal);
Runnable task = () -> {
try {
JsonNode argumentsNode = arguments instanceof JsonNode jn
? jn
: (arguments != null ? MAPPER.valueToTree(arguments) : null);
var invocation = new com.github.copilot.rpc.ToolInvocation().setSessionId(sessionId)
.setToolCallId(toolCallId).setToolName(toolName).setArguments(argumentsNode);
.setToolCallId(toolCallId).setToolName(toolName).setArguments(argumentsNode)
.setAbortSignal(signal);

tool.handler().invoke(invocation).thenAccept(result -> {
activeToolSignals.remove(requestId);
try {
ToolResultObject toolResult;
if (result instanceof ToolResultObject tr) {
Expand All @@ -905,6 +910,7 @@ private void executeToolAndRespondAsync(String requestId, String toolName, Strin
LOG.log(Level.WARNING, "Error sending tool result for requestId=" + requestId, e);
}
}).exceptionally(ex -> {
activeToolSignals.remove(requestId);
try {
getRpc().tools.handlePendingToolCall(new SessionToolsHandlePendingToolCallParams(sessionId,
requestId, null, ex.getMessage() != null ? ex.getMessage() : ex.toString()));
Expand All @@ -914,6 +920,7 @@ private void executeToolAndRespondAsync(String requestId, String toolName, Strin
return null;
});
} catch (Exception e) {
activeToolSignals.remove(requestId);
LOG.log(Level.WARNING, "Error executing tool for requestId=" + requestId, e);
try {
getRpc().tools.handlePendingToolCall(new SessionToolsHandlePendingToolCallParams(sessionId,
Expand Down Expand Up @@ -1796,6 +1803,9 @@ public CompletableFuture<List<SessionEvent>> getMessages() {
*/
public CompletableFuture<Void> abort() {
ensureNotTerminated();
for (com.github.copilot.rpc.AbortSignal signal : activeToolSignals.values()) {
signal.abort();
}
return rpc.invoke("session.abort", Map.of("sessionId", sessionId), Void.class);
}

Expand Down Expand Up @@ -2136,6 +2146,7 @@ public void close() {
eventHandlers.clear();
toolHandlers.clear();
commandHandlers.clear();
activeToolSignals.clear();
permissionHandler.set(null);
userInputHandler.set(null);
elicitationHandler.set(null);
Expand Down
113 changes: 113 additions & 0 deletions java/src/main/java/com/github/copilot/rpc/AbortSignal.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/

package com.github.copilot.rpc;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* A signal that indicates whether a tool invocation has been aborted.
* <p>
* An {@code AbortSignal} is passed to tool handlers via
* {@link ToolInvocation#getAbortSignal()} and is triggered when
* {@link com.github.copilot.CopilotSession#abort()} is called while the tool is
* executing. Tool handlers can use this to implement cooperative cancellation,
* allowing them to stop long-running work gracefully when the session is aborted.
*
* <h2>Example Usage</h2>
*
* <pre>{@code
* ToolHandler handler = invocation -> {
* AbortSignal signal = invocation.getAbortSignal();
* return CompletableFuture.supplyAsync(() -> {
* while (!signal.isAborted()) {
* // do incremental work here
* }
* throw new CancellationException("Tool aborted");
* });
* };
* }</pre>
*
* <h2>Callback Registration</h2>
*
* <pre>{@code
* ToolHandler handler = invocation -> {
* AbortSignal signal = invocation.getAbortSignal();
* signal.onAborted(() -> System.out.println("Aborting tool!"));
* // ... perform work ...
* return CompletableFuture.completedFuture("done");
* };
* }</pre>
*
* @see ToolInvocation#getAbortSignal()
* @see com.github.copilot.CopilotSession#abort()
* @since 1.6.0
*/
public final class AbortSignal {

private final AtomicBoolean aborted = new AtomicBoolean(false);
private final List<Runnable> listeners = new CopyOnWriteArrayList<>();

/**
* Returns whether this signal has been aborted.
*
* @return {@code true} if {@link com.github.copilot.CopilotSession#abort()} was
* called while this tool invocation was in progress; {@code false}
* otherwise
*/
public boolean isAborted() {
return aborted.get();
}

/**
* Registers a callback to be invoked when this signal is aborted.
* <p>
* If the signal is already aborted at the time of registration, the callback is
* invoked immediately on the calling thread.
* <p>
* Exceptions thrown by the callback are silently ignored.
Comment thread
gimenete marked this conversation as resolved.
Outdated
*
* @param listener
* the callback to invoke on abort
* @throws NullPointerException
* if listener is null
*/
public void onAborted(Runnable listener) {
Objects.requireNonNull(listener, "listener must not be null");
listeners.add(listener);
if (aborted.get()) {
try {
listener.run();
} catch (Exception ignored) {
// Exceptions from listeners are silently ignored
}
Comment thread
gimenete marked this conversation as resolved.
Outdated
}
}
Comment thread
gimenete marked this conversation as resolved.

/**
* Triggers this abort signal, notifying all registered listeners.
* <p>
* <strong>Note:</strong> This method is intended for internal SDK use only.
* It is called by the SDK when
* {@link com.github.copilot.CopilotSession#abort()} is invoked while this tool
* invocation is in progress.
* <p>
* Calling this method more than once has no effect — the signal fires exactly
* once.
*/
public void abort() {
if (aborted.compareAndSet(false, true)) {
for (Runnable listener : listeners) {
try {
listener.run();
} catch (Exception ignored) {
// Exceptions from listeners are silently ignored
}
}
}
}
Comment thread
gimenete marked this conversation as resolved.
}
67 changes: 65 additions & 2 deletions java/src/main/java/com/github/copilot/rpc/ToolInvocation.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.util.Map;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonSetter;
import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -16,11 +17,28 @@
* Represents a tool invocation request from the AI assistant.
* <p>
* When the assistant invokes a tool, this object contains the context including
* the session ID, tool call ID, tool name, and arguments parsed from the
* assistant's request.
* the session ID, tool call ID, tool name, arguments parsed from the
* assistant's request, and an {@link AbortSignal} that is triggered when
* {@link com.github.copilot.CopilotSession#abort()} is called while the tool is
* executing.
*
* <h2>Cooperative Cancellation</h2>
*
* <pre>{@code
* ToolHandler handler = invocation -> {
* AbortSignal signal = invocation.getAbortSignal();
* return CompletableFuture.supplyAsync(() -> {
* while (!signal.isAborted()) {
* // do incremental work here
* }
* throw new CancellationException("Tool aborted");
* });
* };
* }</pre>
*
* @see ToolHandler
* @see ToolDefinition
* @see AbortSignal
* @since 1.0.0
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand All @@ -34,6 +52,7 @@ public final class ToolInvocation {
private String toolCallId;
private String toolName;
private JsonNode argumentsNode;
private AbortSignal abortSignal = new AbortSignal();
Comment thread
gimenete marked this conversation as resolved.

/**
* Gets the session ID where the tool was invoked.
Expand Down Expand Up @@ -168,4 +187,48 @@ public ToolInvocation setArguments(JsonNode arguments) {
this.argumentsNode = arguments;
return this;
}

/**
* Returns the abort signal for this tool invocation.
* <p>
* The signal is triggered when
* {@link com.github.copilot.CopilotSession#abort()} is called while this tool
* is executing. Use it to implement cooperative cancellation in your tool
* handler.
*
* <pre>{@code
* ToolHandler handler = invocation -> {
* AbortSignal signal = invocation.getAbortSignal();
* return CompletableFuture.supplyAsync(() -> {
* while (!signal.isAborted()) {
* // do incremental work here
* }
* throw new CancellationException("Tool aborted");
* });
* };
* }</pre>
*
* @return the abort signal; never {@code null}
* @see AbortSignal
* @since 1.6.0
*/
@JsonIgnore
public AbortSignal getAbortSignal() {
return abortSignal;
}
Comment thread
gimenete marked this conversation as resolved.

/**
* Sets the abort signal for this tool invocation.
* <p>
* <strong>Note:</strong> This method is intended for internal SDK use only.
* Users do not need to call this method directly.
*
* @param abortSignal
* the abort signal to associate with this invocation
* @return this invocation for method chaining
*/
public ToolInvocation setAbortSignal(AbortSignal abortSignal) {
this.abortSignal = abortSignal;
return this;
}
Comment thread
gimenete marked this conversation as resolved.
Comment thread
gimenete marked this conversation as resolved.
}
4 changes: 4 additions & 0 deletions java/src/main/java/com/github/copilot/rpc/package-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
* tool that can be invoked by the assistant.</li>
* <li>{@link com.github.copilot.rpc.ToolInvocation} - Represents a tool
* invocation request from the assistant.</li>
* <li>{@link com.github.copilot.rpc.AbortSignal} - Cancellation signal passed
* to tool handlers via {@link com.github.copilot.rpc.ToolInvocation#getAbortSignal()},
* triggered when {@link com.github.copilot.CopilotSession#abort()} is
* called.</li>
* <li>{@link com.github.copilot.rpc.Attachment} - File attachment for
* messages.</li>
* </ul>
Expand Down
73 changes: 73 additions & 0 deletions java/src/test/java/com/github/copilot/ToolInvocationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import static org.junit.jupiter.api.Assertions.*;

import java.util.concurrent.atomic.AtomicBoolean;

import org.junit.jupiter.api.Test;

import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.copilot.rpc.AbortSignal;
import com.github.copilot.rpc.ToolInvocation;

/**
Expand Down Expand Up @@ -127,6 +130,76 @@ void testGetArgumentsAsThrowsOnInvalidType() {
assertTrue(exception.getMessage().contains("StrictType"));
}

/**
* Test that getAbortSignal returns a non-null signal by default.
*/
@Test
void testGetAbortSignalReturnedByDefault() {
ToolInvocation invocation = new ToolInvocation().setSessionId("s1").setToolCallId("c1")
.setToolName("my_tool");
assertNotNull(invocation.getAbortSignal(), "getAbortSignal should not return null");
assertFalse(invocation.getAbortSignal().isAborted(), "signal should not be aborted by default");
}

/**
* Test that isAborted returns true after the signal is aborted.
*/
@Test
void testAbortSignalIsAbortedAfterAbort() {
AbortSignal signal = new AbortSignal();
assertFalse(signal.isAborted());
signal.abort();
assertTrue(signal.isAborted());
}

/**
* Test that onAborted callback is invoked when signal is aborted.
*/
@Test
void testAbortSignalOnAbortedCallbackInvoked() {
AbortSignal signal = new AbortSignal();
var called = new AtomicBoolean(false);
signal.onAborted(() -> called.set(true));
assertFalse(called.get());
signal.abort();
assertTrue(called.get());
}

/**
* Test that onAborted callback is invoked immediately if signal is already
* aborted.
*/
@Test
void testAbortSignalOnAbortedCallbackInvokedImmediatelyIfAlreadyAborted() {
AbortSignal signal = new AbortSignal();
signal.abort();
var called = new AtomicBoolean(false);
signal.onAborted(() -> called.set(true));
assertTrue(called.get(), "callback should be invoked immediately when signal is already aborted");
}

/**
* Test that abort() is idempotent — callbacks fire only once.
*/
@Test
void testAbortSignalAbortIsIdempotent() {
AbortSignal signal = new AbortSignal();
var count = new java.util.concurrent.atomic.AtomicInteger(0);
signal.onAborted(count::incrementAndGet);
signal.abort();
signal.abort();
assertEquals(1, count.get(), "callback should be invoked exactly once even if abort() called twice");
}

/**
* Test that onAborted throws NullPointerException for null listener.
*/
@Test
void testAbortSignalOnAbortedRejectsNullListener() {
AbortSignal signal = new AbortSignal();
assertThrows(NullPointerException.class, () -> signal.onAborted(null));
}

/**
* Record for testing type-safe argument deserialization.
*/
Expand Down