From 8baba5e3502e25b56a94591a7672ece8210c65a5 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 15 Jan 2026 16:16:44 -0800 Subject: [PATCH] feat: implement support for SEP-1686 Tasks --- .../client/McpAsyncClient.java | 792 ++++- .../client/McpClient.java | 156 +- .../client/McpClientFeatures.java | 93 +- .../client/McpSyncClient.java | 268 +- .../HttpClientStreamableHttpTransport.java | 18 +- ...ractTaskAwareToolSpecificationBuilder.java | 164 + .../experimental/tasks/CreateTaskExtra.java | 210 ++ .../experimental/tasks/CreateTaskHandler.java | 60 + .../experimental/tasks/CreateTaskOptions.java | 190 ++ .../tasks/DefaultCreateTaskExtra.java | 94 + .../tasks/DefaultSyncCreateTaskExtra.java | 94 + .../tasks/DefaultTaskContext.java | 151 + .../tasks/GetTaskFromStoreResult.java | 56 + .../experimental/tasks/GetTaskHandler.java | 49 + .../tasks/GetTaskResultHandler.java | 51 + .../tasks/InMemoryTaskMessageQueue.java | 126 + .../experimental/tasks/InMemoryTaskStore.java | 684 +++++ .../experimental/tasks/QueuedMessage.java | 50 + .../tasks/SyncCreateTaskExtra.java | 204 ++ .../tasks/SyncCreateTaskHandler.java | 67 + .../tasks/SyncGetTaskHandler.java | 36 + .../tasks/SyncGetTaskResultHandler.java | 37 + .../TaskAwareAsyncToolSpecification.java | 336 ++ .../tasks/TaskAwareSyncToolSpecification.java | 258 ++ .../experimental/tasks/TaskContext.java | 134 + .../experimental/tasks/TaskDefaults.java | 95 + .../experimental/tasks/TaskHelper.java | 110 + .../experimental/tasks/TaskMessageQueue.java | 74 + .../experimental/tasks/TaskStore.java | 264 ++ .../experimental/tasks/package-info.java | 121 + .../server/McpAsyncServer.java | 637 +++- .../server/McpAsyncServerExchange.java | 749 ++++- .../server/McpServer.java | 369 ++- .../server/McpServerFeatures.java | 81 +- .../server/McpStatelessServerFeatures.java | 8 +- .../server/McpSyncServer.java | 71 + .../server/McpSyncServerExchange.java | 254 +- .../modelcontextprotocol/spec/McpSchema.java | 2717 ++++++++++++++++- ...bstractTaskAwareToolSpecificationTest.java | 191 ++ .../tasks/InMemoryTaskMessageQueueTests.java | 288 ++ .../tasks/InMemoryTaskStoreTests.java | 1277 ++++++++ .../TaskAwareAsyncToolSpecificationTest.java | 290 ++ .../TaskAwareSyncToolSpecificationTest.java | 232 ++ .../experimental/tasks/TaskHelperTests.java | 87 + .../experimental/tasks/TaskTestUtils.java | 126 + .../server/AbstractMcpAsyncServerTests.java | 365 +++ ...stractMcpClientServerIntegrationTests.java | 869 ++++++ .../server/AbstractMcpSyncServerTests.java | 266 ++ .../server/McpAsyncServerExchangeTests.java | 297 ++ .../WebClientStreamableHttpTransport.java | 8 +- ...stractMcpClientServerIntegrationTests.java | 687 ++++- .../experimental/tasks/TaskTestUtils.java | 330 ++ .../server/AbstractMcpAsyncServerTests.java | 375 ++- .../server/AbstractMcpSyncServerTests.java | 273 +- 54 files changed, 15686 insertions(+), 203 deletions(-) create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationBuilder.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskExtra.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskHandler.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskOptions.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultCreateTaskExtra.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultSyncCreateTaskExtra.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultTaskContext.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskFromStoreResult.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskHandler.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskResultHandler.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueue.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStore.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/QueuedMessage.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskExtra.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskHandler.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskHandler.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskResultHandler.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecification.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecification.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskContext.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskDefaults.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskHelper.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskMessageQueue.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskStore.java create mode 100644 mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/package-info.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationTest.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueueTests.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStoreTests.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecificationTest.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecificationTest.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskHelperTests.java create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e6a09cd08..281d18120 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -5,23 +5,30 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.time.Instant; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; +import io.modelcontextprotocol.experimental.tasks.CreateTaskOptions; +import io.modelcontextprotocol.experimental.tasks.TaskDefaults; +import io.modelcontextprotocol.experimental.tasks.TaskStore; import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; @@ -39,6 +46,8 @@ import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -106,6 +115,9 @@ public class McpAsyncClient { public static final TypeRef PROGRESS_NOTIFICATION_TYPE_REF = new TypeRef<>() { }; + public static final TypeRef TASK_STATUS_NOTIFICATION_TYPE_REF = new TypeRef<>() { + }; + public static final String NEGOTIATED_PROTOCOL_VERSION = "io.modelcontextprotocol.client.negotiated-protocol-version"; /** @@ -170,18 +182,33 @@ public class McpAsyncClient { */ private final boolean enableCallToolSchemaCaching; + /** + * Task store for client-side task hosting (experimental). When set, the client can + * host tasks for task-augmented sampling and elicitation requests from the server. + */ + private final TaskStore taskStore; + + /** + * Maximum duration to poll for task completion in callToolStream(). Null means use + * the default timeout (5 minutes) to prevent infinite polling. + */ + private final Duration taskPollTimeout; + /** * Create a new McpAsyncClient with the given transport and session request-response * timeout. * @param transport the transport to use. * @param requestTimeout the session request-response timeout. * @param initializationTimeout the max timeout to await for the client-server + * connection. * @param jsonSchemaValidator the JSON schema validator to use for validating tool - * @param features the MCP Client supported features. responses against output - * schemas. + * responses against output schemas. + * @param features the MCP Client supported features. + * @param taskStore the task store for managing task state. */ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - JsonSchemaValidator jsonSchemaValidator, McpClientFeatures.Async features) { + JsonSchemaValidator jsonSchemaValidator, McpClientFeatures.Async features, + TaskStore taskStore) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); @@ -194,6 +221,8 @@ public class McpAsyncClient { this.jsonSchemaValidator = jsonSchemaValidator; this.toolsOutputSchemaCache = new ConcurrentHashMap<>(); this.enableCallToolSchemaCaching = features.enableCallToolSchemaCaching(); + this.taskStore = taskStore; + this.taskPollTimeout = features.taskPollTimeout(); // Request Handlers Map> requestHandlers = new HashMap<>(); @@ -229,6 +258,18 @@ public class McpAsyncClient { requestHandlers.put(McpSchema.METHOD_ELICITATION_CREATE, elicitationCreateHandler()); } + // Task Handlers (for client-side task hosting) + if (this.taskStore != null && this.clientCapabilities.tasks() != null) { + requestHandlers.put(McpSchema.METHOD_TASKS_GET, clientTasksGetHandler()); + requestHandlers.put(McpSchema.METHOD_TASKS_RESULT, clientTasksResultHandler()); + if (this.clientCapabilities.tasks().list() != null) { + requestHandlers.put(McpSchema.METHOD_TASKS_LIST, clientTasksListHandler()); + } + if (this.clientCapabilities.tasks().cancel() != null) { + requestHandlers.put(McpSchema.METHOD_TASKS_CANCEL, clientTasksCancelHandler()); + } + } + // Notification Handlers Map notificationHandlers = new HashMap<>(); @@ -296,6 +337,16 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, asyncProgressNotificationHandler(progressConsumersFinal)); + // Task Status Notification + List>> taskStatusConsumersFinal = new ArrayList<>(); + taskStatusConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Task status: {}", notification))); + if (!Utils.isEmpty(features.taskStatusConsumers())) { + taskStatusConsumersFinal.addAll(features.taskStatusConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_TASKS_STATUS, + asyncTaskStatusNotificationHandler(taskStatusConsumersFinal)); + Function> postInitializationHook = init -> { if (init.initializeResult().capabilities().tools() == null || !enableCallToolSchemaCaching) { @@ -387,6 +438,10 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { + // Shutdown task store to clean up any scheduled cleanup tasks + if (this.taskStore != null) { + this.taskStore.shutdown().block(Duration.ofSeconds(5)); + } this.initializer.close(); this.transport.close(); } @@ -397,7 +452,10 @@ public void close() { */ public Mono closeGracefully() { return Mono.defer(() -> { - return this.initializer.closeGracefully().then(transport.closeGracefully()); + Mono taskStoreShutdown = this.taskStore != null ? this.taskStore.shutdown() : Mono.empty(); + return taskStoreShutdown.timeout(Duration.ofSeconds(5), Mono.empty()) + .then(this.initializer.closeGracefully()) + .then(transport.closeGracefully()); }); } @@ -541,29 +599,153 @@ private RequestHandler rootsListRequestHandler() { }; } + // -------------------------- + // Task-Augmented Request Support + // -------------------------- + + /** + * Executes a task-augmented request, creating a background task and returning + * immediately with a CreateTaskResult. This is a helper method that extracts the + * common logic for task-augmented sampling and elicitation requests. + * @param The result type (must implement ClientTaskPayloadResult) + * @param originatingRequest The original MCP request that triggered task creation + * @param taskMetadata The task metadata from the request + * @param handlerMono The handler execution that produces the result + * @param operationType Name of the operation for logging (e.g., "sampling", + * "elicitation") + * @return A Mono that emits CreateTaskResult immediately + */ + private Mono executeTaskAugmentedRequest( + McpSchema.Request originatingRequest, McpSchema.TaskMetadata taskMetadata, Mono handlerMono, + String operationType) { + return this.taskStore + .createTask(CreateTaskOptions.builder(originatingRequest).requestedTtl(taskMetadata.ttl()).build()) + .flatMap(task -> { + // Execute the handler in the background (fire-and-forget). + // The subscription completes naturally when the handler finishes. + // We don't track subscriptions because: + // 1. Disposing mid-flight causes data loss (result never stored) + // 2. Completed subscriptions are no-ops anyway + // 3. The taskStore handles cleanup via TTL expiration + handlerMono + .flatMap(result -> this.taskStore.storeTaskResult(task.taskId(), null, + McpSchema.TaskStatus.COMPLETED, result)) + .onErrorResume(error -> this.taskStore + .updateTaskStatus(task.taskId(), null, McpSchema.TaskStatus.FAILED, error.getMessage()) + .onErrorResume(storeError -> { + logger.error("Failed to update {} task status for {}: {}", operationType, task.taskId(), + storeError.getMessage()); + return Mono.empty(); + })) + .subscribe(unused -> { + // Background task completed successfully + }, error -> logger.error("Unexpected error in {} task {}: {}", operationType, task.taskId(), + error.getMessage())); + // Return CreateTaskResult immediately + return Mono.just((McpSchema.Result) McpSchema.CreateTaskResult.builder().task(task).build()); + }); + } + // -------------------------- // Sampling // -------------------------- - private RequestHandler samplingCreateMessageHandler() { + + private RequestHandler samplingCreateMessageHandler() { return params -> { McpSchema.CreateMessageRequest request = transport.unmarshalFrom(params, CREATE_MESSAGE_REQUEST_TYPE_REF); - return this.samplingHandler.apply(request); + // Check for task-augmented request + if (request.task() != null && this.taskStore != null) { + return executeTaskAugmentedRequest(request, request.task(), this.samplingHandler.apply(request), + "sampling"); + } + + // Non-task-augmented request - execute directly + return this.samplingHandler.apply(request).map(result -> (McpSchema.Result) result); }; } // -------------------------- // Elicitation // -------------------------- - private RequestHandler elicitationCreateHandler() { + + private RequestHandler elicitationCreateHandler() { return params -> { ElicitRequest request = transport.unmarshalFrom(params, new TypeRef<>() { }); - return this.elicitationHandler.apply(request); + // Check for task-augmented request + if (request.task() != null && this.taskStore != null) { + return executeTaskAugmentedRequest(request, request.task(), this.elicitationHandler.apply(request), + "elicitation"); + } + + // Non-task-augmented request - execute directly + return this.elicitationHandler.apply(request).map(result -> (McpSchema.Result) result); + }; + } + + // -------------------------- + // Client-Side Task Hosting + // -------------------------- + + /** + * Handler for tasks/get requests from the server (client-hosted tasks). + */ + private RequestHandler clientTasksGetHandler() { + return params -> { + McpSchema.GetTaskRequest request = transport.unmarshalFrom(params, new TypeRef<>() { + }); + return this.taskStore.getTask(request.taskId(), null) + .map(result -> result.task()) + .map(McpSchema.GetTaskResult::fromTask); + }; + } + + /** + * Handler for tasks/result requests from the server (client-hosted tasks). + */ + private RequestHandler clientTasksResultHandler() { + return params -> { + McpSchema.GetTaskPayloadRequest request = transport.unmarshalFrom(params, new TypeRef<>() { + }); + return this.taskStore.getTaskResult(request.taskId(), null).cast(McpSchema.Result.class); + }; + } + + /** + * Handler for tasks/list requests from the server (client-hosted tasks). + */ + private RequestHandler clientTasksListHandler() { + return params -> { + McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, new TypeRef<>() { + }); + return this.taskStore.listTasks(request != null ? request.cursor() : null, null); + }; + } + + /** + * Handler for tasks/cancel requests from the server (client-hosted tasks). + */ + private RequestHandler clientTasksCancelHandler() { + return params -> { + McpSchema.CancelTaskRequest request = transport.unmarshalFrom(params, new TypeRef<>() { + }); + return this.taskStore.requestCancellation(request.taskId(), null).map(McpSchema.CancelTaskResult::fromTask); }; } + /** + * Returns the task store used for client-side task hosting. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @return the task store, or null if client-side task hosting is not configured + */ + public TaskStore getTaskStore() { + return this.taskStore; + } + // -------------------------- // Tools // -------------------------- @@ -596,6 +778,370 @@ public Mono callTool(McpSchema.CallToolRequest callToo }); } + /** + * Low-level method that invokes a tool with task augmentation, creating a background + * task for long-running operations. + * + *

+ * Recommendation: For most use cases, prefer {@link #callToolStream} + * which provides a unified streaming interface that handles both regular and + * task-augmented tool calls automatically, including polling and result retrieval. + * + *

+ * When calling a tool with task augmentation, the server creates a task and returns + * immediately with a {@link McpSchema.CreateTaskResult} containing the task ID. The + * actual tool execution happens asynchronously. Use {@link #getTask} to poll for task + * status and {@link #getTaskResult} to retrieve the result once completed. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * + *

+ * Example usage (manual polling): + * + *

{@code
+	 * var request = new CallToolRequest("slow-operation", args,
+	 *     new TaskMetadata(60000L), null);  // 60s TTL
+	 * var createResult = client.callToolTask(request).block();
+	 * String taskId = createResult.task().taskId();
+	 *
+	 * // Poll until complete
+	 * while (true) {
+	 *     var task = client.getTask(taskId).block();
+	 *     if (task.isTerminal()) break;
+	 *     Thread.sleep(task.pollInterval());
+	 * }
+	 *
+	 * // Get result
+	 * var result = client.getTaskResult(taskId, new TypeRef<CallToolResult>(){}).block();
+	 * }
+ * @param callToolRequest The request containing the tool name, parameters, and task + * metadata. The {@code task} field must be non-null. + * @return A Mono that emits the task creation result containing the task ID and + * initial status. + * @throws IllegalArgumentException if the request does not include task metadata + * @see #callToolStream + * @see McpSchema.CallToolRequest + * @see McpSchema.CreateTaskResult + * @see McpSchema.TaskMetadata + * @see #getTask + * @see #getTaskResult + */ + public Mono callToolTask(McpSchema.CallToolRequest callToolRequest) { + if (callToolRequest.task() == null) { + return Mono.error(new IllegalArgumentException( + "Task metadata is required for task-augmented tool calls. Use callTool() for regular tool calls.")); + } + return this.initializer.withInitialization("calling tool with task", init -> { + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); + } + var caps = init.initializeResult().capabilities(); + if (caps.tasks() == null || caps.tasks().requests() == null || caps.tasks().requests().tools() == null + || caps.tasks().requests().tools().call() == null) { + return Mono.error( + new IllegalStateException("Server does not provide tasks capability with tools.call support")); + } + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CREATE_TASK_RESULT_TYPE_REF); + }); + } + + /** + * The default poll interval in milliseconds for task status polling when the server + * does not specify one. + */ + private static final long DEFAULT_TASK_POLL_INTERVAL_MS = TaskDefaults.DEFAULT_POLL_INTERVAL_MS; + + /** + * Default timeout for task polling operations. If not explicitly configured via the + * builder, polling will timeout after this duration to prevent infinite loops. + */ + private static final Duration DEFAULT_TASK_POLL_TIMEOUT = Duration.ofMinutes(5); + + /** + * Calls a tool and returns a stream of response messages, handling both regular and + * task-augmented tool calls automatically. + * + *

+ * This method provides a unified streaming interface for tool execution: + *

    + *
  • For non-task requests (when {@code task} field is null): + * yields a single {@link McpSchema.ResultMessage} or {@link McpSchema.ErrorMessage} + *
  • For task-augmented requests: yields + * {@link McpSchema.TaskCreatedMessage} → zero or more + * {@link McpSchema.TaskStatusMessage} → {@link McpSchema.ResultMessage} or + * {@link McpSchema.ErrorMessage} + *
+ * + *

+ * This is the recommended way to call tools when you want to support both regular and + * long-running tool executions without having to handle the decision logic yourself. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * + *

+ * Example usage (Java 21+ with pattern matching for switch): + * + *

{@code
+	 * var request = new CallToolRequest("my-tool", Map.of("arg", "value"),
+	 *     new TaskMetadata(60000L), null);  // Optional task metadata
+	 *
+	 * client.callToolStream(request)
+	 *     .subscribe(message -> {
+	 *         switch (message) {
+	 *             case TaskCreatedMessage tc ->
+	 *                 log.info("Task created: {}", tc.task().taskId());
+	 *             case TaskStatusMessage ts ->
+	 *                 log.info("Status: {} - {}", ts.task().status(), ts.task().statusMessage());
+	 *             case ResultMessage r ->
+	 *                 handleResult(r.result());
+	 *             case ErrorMessage e ->
+	 *                 handleError(e.error());
+	 *         }
+	 *     });
+	 * }
+ * @param callToolRequest The request containing the tool name and arguments. If the + * {@code task} field is set, the call will be task-augmented. + * @return A Flux that emits {@link McpSchema.ResponseMessage} instances representing + * the progress and result of the tool call. + * @see McpSchema.ResponseMessage + * @see McpSchema.TaskCreatedMessage + * @see McpSchema.TaskStatusMessage + * @see McpSchema.ResultMessage + * @see McpSchema.ErrorMessage + */ + public Flux> callToolStream( + McpSchema.CallToolRequest callToolRequest) { + // For non-task requests, just wrap the result in a single message + if (callToolRequest.task() == null) { + return this + .callTool(callToolRequest).>map(McpSchema.ResultMessage::of) + .onErrorResume(error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message(error.getMessage() != null ? error.getMessage() : "Unknown error") + .build(); + return Mono.just(McpSchema.ErrorMessage.of(mcpError)); + }) + .flux(); + } + + // For task-augmented requests, handle the full lifecycle + return Flux.create(sink -> { + // Cancellation flag to stop polling when subscriber cancels + AtomicBoolean cancelled = new AtomicBoolean(false); + + // Composite disposable to track all active subscriptions and prevent memory + // leaks + Disposable.Composite disposables = Disposables.composite(); + + sink.onCancel(() -> { + cancelled.set(true); + disposables.dispose(); + }); + sink.onDispose(() -> { + cancelled.set(true); + disposables.dispose(); + }); + + // Step 1: Create the task + Disposable createTaskSub = this.callToolTask(callToolRequest).subscribe(createResult -> { + // Check if cancelled before proceeding + if (cancelled.get()) { + return; + } + + McpSchema.Task task = createResult.task(); + if (task == null) { + sink.error(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task creation did not return a task") + .build()); + return; + } + + // Emit taskCreated message + sink.next(McpSchema.TaskCreatedMessage.of(task)); + + // Step 2: Start polling for task status (record start time for timeout) + pollTaskUntilTerminal(task.taskId(), sink, Instant.now(), cancelled, disposables); + }, error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message(error.getMessage() != null ? error.getMessage() : "Unknown error") + .build(); + sink.next(McpSchema.ErrorMessage.of(mcpError)); + sink.complete(); + }); + disposables.add(createTaskSub); + }); + } + + /** + * Polls task status until it reaches a terminal state, emitting status updates and + * final result. + * @param taskId the task ID to poll + * @param sink the sink to emit messages to + * @param startTime the time when polling started, used for timeout calculation + * @param cancelled cancellation flag - when true, polling should stop + * @param disposables composite disposable to track active subscriptions for cleanup + */ + private void pollTaskUntilTerminal(String taskId, + reactor.core.publisher.FluxSink> sink, + Instant startTime, AtomicBoolean cancelled, Disposable.Composite disposables) { + + // Check if cancelled before proceeding + if (cancelled.get()) { + return; + } + + // Check timeout (use configured timeout or default to prevent infinite loops) + Duration effectiveTimeout = this.taskPollTimeout != null ? this.taskPollTimeout : DEFAULT_TASK_POLL_TIMEOUT; + Duration elapsed = Duration.between(startTime, Instant.now()); + if (elapsed.compareTo(effectiveTimeout) > 0) { + sink.next(McpSchema.ErrorMessage.of(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task polling timed out after " + effectiveTimeout) + .build())); + sink.complete(); + return; + } + + Disposable getTaskSub = this.getTask(McpSchema.GetTaskRequest.builder().taskId(taskId).build()) + .subscribe(taskResult -> { + // Check if cancelled before processing result + if (cancelled.get()) { + return; + } + + McpSchema.Task task = taskResult.toTask(); + + // Emit status update + sink.next(McpSchema.TaskStatusMessage.of(task)); + + // Check TTL enforcement - if task has expired based on server's TTL + if (task.ttl() != null && task.createdAt() != null) { + try { + Instant createdAt = Instant.parse(task.createdAt()); + Instant expiresAt = createdAt.plusMillis(task.ttl()); + if (Instant.now().isAfter(expiresAt)) { + sink.next(McpSchema.ErrorMessage.of(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task TTL expired after " + task.ttl() + "ms") + .build())); + sink.complete(); + return; + } + } + catch (DateTimeParseException e) { + // Ignore TTL check errors and continue polling + } + } + + // Check if terminal + if (task.isTerminal()) { + handleTerminalTask(taskId, task, sink, cancelled, disposables); + } + else if (task.status() == McpSchema.TaskStatus.INPUT_REQUIRED) { + // For input_required, call tasks/result which blocks until terminal + // (This handles elicitation/sampling that may happen server-side) + fetchTaskResultAndComplete(taskId, sink, cancelled, disposables); + } + else { + // Schedule next poll (only if not cancelled) + if (!cancelled.get()) { + long pollInterval = task.pollInterval() != null ? task.pollInterval() + : DEFAULT_TASK_POLL_INTERVAL_MS; + Disposable delaySub = Mono.delay(Duration.ofMillis(pollInterval)) + .subscribe( + ignored -> pollTaskUntilTerminal(taskId, sink, startTime, cancelled, disposables)); + disposables.add(delaySub); + } + } + }, error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message(error.getMessage() != null ? error.getMessage() : "Unknown error") + .build(); + sink.next(McpSchema.ErrorMessage.of(mcpError)); + sink.complete(); + }); + disposables.add(getTaskSub); + } + + /** + * Handles a task that has reached a terminal state. + * @param taskId the task ID + * @param task the task in terminal state + * @param sink the sink to emit messages to + * @param cancelled cancellation flag - when true, should stop processing + * @param disposables composite disposable to track active subscriptions for cleanup + */ + private void handleTerminalTask(String taskId, McpSchema.Task task, + reactor.core.publisher.FluxSink> sink, + AtomicBoolean cancelled, Disposable.Composite disposables) { + // Check if cancelled before proceeding + if (cancelled.get()) { + return; + } + + if (task.status() == McpSchema.TaskStatus.COMPLETED) { + fetchTaskResultAndComplete(taskId, sink, cancelled, disposables); + } + else if (task.status() == McpSchema.TaskStatus.FAILED) { + String message = task.statusMessage() != null ? task.statusMessage() : "Task " + taskId + " failed"; + sink.next(McpSchema.ErrorMessage + .of(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(message).build())); + sink.complete(); + } + else if (task.status() == McpSchema.TaskStatus.CANCELLED) { + sink.next(McpSchema.ErrorMessage.of(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task " + taskId + " was cancelled") + .build())); + sink.complete(); + } + else { + sink.complete(); + } + } + + /** + * Fetches the task result and completes the stream. + * @param taskId the task ID + * @param sink the sink to emit messages to + * @param cancelled cancellation flag - when true, should stop processing + * @param disposables composite disposable to track active subscriptions for cleanup + */ + private void fetchTaskResultAndComplete(String taskId, + reactor.core.publisher.FluxSink> sink, + AtomicBoolean cancelled, Disposable.Composite disposables) { + // Check if cancelled before proceeding + if (cancelled.get()) { + return; + } + + Disposable fetchResultSub = this + .getTaskResult(McpSchema.GetTaskPayloadRequest.builder().taskId(taskId).build(), CALL_TOOL_RESULT_TYPE_REF) + .subscribe(result -> { + // Check if cancelled before emitting result + if (cancelled.get()) { + return; + } + sink.next(McpSchema.ResultMessage.of(result)); + sink.complete(); + }, error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message(error.getMessage() != null ? error.getMessage() : "Unknown error") + .build(); + sink.next(McpSchema.ErrorMessage.of(mcpError)); + sink.complete(); + }); + disposables.add(fetchResultSub); + } + private McpSchema.CallToolResult validateToolResult(String toolName, McpSchema.CallToolResult result) { if (!this.enableCallToolSchemaCaching || result == null || result.isError() == Boolean.TRUE) { @@ -969,6 +1515,23 @@ private NotificationHandler asyncProgressNotificationHandler( }; } + private NotificationHandler asyncTaskStatusNotificationHandler( + List>> taskStatusConsumers) { + + return params -> { + McpSchema.TaskStatusNotification taskStatusNotification = transport.unmarshalFrom(params, + TASK_STATUS_NOTIFICATION_TYPE_REF); + + return Flux.fromIterable(taskStatusConsumers) + .flatMap(consumer -> consumer.apply(taskStatusNotification)) + .onErrorResume(error -> { + logger.error("Error handling task status notification", error); + return Mono.empty(); + }) + .then(); + }; + } + /** * This method is package-private and used for test only. Should not be called by user * code. @@ -999,4 +1562,217 @@ public Mono completeCompletion(McpSchema.CompleteReque .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } + // --------------------------------------- + // Tasks (Experimental) + // --------------------------------------- + + private static final TypeRef GET_TASK_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef LIST_TASKS_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef CANCEL_TASK_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef CREATE_TASK_RESULT_TYPE_REF = new TypeRef<>() { + }; + + /** + * Get the status and metadata of a task. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param getTaskRequest The request containing the task ID. + * @return A Mono that completes with the task status and metadata. + * @see McpSchema.GetTaskRequest + * @see McpSchema.GetTaskResult + */ + public Mono getTask(McpSchema.GetTaskRequest getTaskRequest) { + return this.initializer.withInitialization("get task", init -> { + if (init.initializeResult().capabilities().tasks() == null) { + return Mono.error(new IllegalStateException("Server does not provide tasks capability")); + } + return init.mcpSession().sendRequest(McpSchema.METHOD_TASKS_GET, getTaskRequest, GET_TASK_RESULT_TYPE_REF); + }); + } + + /** + * Get the status and metadata of a task by ID. + * + *

+ * This is a convenience overload that creates a {@link McpSchema.GetTaskRequest} with + * the given task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param taskId The task identifier to query. + * @return A Mono that completes with the task status and metadata. + */ + public Mono getTask(String taskId) { + Assert.hasText(taskId, "Task ID must not be null or empty"); + return Mono.defer(() -> getTask(McpSchema.GetTaskRequest.builder().taskId(taskId).build())); + } + + /** + * Get the result of a completed task. + * + *

+ * The result type depends on the original request that created the task. For tool + * calls, use {@code new TypeRef(){}}. For sampling + * requests, use {@code new TypeRef(){}}. + * + *

+ * Example usage: + * + *

{@code
+	 * // For tool task results:
+	 * var result = client.getTaskResult(
+	 *     new GetTaskPayloadRequest(taskId, null),
+	 *     new TypeRef(){})
+	 *     .block();
+	 * }
+ * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param The expected result type, must extend {@link McpSchema.Result} + * @param getTaskPayloadRequest The request containing the task ID. + * @param resultTypeRef Type reference for deserializing the result. + * @return A Mono that completes with the task result. + * @see McpSchema.GetTaskPayloadRequest + * @see McpSchema.CallToolResult + * @see McpSchema.CreateMessageResult + */ + public Mono getTaskResult( + McpSchema.GetTaskPayloadRequest getTaskPayloadRequest, TypeRef resultTypeRef) { + return this.initializer.withInitialization("get task result", init -> { + if (init.initializeResult().capabilities().tasks() == null) { + return Mono.error(new IllegalStateException("Server does not provide tasks capability")); + } + return init.mcpSession().sendRequest(McpSchema.METHOD_TASKS_RESULT, getTaskPayloadRequest, resultTypeRef); + }); + } + + /** + * Retrieves the result of a completed task by task ID. + * + *

+ * This is a convenience overload that creates a + * {@link McpSchema.GetTaskPayloadRequest} from the task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param The expected result type, must extend {@link McpSchema.Result} + * @param taskId The task identifier. + * @param resultTypeRef Type reference for deserializing the result. + * @return A Mono that completes with the task result. + */ + public Mono getTaskResult(String taskId, + TypeRef resultTypeRef) { + Assert.hasText(taskId, "Task ID must not be null or empty"); + Assert.notNull(resultTypeRef, "Result type reference must not be null"); + return Mono.defer( + () -> getTaskResult(McpSchema.GetTaskPayloadRequest.builder().taskId(taskId).build(), resultTypeRef)); + } + + /** + * List all tasks known by the server. + * + *

+ * This method automatically handles pagination, fetching all pages and combining them + * into a single result with an unmodifiable list. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @return A Mono that completes with the list of all tasks. + * @see McpSchema.ListTasksResult + */ + public Mono listTasks() { + return this.listTasks(McpSchema.FIRST_PAGE).expand(result -> { + String next = result.nextCursor(); + return (next != null && !next.isEmpty()) ? this.listTasks(next) : Mono.empty(); + }).reduce(McpSchema.ListTasksResult.builder().tasks(new ArrayList<>()).build(), (allTasksResult, result) -> { + allTasksResult.tasks().addAll(result.tasks()); + return allTasksResult; + }) + .map(result -> McpSchema.ListTasksResult.builder() + .tasks(Collections.unmodifiableList(result.tasks())) + .build()); + } + + /** + * List tasks known by the server with pagination support. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cursor Pagination cursor from a previous list request. + * @return A Mono that completes with a page of tasks. + * @see McpSchema.ListTasksResult + */ + public Mono listTasks(String cursor) { + return this.initializer.withInitialization("list tasks", init -> { + if (init.initializeResult().capabilities().tasks() == null) { + return Mono.error(new IllegalStateException("Server does not provide tasks capability")); + } + if (init.initializeResult().capabilities().tasks().list() == null) { + return Mono.error(new IllegalStateException("Server does not provide tasks.list capability")); + } + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TASKS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TASKS_RESULT_TYPE_REF); + }); + } + + /** + * Request cancellation of a task. + * + *

+ * Note that cancellation is cooperative - the server may not honor the cancellation + * request, or may take some time to cancel the task. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cancelTaskRequest The request containing the task ID. + * @return A Mono that completes with the updated task status. + * @see McpSchema.CancelTaskRequest + * @see McpSchema.CancelTaskResult + */ + public Mono cancelTask(McpSchema.CancelTaskRequest cancelTaskRequest) { + return this.initializer.withInitialization("cancel task", init -> { + if (init.initializeResult().capabilities().tasks() == null) { + return Mono.error(new IllegalStateException("Server does not provide tasks capability")); + } + if (init.initializeResult().capabilities().tasks().cancel() == null) { + return Mono.error(new IllegalStateException("Server does not provide tasks.cancel capability")); + } + return init.mcpSession() + .sendRequest(McpSchema.METHOD_TASKS_CANCEL, cancelTaskRequest, CANCEL_TASK_RESULT_TYPE_REF); + }); + } + + /** + * Request cancellation of a task by ID. + * + *

+ * This is a convenience overload that creates a {@link McpSchema.CancelTaskRequest} + * with the given task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param taskId The task identifier to cancel. + * @return A Mono that completes with the updated task status. + */ + public Mono cancelTask(String taskId) { + Assert.hasText(taskId, "Task ID must not be null or empty"); + return Mono.defer(() -> cancelTask(McpSchema.CancelTaskRequest.builder().taskId(taskId).build())); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java index c9989f832..974a36d2d 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.experimental.tasks.TaskStore; import io.modelcontextprotocol.json.schema.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; @@ -184,6 +185,8 @@ class SyncSpec { private final List> progressConsumers = new ArrayList<>(); + private final List> taskStatusConsumers = new ArrayList<>(); + private Function samplingHandler; private Function elicitationHandler; @@ -194,6 +197,10 @@ class SyncSpec { private boolean enableCallToolSchemaCaching = false; // Default to false + private TaskStore taskStore; + + private Duration taskPollTimeout; // null = use default (5 minutes) + private SyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -317,6 +324,44 @@ public SyncSpec elicitation(Function elicitationHan return this; } + /** + * Sets the task store for client-side task hosting. When set, the client can host + * tasks for task-augmented sampling and elicitation requests from the server. + * + *

+ * This is an experimental feature that may change in future releases. + * @param taskStore The task store implementation. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStore is null + */ + public SyncSpec taskStore(TaskStore taskStore) { + Assert.notNull(taskStore, "Task store must not be null"); + this.taskStore = taskStore; + return this; + } + + /** + * Sets the maximum time to wait for a task to reach a terminal state during task + * result polling. + * + *

+ * When using task-augmented requests (e.g., long-running tool calls), the client + * polls the server for task status updates. This timeout limits how long the + * client will wait for the task to complete, fail, or be cancelled. + * + *

+ * If not set, defaults to 5 minutes to prevent infinite polling loops. + * + *

+ * This is an experimental feature that may change in future releases. + * @param timeout maximum poll duration, or null to use the default (5 minutes) + * @return This builder instance for method chaining + */ + public SyncSpec taskPollTimeout(Duration timeout) { + this.taskPollTimeout = timeout; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -428,7 +473,7 @@ public SyncSpec progressConsumer(Consumer progre * @param progressConsumers A list of consumers that receives progress * notifications. Must not be null. * @return This builder instance for method chaining - * @throws IllegalArgumentException if progressConsumer is null + * @throws IllegalArgumentException if progressConsumers is null */ public SyncSpec progressConsumers(List> progressConsumers) { Assert.notNull(progressConsumers, "Progress consumers must not be null"); @@ -436,6 +481,34 @@ public SyncSpec progressConsumers(List> return this; } + /** + * Adds a consumer to be notified of task status notifications from the server. + * This enables clients to receive updates about task progress and status changes. + * @param taskStatusConsumer A consumer that receives task status notifications. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStatusConsumer is null + */ + public SyncSpec taskStatusConsumer(Consumer taskStatusConsumer) { + Assert.notNull(taskStatusConsumer, "Task status consumer must not be null"); + this.taskStatusConsumers.add(taskStatusConsumer); + return this; + } + + /** + * Adds multiple consumers to be notified of task status notifications from the + * server. + * @param taskStatusConsumers A list of consumers that receive task status + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStatusConsumers is null + */ + public SyncSpec taskStatusConsumers(List> taskStatusConsumers) { + Assert.notNull(taskStatusConsumers, "Task status consumers must not be null"); + this.taskStatusConsumers.addAll(taskStatusConsumers); + return this; + } + /** * Add a provider of {@link McpTransportContext}, providing a context before * calling any client operation. This allows to extract thread-locals and hand @@ -486,14 +559,15 @@ public SyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching) public McpSyncClient build() { McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, - this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler, - this.elicitationHandler, this.enableCallToolSchemaCaching); + this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, + this.taskStatusConsumers, this.samplingHandler, this.elicitationHandler, + this.enableCallToolSchemaCaching, this.taskPollTimeout, this.taskStore != null); McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, - jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault(), - asyncFeatures), this.contextProvider); + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault(), asyncFeatures, + this.taskStore), this.contextProvider); } } @@ -540,6 +614,8 @@ class AsyncSpec { private final List>> progressConsumers = new ArrayList<>(); + private final List>> taskStatusConsumers = new ArrayList<>(); + private Function> samplingHandler; private Function> elicitationHandler; @@ -548,6 +624,10 @@ class AsyncSpec { private boolean enableCallToolSchemaCaching = false; // Default to false + private TaskStore taskStore; + + private Duration taskPollTimeout; // null = use default (5 minutes) + private AsyncSpec(McpClientTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; @@ -671,6 +751,22 @@ public AsyncSpec elicitation(Function> elicita return this; } + /** + * Sets the task store for client-side task hosting. When set, the client can host + * tasks for task-augmented sampling and elicitation requests from the server. + * + *

+ * This is an experimental feature that may change in future releases. + * @param taskStore The task store implementation. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStore is null + */ + public AsyncSpec taskStore(TaskStore taskStore) { + Assert.notNull(taskStore, "Task store must not be null"); + this.taskStore = taskStore; + return this; + } + /** * Adds a consumer to be notified when the available tools change. This allows the * client to react to changes in the server's tool capabilities, such as tools @@ -785,7 +881,7 @@ public AsyncSpec progressConsumer(Function>> progressConsumers) { @@ -794,6 +890,35 @@ public AsyncSpec progressConsumers( return this; } + /** + * Adds a consumer to be notified of task status notifications from the server. + * This enables clients to receive updates about task progress and status changes. + * @param taskStatusConsumer A consumer that receives task status notifications. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStatusConsumer is null + */ + public AsyncSpec taskStatusConsumer(Function> taskStatusConsumer) { + Assert.notNull(taskStatusConsumer, "Task status consumer must not be null"); + this.taskStatusConsumers.add(taskStatusConsumer); + return this; + } + + /** + * Adds multiple consumers to be notified of task status notifications from the + * server. + * @param taskStatusConsumers A list of consumers that receive task status + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStatusConsumers is null + */ + public AsyncSpec taskStatusConsumers( + List>> taskStatusConsumers) { + Assert.notNull(taskStatusConsumers, "Task status consumers must not be null"); + this.taskStatusConsumers.addAll(taskStatusConsumers); + return this; + } + /** * Sets the JSON schema validator to use for validating tool responses against * output schemas. @@ -819,6 +944,21 @@ public AsyncSpec enableCallToolSchemaCaching(boolean enableCallToolSchemaCaching return this; } + /** + * Sets the maximum duration to poll for task completion in + * {@code callToolStream()}. If not set, defaults to 5 minutes to prevent infinite + * polling loops. + * + *

+ * This is an experimental feature that may change in future releases. + * @param timeout maximum poll duration, or null to use the default (5 minutes) + * @return This builder instance for method chaining + */ + public AsyncSpec taskPollTimeout(Duration timeout) { + this.taskPollTimeout = timeout; + return this; + } + /** * Create an instance of {@link McpAsyncClient} with the provided configurations * or sensible defaults. @@ -832,7 +972,9 @@ public McpAsyncClient build() { new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, - this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching)); + this.taskStatusConsumers, this.samplingHandler, this.elicitationHandler, + this.enableCallToolSchemaCaching, this.taskPollTimeout, this.taskStore != null), + this.taskStore); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 127d53337..18da97795 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.client; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -71,9 +72,10 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> promptsChangeConsumers, List>> loggingConsumers, List>> progressConsumers, + List>> taskStatusConsumers, Function> samplingHandler, Function> elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, Duration taskPollTimeout, boolean taskStorePresent) { /** * Create an instance and validate the arguments. @@ -96,9 +98,10 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> promptsChangeConsumers, List>> loggingConsumers, List>> progressConsumers, + List>> taskStatusConsumers, Function> samplingHandler, Function> elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, Duration taskPollTimeout, boolean taskStorePresent) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -106,7 +109,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, - elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null, + taskStorePresent ? buildTaskCapabilities(samplingHandler, elicitationHandler) : null); this.roots = roots != null ? new ConcurrentHashMap<>(roots) : new ConcurrentHashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -115,9 +119,29 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); + this.taskStatusConsumers = taskStatusConsumers != null ? taskStatusConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + this.taskPollTimeout = taskPollTimeout; + this.taskStorePresent = taskStorePresent; + } + + /** + * Build the task capabilities based on the presence of sampling and elicitation + * handlers. + */ + private static McpSchema.ClientCapabilities.ClientTaskCapabilities buildTaskCapabilities( + Function> samplingHandler, + Function> elicitationHandler) { + var builder = McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().list().cancel(); + if (samplingHandler != null) { + builder.samplingCreateMessage(); + } + if (elicitationHandler != null) { + builder.elicitationCreate(); + } + return builder.build(); } /** @@ -133,8 +157,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c Function> samplingHandler, Function> elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, - resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler, false); + resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), List.of(), + samplingHandler, elicitationHandler, false, null, false); } /** @@ -182,18 +206,28 @@ public static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } - Function> samplingHandler = r -> Mono - .fromCallable(() -> syncSpec.samplingHandler().apply(r)) - .subscribeOn(Schedulers.boundedElastic()); + List>> taskStatusConsumers = new ArrayList<>(); + for (Consumer consumer : syncSpec.taskStatusConsumers()) { + taskStatusConsumers.add(n -> Mono.fromRunnable(() -> consumer.accept(n)) + .subscribeOn(Schedulers.boundedElastic())); + } - Function> elicitationHandler = r -> Mono - .fromCallable(() -> syncSpec.elicitationHandler().apply(r)) - .subscribeOn(Schedulers.boundedElastic()); + Function> samplingHandler = syncSpec + .samplingHandler() != null + ? r -> Mono.fromCallable(() -> syncSpec.samplingHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()) + : null; + + Function> elicitationHandler = syncSpec + .elicitationHandler() != null + ? r -> Mono.fromCallable(() -> syncSpec.elicitationHandler().apply(r)) + .subscribeOn(Schedulers.boundedElastic()) + : null; return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, - loggingConsumers, progressConsumers, samplingHandler, elicitationHandler, - syncSpec.enableCallToolSchemaCaching); + loggingConsumers, progressConsumers, taskStatusConsumers, samplingHandler, elicitationHandler, + syncSpec.enableCallToolSchemaCaching(), syncSpec.taskPollTimeout(), syncSpec.taskStorePresent()); } } @@ -220,9 +254,10 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List>> promptsChangeConsumers, List> loggingConsumers, List> progressConsumers, + List> taskStatusConsumers, Function samplingHandler, Function elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, Duration taskPollTimeout, boolean taskStorePresent) { /** * Create an instance and validate the arguments. @@ -246,9 +281,10 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List>> promptsChangeConsumers, List> loggingConsumers, List> progressConsumers, + List> taskStatusConsumers, Function samplingHandler, Function elicitationHandler, - boolean enableCallToolSchemaCaching) { + boolean enableCallToolSchemaCaching, Duration taskPollTimeout, boolean taskStorePresent) { Assert.notNull(clientInfo, "Client info must not be null"); this.clientInfo = clientInfo; @@ -256,7 +292,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl : new McpSchema.ClientCapabilities(null, !Utils.isEmpty(roots) ? new McpSchema.ClientCapabilities.RootCapabilities(false) : null, samplingHandler != null ? new McpSchema.ClientCapabilities.Sampling() : null, - elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null); + elicitationHandler != null ? new McpSchema.ClientCapabilities.Elicitation() : null, + taskStorePresent ? buildTaskCapabilities(samplingHandler, elicitationHandler) : null); this.roots = roots != null ? new HashMap<>(roots) : new HashMap<>(); this.toolsChangeConsumers = toolsChangeConsumers != null ? toolsChangeConsumers : List.of(); @@ -265,9 +302,29 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); + this.taskStatusConsumers = taskStatusConsumers != null ? taskStatusConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; + this.taskPollTimeout = taskPollTimeout; + this.taskStorePresent = taskStorePresent; + } + + /** + * Build the task capabilities based on the presence of sampling and elicitation + * handlers. + */ + private static McpSchema.ClientCapabilities.ClientTaskCapabilities buildTaskCapabilities( + Function samplingHandler, + Function elicitationHandler) { + var builder = McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().list().cancel(); + if (samplingHandler != null) { + builder.samplingCreateMessage(); + } + if (elicitationHandler != null) { + builder.elicitationCreate(); + } + return builder.build(); } /** @@ -282,8 +339,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl Function samplingHandler, Function elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, - resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler, false); + resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), List.of(), + samplingHandler, elicitationHandler, false, null, false); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 7fdaa8941..fac8625ae 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -5,18 +5,22 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.List; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.experimental.tasks.TaskStore; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult; import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -200,14 +204,14 @@ public void rootsListChangedNotification() { * Add a roots dynamically. */ public void addRoot(McpSchema.Root root) { - this.delegate.addRoot(root).block(); + withProvidedContext(this.delegate.addRoot(root)).block(); } /** * Remove a root dynamically. */ public void removeRoot(String rootUri) { - this.delegate.removeRoot(rootUri).block(); + withProvidedContext(this.delegate.removeRoot(rootUri)).block(); } /** @@ -237,6 +241,104 @@ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolReque } + /** + * Low-level method that invokes a tool with task augmentation, creating a background + * task for long-running operations. + * + *

+ * Recommendation: For most use cases, prefer {@link #callToolStream} + * which provides a unified interface that handles both regular and task-augmented + * tool calls automatically, including polling and result retrieval. + * + *

+ * When calling a tool with task augmentation, the server creates a task and returns + * immediately with a {@link McpSchema.CreateTaskResult} containing the task ID. The + * actual tool execution happens asynchronously. Use {@link #getTask} to poll for task + * status and {@link #getTaskResult} to retrieve the result once completed. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param callToolRequest The request containing the tool name, parameters, and task + * metadata. The {@code task} field must be non-null. + * @return The task creation result containing the task ID and initial status. + * @throws IllegalArgumentException if the request does not include task metadata + * @see #callToolStream + * @see McpSchema.CallToolRequest + * @see McpSchema.CreateTaskResult + * @see McpSchema.TaskMetadata + * @see #getTask + * @see #getTaskResult + */ + public McpSchema.CreateTaskResult callToolTask(McpSchema.CallToolRequest callToolRequest) { + return withProvidedContext(this.delegate.callToolTask(callToolRequest)).block(); + } + + /** + * Calls a tool and returns a list of response messages, handling both regular and + * task-augmented tool calls automatically. + * + *

+ * This method provides a unified interface for tool execution: + *

    + *
  • For non-task requests (when {@code task} field is null): + * returns a list with a single {@link McpSchema.ResultMessage} or + * {@link McpSchema.ErrorMessage} + *
  • For task-augmented requests: returns a list containing + * {@link McpSchema.TaskCreatedMessage} followed by zero or more + * {@link McpSchema.TaskStatusMessage} and ending with {@link McpSchema.ResultMessage} + * or {@link McpSchema.ErrorMessage} + *
+ * + *

+ * This is the recommended way to call tools when you want to support both regular and + * long-running tool executions without having to handle the decision logic yourself. + * + *

+ * Note: This method blocks until the tool execution completes. For + * non-blocking streaming, use the async client's + * {@link McpAsyncClient#callToolStream(McpSchema.CallToolRequest)} method. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * + *

+ * Example usage: + * + *

{@code
+	 * var request = new CallToolRequest("my-tool", Map.of("arg", "value"),
+	 *     new TaskMetadata(60000L), null);  // Optional task metadata
+	 *
+	 * List> messages = client.callToolStream(request);
+	 * for (var message : messages) {
+	 *     switch (message) {
+	 *         case TaskCreatedMessage tc ->
+	 *             System.out.println("Task created: " + tc.task().taskId());
+	 *         case TaskStatusMessage ts ->
+	 *             System.out.println("Status: " + ts.task().status());
+	 *         case ResultMessage r ->
+	 *             System.out.println("Result: " + r.result());
+	 *         case ErrorMessage e ->
+	 *             System.err.println("Error: " + e.error().getMessage());
+	 *     }
+	 * }
+	 * }
+ * @param callToolRequest The request containing the tool name and arguments. If the + * {@code task} field is set, the call will be task-augmented. + * @return A list of {@link McpSchema.ResponseMessage} instances representing the + * progress and result of the tool call. + * @see McpSchema.ResponseMessage + * @see McpSchema.TaskCreatedMessage + * @see McpSchema.TaskStatusMessage + * @see McpSchema.ResultMessage + * @see McpSchema.ErrorMessage + */ + public List> callToolStream( + McpSchema.CallToolRequest callToolRequest) { + return withProvidedContextFlux(this.delegate.callToolStream(callToolRequest)).collectList().block(); + } + /** * Retrieves the list of all tools provided by the server. * @return The list of all tools result containing: - tools: List of available tools, @@ -394,6 +496,158 @@ public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest com } + // --------------------------------------- + // Tasks (Experimental) + // --------------------------------------- + + /** + * Returns the task store used for client-side task hosting. + * @return the task store, or null if client-side task hosting is not configured + */ + public TaskStore getTaskStore() { + return this.delegate.getTaskStore(); + } + + /** + * Get the status and metadata of a task. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param getTaskRequest The request containing the task ID. + * @return The task status and metadata. + * @see McpSchema.GetTaskRequest + * @see McpSchema.GetTaskResult + */ + public McpSchema.GetTaskResult getTask(McpSchema.GetTaskRequest getTaskRequest) { + return withProvidedContext(this.delegate.getTask(getTaskRequest)).block(); + } + + /** + * Get the status and metadata of a task by ID. + * + *

+ * This is a convenience overload that creates a {@link McpSchema.GetTaskRequest} with + * the given task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param taskId The task identifier to query. + * @return The task status and metadata. + */ + public McpSchema.GetTaskResult getTask(String taskId) { + Assert.hasText(taskId, "Task ID must not be null or empty"); + return getTask(McpSchema.GetTaskRequest.builder().taskId(taskId).build()); + } + + /** + * Get the result of a completed task. + * + *

+ * The result type depends on the original request that created the task. For tool + * calls, use {@code new TypeRef(){}}. For sampling + * requests, use {@code new TypeRef(){}}. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param The expected result type, must extend {@link McpSchema.Result} + * @param getTaskPayloadRequest The request containing the task ID. + * @param resultTypeRef Type reference for deserializing the result. + * @return The task result. + * @see McpSchema.GetTaskPayloadRequest + * @see McpSchema.CallToolResult + * @see McpSchema.CreateMessageResult + */ + public T getTaskResult( + McpSchema.GetTaskPayloadRequest getTaskPayloadRequest, TypeRef resultTypeRef) { + return withProvidedContext(this.delegate.getTaskResult(getTaskPayloadRequest, resultTypeRef)).block(); + } + + /** + * Retrieves the result of a completed task by task ID. + * + *

+ * This is a convenience overload that creates a + * {@link McpSchema.GetTaskPayloadRequest} from the task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param The expected result type, must extend {@link McpSchema.Result} + * @param taskId The task identifier. + * @param resultTypeRef Type reference for deserializing the result. + * @return The task result. + */ + public T getTaskResult(String taskId, TypeRef resultTypeRef) { + return withProvidedContext(this.delegate.getTaskResult(taskId, resultTypeRef)).block(); + } + + /** + * List all tasks known by the server. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @return The list of all tasks. + * @see McpSchema.ListTasksResult + */ + public McpSchema.ListTasksResult listTasks() { + return withProvidedContext(this.delegate.listTasks()).block(); + } + + /** + * List tasks known by the server with pagination support. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cursor Pagination cursor from a previous list request. + * @return A page of tasks. + * @see McpSchema.ListTasksResult + */ + public McpSchema.ListTasksResult listTasks(String cursor) { + return withProvidedContext(this.delegate.listTasks(cursor)).block(); + } + + /** + * Request cancellation of a task. + * + *

+ * Note that cancellation is cooperative - the server may not honor the cancellation + * request, or may take some time to cancel the task. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cancelTaskRequest The request containing the task ID. + * @return The updated task status. + * @see McpSchema.CancelTaskRequest + * @see McpSchema.CancelTaskResult + */ + public McpSchema.CancelTaskResult cancelTask(McpSchema.CancelTaskRequest cancelTaskRequest) { + return withProvidedContext(this.delegate.cancelTask(cancelTaskRequest)).block(); + } + + /** + * Request cancellation of a task by ID. + * + *

+ * This is a convenience overload that creates a {@link McpSchema.CancelTaskRequest} + * with the given task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param taskId The task identifier to cancel. + * @return The updated task status. + */ + public McpSchema.CancelTaskResult cancelTask(String taskId) { + Assert.hasText(taskId, "Task ID must not be null or empty"); + return cancelTask(McpSchema.CancelTaskRequest.builder().taskId(taskId).build()); + } + /** * For a given action, on assembly, capture the "context" via the * {@link #contextProvider} and store it in the Reactor context. @@ -404,4 +658,14 @@ private Mono withProvidedContext(Mono action) { return action.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, this.contextProvider.get())); } + /** + * For a given Flux action, on assembly, capture the "context" via the + * {@link #contextProvider} and store it in the Reactor context. + * @param action the flux action to perform + * @return the flux with context applied + */ + private Flux withProvidedContextFlux(Flux action) { + return action.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, this.contextProvider.get())); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 0a8dff363..b521d467a 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -100,11 +100,13 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private static final String TEXT_EVENT_STREAM = "text/event-stream"; - public static int NOT_FOUND = 404; + public static final int ACCEPTED = 202; - public static int METHOD_NOT_ALLOWED = 405; + public static final int NOT_FOUND = 404; - public static int BAD_REQUEST = 400; + public static final int METHOD_NOT_ALLOWED = 405; + + public static final int BAD_REQUEST = 400; private final McpJsonMapper jsonMapper; @@ -492,11 +494,13 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { .orElse(null); // For empty content or HTTP code 202 (ACCEPTED), assume success - if (contentType.isBlank() || "0".equals(contentLength) || statusCode == 202) { - // if (contentType.isBlank() || "0".equals(contentLength)) { + if (contentType.isBlank() || "0".equals(contentLength) || statusCode == ACCEPTED) { logger.debug("No body returned for POST in session {}", sessionRepresentation); - // No content type means no response body, so we can just - // return an empty stream + // No content type, zero content length, or 202 Accepted means + // no response body expected, so we return an empty stream. + // Per the spec, 202 is used to acknowledge + // notifications/responses + // where no reply is expected. deliveredSink.success(); return Flux.empty(); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationBuilder.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationBuilder.java new file mode 100644 index 000000000..da9ebca8e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationBuilder.java @@ -0,0 +1,164 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.ToolAnnotations; +import io.modelcontextprotocol.spec.McpSchema.ToolExecution; +import io.modelcontextprotocol.util.Assert; + +/** + * Abstract base builder for task-aware tool specifications. + * + *

+ * This class provides common builder functionality shared between + * {@link TaskAwareAsyncToolSpecification.Builder} and + * {@link TaskAwareSyncToolSpecification.Builder}. + * + *

+ * Subclasses must implement: + *

    + *
  • Handler-specific setter methods (createTask, getTask, getTaskResult)
  • + *
  • The {@link #build()} method to construct the final specification
  • + *
+ * + * @param the concrete builder type for fluent method chaining + * @see TaskAwareAsyncToolSpecification.Builder + * @see TaskAwareSyncToolSpecification.Builder + */ +public abstract class AbstractTaskAwareToolSpecificationBuilder> { + + protected String name; + + protected String description; + + protected JsonSchema inputSchema; + + protected TaskSupportMode taskSupportMode = TaskSupportMode.REQUIRED; + + protected ToolAnnotations annotations; + + /** + * Returns this builder cast to the concrete type for fluent chaining. + * @return this builder as the concrete type B + */ + @SuppressWarnings("unchecked") + protected B self() { + return (B) this; + } + + /** + * Sets the tool name (required). + * @param name the unique name for this tool + * @return this builder + */ + public B name(String name) { + this.name = name; + return self(); + } + + /** + * Sets the tool description. + * @param description a human-readable description of what the tool does + * @return this builder + */ + public B description(String description) { + this.description = description; + return self(); + } + + /** + * Sets the JSON Schema for the tool's input parameters. + * + *

+ * If not set, defaults to an empty object schema (no parameters). + * @param inputSchema the JSON Schema defining the expected input structure + * @return this builder + */ + public B inputSchema(JsonSchema inputSchema) { + this.inputSchema = inputSchema; + return self(); + } + + /** + * Sets the task support mode for this tool. + * + *

+ * Defaults to {@link TaskSupportMode#REQUIRED} if not set. + * @param mode the task support mode (FORBIDDEN, OPTIONAL, or REQUIRED) + * @return this builder + */ + public B taskSupportMode(TaskSupportMode mode) { + this.taskSupportMode = mode != null ? mode : TaskSupportMode.REQUIRED; + return self(); + } + + /** + * Sets optional annotations for the tool. + * @param annotations additional metadata for the tool + * @return this builder + */ + public B annotations(ToolAnnotations annotations) { + this.annotations = annotations; + return self(); + } + + /** + * Validates that required fields are set. + * + *

+ * Subclasses should call this method at the start of their {@link #build()} method + * and add any additional validation specific to their handler types. + * @throws IllegalArgumentException if name is empty + */ + protected void validateCommonFields() { + Assert.hasText(name, "name must not be empty"); + } + + /** + * Builds the {@link Tool} object from the configured properties. + * + *

+ * This method applies defaults for missing optional fields: + *

    + *
  • inputSchema defaults to an empty object schema
  • + *
  • description defaults to the tool name
  • + *
+ * @return a configured Tool instance + */ + protected Tool buildTool() { + // Use default empty schema if not provided + JsonSchema schema = inputSchema != null ? inputSchema : TaskDefaults.EMPTY_INPUT_SCHEMA; + + // Use description or default to name + String desc = description != null ? description : name; + + // Build execution property with configured task support mode + ToolExecution execution = ToolExecution.builder().taskSupport(taskSupportMode).build(); + + // Build the tool + return Tool.builder() + .name(name) + .description(desc) + .inputSchema(schema) + .execution(execution) + .annotations(annotations) + .build(); + } + + /** + * Builds the final tool specification. + * + *

+ * Implementations must validate their handler-specific fields and construct the + * appropriate specification type. + * @return the built specification + * @throws IllegalArgumentException if required fields are not set + */ + public abstract Object build(); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskExtra.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskExtra.java new file mode 100644 index 000000000..5ee270868 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskExtra.java @@ -0,0 +1,210 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.function.Consumer; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Context passed to {@link CreateTaskHandler} providing access to task infrastructure and + * request metadata. + *

+ * Example usage: + * + *

{@code
+ * CreateTaskHandler handler = (args, extra) -> {
+ *     // Decide TTL based on request or use a default
+ *     long ttl = extra.requestTtl() != null
+ *         ? Math.min(extra.requestTtl(), Duration.ofMinutes(30).toMillis())
+ *         : Duration.ofMinutes(5).toMillis();
+ *
+ *     return extra.taskStore()
+ *         .createTask(CreateTaskOptions.builder()
+ *             .requestedTtl(ttl)
+ *             .sessionId(extra.sessionId())
+ *             .build())
+ *         .flatMap(task -> {
+ *             // Use exchange for client communication
+ *             startBackgroundWork(task.taskId(), args, extra.exchange()).subscribe();
+ *             return Mono.just(new McpSchema.CreateTaskResult(task, null));
+ *         });
+ * };
+ * }
+ * + *

+ * Design Note: This interface mirrors {@link SyncCreateTaskExtra} for + * the synchronous API. The duplication is intentional because async methods return + * {@link Mono} while sync methods return values directly. This separation allows for + * proper reactive and blocking semantics without forcing one paradigm on the other. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see CreateTaskHandler + * @see SyncCreateTaskExtra + * @see TaskStore + * @see TaskMessageQueue + */ +public interface CreateTaskExtra { + + /** + * The task store for creating and managing tasks. + * + *

+ * Tools use this to create tasks with their desired configuration: + * + *

{@code
+	 * extra.taskStore().createTask(CreateTaskOptions.builder()
+	 *     .requestedTtl(Duration.ofMinutes(5).toMillis())
+	 *     .pollInterval(Duration.ofSeconds(1).toMillis())
+	 *     .sessionId(extra.sessionId())
+	 *     .build());
+	 * }
+ * @return the TaskStore instance + */ + TaskStore taskStore(); + + /** + * The message queue for task communication during INPUT_REQUIRED state. + * + *

+ * Use this for interactive tasks that need to communicate with the client during + * execution. + * @return the TaskMessageQueue instance, or null if not configured + */ + TaskMessageQueue taskMessageQueue(); + + /** + * The server exchange for client interaction. + * + *

+ * Provides access to session-scoped operations like sending notifications to the + * client. + * @return the McpAsyncServerExchange instance + */ + McpAsyncServerExchange exchange(); + + /** + * Session ID for task isolation. + * + *

+ * Tasks created with this session ID will only be visible to the same session, + * enabling proper multi-client isolation. + * @return the session ID string + */ + String sessionId(); + + /** + * Request-specified TTL from client (may be null). + * + *

+ * If the client specified a TTL in the task metadata of their request, it will be + * available here. Tools can use this to implement client-controlled TTL policies: + * + *

{@code
+	 * // Client can lower but not raise TTL
+	 * long maxTtl = Duration.ofMinutes(30).toMillis();
+	 * long ttl = extra.requestTtl() != null
+	 *     ? Math.min(extra.requestTtl(), maxTtl)
+	 *     : maxTtl;
+	 * }
+ * @return the TTL in milliseconds from the client request, or null if not specified + */ + Long requestTtl(); + + /** + * The original MCP request that triggered this task creation. + * + *

+ * For tool calls, this will be a {@link McpSchema.CallToolRequest} containing the + * tool name, arguments, and any task metadata. This request is stored alongside the + * task and can be retrieved later via the task store, eliminating the need for + * separate task-to-tool mapping. + * @return the original request that triggered task creation + */ + McpSchema.Request originatingRequest(); + + // -------------------------- + // Convenience Methods + // -------------------------- + + /** + * Convenience method to create a task with default options derived from this context. + * + *

+ * This method automatically uses {@link #originatingRequest()}, {@link #sessionId()}, + * and {@link #requestTtl()} from this context, eliminating common boilerplate: + * + *

{@code
+	 * // Instead of:
+	 * extra.taskStore().createTask(CreateTaskOptions.builder(extra.originatingRequest())
+	 *     .sessionId(extra.sessionId())
+	 *     .requestedTtl(extra.requestTtl())
+	 *     .build())
+	 *
+	 * // You can simply use:
+	 * extra.createTask()
+	 * }
+ * @return Mono that completes with the created Task + */ + default Mono createTask() { + return taskStore().createTask(CreateTaskOptions.builder(originatingRequest()) + .sessionId(sessionId()) + .requestedTtl(requestTtl()) + .build()); + } + + /** + * Convenience method to create a task with custom options, but inheriting session + * context. + * + *

+ * This method pre-populates the builder with {@link #originatingRequest()}, + * {@link #sessionId()}, and {@link #requestTtl()}, then allows customization: + * + *

{@code
+	 * // Create a task with custom poll interval:
+	 * extra.createTask(opts -> opts.pollInterval(500L))
+	 *
+	 * // Create a task with custom TTL (ignoring client request):
+	 * extra.createTask(opts -> opts.requestedTtl(Duration.ofMinutes(10).toMillis()))
+	 * }
+ * @param customizer function to customize options beyond the defaults + * @return Mono that completes with the created Task + */ + default Mono createTask(Consumer customizer) { + CreateTaskOptions.Builder builder = CreateTaskOptions.builder(originatingRequest()) + .sessionId(sessionId()) + .requestedTtl(requestTtl()); + customizer.accept(builder); + return taskStore().createTask(builder.build()); + } + + /** + * Create a TaskContext for managing the given task's lifecycle. + * + *

+ * This convenience method creates a TaskContext that uses this extra's task store and + * message queue, reducing boilerplate in task handlers: + * + *

{@code
+	 * extra.createTask()
+	 *     .map(task -> extra.createTaskContext(task))
+	 *     .flatMap(ctx -> {
+	 *         // Use ctx to update status, send messages, etc.
+	 *         return ctx.complete(result);
+	 *     });
+	 * }
+ * @param task the task to create a context for + * @return a TaskContext bound to the given task and this extra's infrastructure + */ + default TaskContext createTaskContext(McpSchema.Task task) { + return new DefaultTaskContext<>(task.taskId(), sessionId(), taskStore(), taskMessageQueue()); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskHandler.java new file mode 100644 index 000000000..b5e74e29b --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskHandler.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Functional interface for handling task creation for tool calls. + *

+ * Example usage: + * + *

{@code
+ * CreateTaskHandler handler = (args, extra) -> {
+ *     // Tool decides TTL directly
+ *     long ttl = Duration.ofMinutes(5).toMillis();
+ *
+ *     return extra.taskStore()
+ *         .createTask(CreateTaskOptions.builder()
+ *             .requestedTtl(ttl)
+ *             .sessionId(extra.sessionId())
+ *             .build())
+ *         .flatMap(task -> {
+ *             // Start background work
+ *             doWork(task.taskId(), args, extra.exchange()).subscribe();
+ *             return Mono.just(new McpSchema.CreateTaskResult(task, null));
+ *         });
+ * };
+ * }
+ * + *

+ * This is an experimental API that may change in future releases. + * + * @see CreateTaskExtra + * @see TaskAwareToolSpec + */ +@FunctionalInterface +public interface CreateTaskHandler { + + /** + * Handles task creation for a tool call. + * + *

+ * The handler is responsible for: + *

    + *
  • Creating the task with desired TTL and poll interval
  • + *
  • Starting any background work needed
  • + *
  • Returning the created task wrapped in a CreateTaskResult
  • + *
+ * @param args The parsed tool arguments from the CallToolRequest + * @param extra Context providing taskStore, exchange, and request metadata + * @return a Mono emitting the CreateTaskResult containing the created Task + */ + Mono createTask(Map args, CreateTaskExtra extra); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskOptions.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskOptions.java new file mode 100644 index 000000000..65e01efc4 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskOptions.java @@ -0,0 +1,190 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Options for creating a new task. + * + *

+ * Recommended Usage: The {@link #builder(McpSchema.Request)} method + * provides the preferred way to create options, offering a fluent API that's more + * readable and maintainable than the record constructor: + * + *

{@code
+ * CreateTaskOptions options = CreateTaskOptions.builder(callToolRequest)
+ *     .requestedTtl(60000L)
+ *     .pollInterval(1000L)
+ *     .sessionId(sessionId)
+ *     .build();
+ * }
+ * + *

+ * The {@code originatingRequest} field is required and stores the original MCP request + * that triggered task creation. This allows the task store to provide full context when + * retrieving tasks, eliminating the need for separate task-to-tool mapping. + * + * This is an experimental API that may change in future releases. + * + * @see Builder + */ +public record CreateTaskOptions(McpSchema.Request originatingRequest, String taskId, Long requestedTtl, + Long pollInterval, Object context, String sessionId) { + + /** + * Compact constructor that validates options. + * + *

+ * Validation rules: + *

    + *
  • originatingRequest: required (must not be null)
  • + *
  • TTL: 0 to {@link TaskDefaults#MAX_TTL_MS} (24 hours)
  • + *
  • Poll interval: 0 or {@link TaskDefaults#MIN_POLL_INTERVAL_MS} to + * {@link TaskDefaults#MAX_POLL_INTERVAL_MS}
  • + *
+ * @throws IllegalArgumentException if originatingRequest is null, or TTL or + * pollInterval is out of bounds + */ + public CreateTaskOptions { + if (originatingRequest == null) { + throw new IllegalArgumentException("originatingRequest must not be null"); + } + if (requestedTtl != null) { + if (requestedTtl < 0) { + throw new IllegalArgumentException("requestedTtl must be non-negative, got: " + requestedTtl); + } + if (requestedTtl > TaskDefaults.MAX_TTL_MS) { + throw new IllegalArgumentException("requestedTtl must not exceed " + TaskDefaults.MAX_TTL_MS + + "ms (24 hours), got: " + requestedTtl); + } + } + if (pollInterval != null) { + if (pollInterval < 0) { + throw new IllegalArgumentException("pollInterval must be non-negative, got: " + pollInterval); + } + if (pollInterval > 0 && pollInterval < TaskDefaults.MIN_POLL_INTERVAL_MS) { + throw new IllegalArgumentException("pollInterval must be at least " + TaskDefaults.MIN_POLL_INTERVAL_MS + + "ms when non-zero, got: " + pollInterval); + } + if (pollInterval > TaskDefaults.MAX_POLL_INTERVAL_MS) { + throw new IllegalArgumentException("pollInterval must not exceed " + TaskDefaults.MAX_POLL_INTERVAL_MS + + "ms (1 hour), got: " + pollInterval); + } + } + } + + /** + * Creates a new builder for CreateTaskOptions. + * @param originatingRequest the original MCP request that triggered task creation + * (required) + * @return a new builder instance + * @throws IllegalArgumentException if originatingRequest is null + */ + public static Builder builder(McpSchema.Request originatingRequest) { + return new Builder(originatingRequest); + } + + public static class Builder { + + private final McpSchema.Request originatingRequest; + + private String taskId; + + private Long requestedTtl; + + private Long pollInterval; + + private Object context; + + private String sessionId; + + /** + * Creates a new builder with the required originating request. + * @param originatingRequest the original MCP request that triggered task creation + * @throws IllegalArgumentException if originatingRequest is null + */ + Builder(McpSchema.Request originatingRequest) { + if (originatingRequest == null) { + throw new IllegalArgumentException("originatingRequest must not be null"); + } + this.originatingRequest = originatingRequest; + } + + /** + * Sets a custom task ID. If null or not called, the TaskStore will auto-generate + * a unique task ID. + * + *

+ * Custom task IDs are useful for correlating MCP tasks with external systems + * (e.g., job queues, workflow engines). When wrapping external async APIs that + * have their own job IDs, you can use the external ID directly as the MCP task + * ID. + * @param taskId custom task ID, or null for auto-generation + * @return this builder + */ + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + /** + * Sets the requested TTL in milliseconds. + * @param requestedTtl the requested TTL + * @return this builder + */ + public Builder requestedTtl(Long requestedTtl) { + this.requestedTtl = requestedTtl; + return this; + } + + /** + * Sets the suggested poll interval in milliseconds. + * @param pollInterval the poll interval + * @return this builder + */ + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + /** + * Sets optional context data to associate with the task. + * @param context the context data + * @return this builder + */ + public Builder context(Object context) { + this.context = context; + return this; + } + + /** + * Sets the session ID for session-scoped task isolation. + * @param sessionId the session ID + * @return this builder + */ + public Builder sessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } + + /** + * Builds the {@link CreateTaskOptions} instance. + * + *

+ * Validation is performed by the record's compact constructor. + * @return the built CreateTaskOptions + * @throws IllegalArgumentException if TTL or pollInterval is out of bounds + * @see CreateTaskOptions#CreateTaskOptions(McpSchema.Request, String, Long, Long, + * Object, String) + */ + public CreateTaskOptions build() { + // Validation is handled by the compact constructor + return new CreateTaskOptions(originatingRequest, taskId, requestedTtl, pollInterval, context, sessionId); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultCreateTaskExtra.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultCreateTaskExtra.java new file mode 100644 index 000000000..c0820ab20 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultCreateTaskExtra.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation of {@link CreateTaskExtra}. + * + *

+ * This implementation is created by {@link io.modelcontextprotocol.server.McpAsyncServer} + * when delegating to a tool's {@link CreateTaskHandler}. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see CreateTaskExtra + * @see CreateTaskHandler + */ +public class DefaultCreateTaskExtra implements CreateTaskExtra { + + private final TaskStore taskStore; + + private final TaskMessageQueue taskMessageQueue; + + private final McpAsyncServerExchange exchange; + + private final String sessionId; + + private final Long requestTtl; + + private final McpSchema.Request originatingRequest; + + /** + * Creates a new DefaultCreateTaskExtra instance. + * @param taskStore the task store for creating tasks (required) + * @param taskMessageQueue the message queue for task communication (may be null) + * @param exchange the server exchange for client interaction (required) + * @param sessionId the session ID for task isolation (required) + * @param requestTtl the TTL from the client request (may be null) + * @param originatingRequest the original MCP request that triggered task creation + * (required) + */ + public DefaultCreateTaskExtra(TaskStore taskStore, + TaskMessageQueue taskMessageQueue, McpAsyncServerExchange exchange, String sessionId, Long requestTtl, + McpSchema.Request originatingRequest) { + Assert.notNull(taskStore, "taskStore must not be null"); + Assert.notNull(exchange, "exchange must not be null"); + Assert.notNull(sessionId, "sessionId must not be null"); + Assert.notNull(originatingRequest, "originatingRequest must not be null"); + + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; + this.exchange = exchange; + this.sessionId = sessionId; + this.requestTtl = requestTtl; + this.originatingRequest = originatingRequest; + } + + @Override + public TaskStore taskStore() { + return this.taskStore; + } + + @Override + public TaskMessageQueue taskMessageQueue() { + return this.taskMessageQueue; + } + + @Override + public McpAsyncServerExchange exchange() { + return this.exchange; + } + + @Override + public String sessionId() { + return this.sessionId; + } + + @Override + public Long requestTtl() { + return this.requestTtl; + } + + @Override + public McpSchema.Request originatingRequest() { + return this.originatingRequest; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultSyncCreateTaskExtra.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultSyncCreateTaskExtra.java new file mode 100644 index 000000000..81e76a109 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultSyncCreateTaskExtra.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; + +/** + * Default implementation of {@link SyncCreateTaskExtra}. + * + *

+ * This implementation is created by {@link io.modelcontextprotocol.server.McpSyncServer} + * when delegating to a tool's {@link SyncCreateTaskHandler}. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see SyncCreateTaskExtra + * @see SyncCreateTaskHandler + */ +public class DefaultSyncCreateTaskExtra implements SyncCreateTaskExtra { + + private final TaskStore taskStore; + + private final TaskMessageQueue taskMessageQueue; + + private final McpSyncServerExchange exchange; + + private final String sessionId; + + private final Long requestTtl; + + private final McpSchema.Request originatingRequest; + + /** + * Creates a new DefaultSyncCreateTaskExtra instance. + * @param taskStore the task store for creating tasks (required) + * @param taskMessageQueue the message queue for task communication (may be null) + * @param exchange the server exchange for client interaction (required) + * @param sessionId the session ID for task isolation (required) + * @param requestTtl the TTL from the client request (may be null) + * @param originatingRequest the original MCP request that triggered task creation + * (required) + */ + public DefaultSyncCreateTaskExtra(TaskStore taskStore, + TaskMessageQueue taskMessageQueue, McpSyncServerExchange exchange, String sessionId, Long requestTtl, + McpSchema.Request originatingRequest) { + Assert.notNull(taskStore, "taskStore must not be null"); + Assert.notNull(exchange, "exchange must not be null"); + Assert.notNull(sessionId, "sessionId must not be null"); + Assert.notNull(originatingRequest, "originatingRequest must not be null"); + + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; + this.exchange = exchange; + this.sessionId = sessionId; + this.requestTtl = requestTtl; + this.originatingRequest = originatingRequest; + } + + @Override + public TaskStore taskStore() { + return this.taskStore; + } + + @Override + public TaskMessageQueue taskMessageQueue() { + return this.taskMessageQueue; + } + + @Override + public McpSyncServerExchange exchange() { + return this.exchange; + } + + @Override + public String sessionId() { + return this.sessionId; + } + + @Override + public Long requestTtl() { + return this.requestTtl; + } + + @Override + public McpSchema.Request originatingRequest() { + return this.originatingRequest; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultTaskContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultTaskContext.java new file mode 100644 index 000000000..e0e616269 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultTaskContext.java @@ -0,0 +1,151 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.Task; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * Default implementation of {@link TaskContext} that delegates to a {@link TaskStore}. + * + *

+ * This implementation provides the standard task context functionality for task handlers, + * including status updates, cancellation checking, and completion signaling. + * + *

+ * The type parameter {@code R} specifies the result type that this context handles, which + * must match the result type of the underlying {@link TaskStore}. + * + *

+ * This is an experimental API that may change in future releases. + * + * @param the type of result this context handles + */ +public class DefaultTaskContext implements TaskContext { + + private final String taskId; + + private final String sessionId; + + private final TaskStore taskStore; + + private final TaskMessageQueue taskMessageQueue; + + /** + * Creates a new DefaultTaskContext. + * @param taskId the task identifier + * @param sessionId the session ID for session validation, or null for single-tenant + * mode + * @param taskStore the task store to delegate to + */ + public DefaultTaskContext(String taskId, String sessionId, TaskStore taskStore) { + this(taskId, sessionId, taskStore, null); + } + + /** + * Creates a new DefaultTaskContext with an optional message queue. + * @param taskId the task identifier + * @param sessionId the session ID for session validation, or null for single-tenant + * mode + * @param taskStore the task store to delegate to + * @param taskMessageQueue the message queue for INPUT_REQUIRED state communication, + * may be null + */ + public DefaultTaskContext(String taskId, String sessionId, TaskStore taskStore, + TaskMessageQueue taskMessageQueue) { + Assert.hasText(taskId, "Task ID must not be empty"); + Assert.notNull(taskStore, "TaskStore must not be null"); + this.taskId = taskId; + this.sessionId = sessionId; + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; + } + + @Override + public String getTaskId() { + return this.taskId; + } + + /** + * Returns the session ID associated with this task context. + * @return the session ID, or null for single-tenant mode + */ + public String getSessionId() { + return this.sessionId; + } + + @Override + public Mono getTask() { + return this.taskStore.getTask(this.taskId, this.sessionId).map(GetTaskFromStoreResult::task); + } + + @Override + public Mono isCancelled() { + return this.taskStore.isCancellationRequested(this.taskId, this.sessionId); + } + + @Override + public Mono requestCancellation() { + return this.taskStore.requestCancellation(this.taskId, this.sessionId).then(); + } + + @Override + public Mono updateStatus(String statusMessage) { + return this.taskStore.updateTaskStatus(this.taskId, this.sessionId, TaskStatus.WORKING, statusMessage); + } + + @Override + public Mono requireInput(String statusMessage) { + return this.taskStore.updateTaskStatus(this.taskId, this.sessionId, TaskStatus.INPUT_REQUIRED, statusMessage); + } + + /** + * {@inheritDoc} + * + *

+ * Type Safety Note: Due to Java type erasure, the generic type + * parameter {@code R} cannot be validated at runtime. This method performs a runtime + * check that the result is either a {@link McpSchema.ServerTaskPayloadResult} or + * {@link McpSchema.ClientTaskPayloadResult}, but it cannot verify that the result + * type matches the specific {@code R} type parameter of this + * {@code DefaultTaskContext}. + * + *

+ * Callers must ensure type consistency: if this context was created with + * {@code TaskStore}, only {@code ServerTaskPayloadResult} + * values (like {@code CallToolResult}) should be passed. Passing the wrong type will + * not cause an immediate error but may result in {@code ClassCastException} when the + * result is retrieved from the TaskStore. + */ + @SuppressWarnings("unchecked") + @Override + public Mono complete(McpSchema.Result result) { + Assert.notNull(result, "Result must not be null"); + if (!(result instanceof McpSchema.ServerTaskPayloadResult) + && !(result instanceof McpSchema.ClientTaskPayloadResult)) { + return Mono.error(new IllegalArgumentException( + "Result must be a ServerTaskPayloadResult or ClientTaskPayloadResult, got: " + + result.getClass().getName())); + } + return this.taskStore.storeTaskResult(this.taskId, this.sessionId, TaskStatus.COMPLETED, (R) result); + } + + @Override + public Mono fail(String errorMessage) { + return this.taskStore.updateTaskStatus(this.taskId, this.sessionId, TaskStatus.FAILED, errorMessage); + } + + /** + * Returns the message queue for INPUT_REQUIRED state communication. + * @return the task message queue, or null if not configured + */ + public TaskMessageQueue getMessageQueue() { + return this.taskMessageQueue; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskFromStoreResult.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskFromStoreResult.java new file mode 100644 index 000000000..999a8db73 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskFromStoreResult.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Result type returned by {@link TaskStore#getTask(String, String)} containing both the + * task and the original request that created it. + * + *

+ * This record encapsulates the task along with its originating request, enabling callers + * to access context about how the task was created without requiring a separate lookup. + * For tool calls, the originating request will be a {@link McpSchema.CallToolRequest} + * containing the tool name, arguments, and any task metadata. + * + *

+ * Example usage: + * + *

{@code
+ * taskStore.getTask(taskId, sessionId)
+ *     .map(result -> {
+ *         McpSchema.Task task = result.task();
+ *         if (result.originatingRequest() instanceof McpSchema.CallToolRequest ctr) {
+ *             String toolName = ctr.name();
+ *             // dispatch to tool-specific handler
+ *         }
+ *         return task;
+ *     });
+ * }
+ * + *

+ * This is an experimental API that may change in future releases. + * + * @param task the retrieved task + * @param originatingRequest the original MCP request that triggered task creation + * @see TaskStore#getTask(String, String) + */ +public record GetTaskFromStoreResult(McpSchema.Task task, McpSchema.Request originatingRequest) { + + /** + * Creates a GetTaskFromStoreResult with validation. + * @throws IllegalArgumentException if task or originatingRequest is null + */ + public GetTaskFromStoreResult { + if (task == null) { + throw new IllegalArgumentException("task must not be null"); + } + if (originatingRequest == null) { + throw new IllegalArgumentException("originatingRequest must not be null"); + } + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskHandler.java new file mode 100644 index 000000000..fe7cb6d56 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskHandler.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Functional interface for handling custom task retrieval logic. + * + *

+ * When a tool registers a custom {@code GetTaskHandler}, it will be called instead of the + * default task store lookup when {@code tasks/get} requests are received for tasks + * created by that tool. + * + *

+ * This enables tools to: + *

    + *
  • Fetch task state from external storage (Redis, database, etc.)
  • + *
  • Transform or enrich task data before returning
  • + *
  • Implement custom task lifecycle logic
  • + *
+ * + *

+ * Full override pattern: Custom handlers do NOT receive the stored task + * as input - they are expected to fetch everything independently for maximum flexibility. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see SyncGetTaskHandler + * @see GetTaskResultHandler + */ +@FunctionalInterface +public interface GetTaskHandler { + + /** + * Handles a {@code tasks/get} request for a task created by the associated tool. + * @param exchange the server exchange providing access to the client session + * @param request the task retrieval request containing the task ID + * @return a Mono emitting the task result, or an error if the task cannot be + * retrieved + */ + Mono handle(McpAsyncServerExchange exchange, McpSchema.GetTaskRequest request); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskResultHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskResultHandler.java new file mode 100644 index 000000000..4e5da831a --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/GetTaskResultHandler.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Functional interface for handling custom task result retrieval logic. + * + *

+ * When a tool registers a custom {@code GetTaskResultHandler}, it will be called instead + * of the default task store lookup when {@code tasks/result} requests are received for + * tasks created by that tool. + * + *

+ * This enables tools to: + *

    + *
  • Fetch task results from external storage
  • + *
  • Transform or enrich results before returning
  • + *
  • Implement lazy result computation
  • + *
+ * + *

+ * Full override pattern: Custom handlers do NOT receive the stored + * result as input - they are expected to fetch everything independently for maximum + * flexibility. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see SyncGetTaskResultHandler + * @see GetTaskHandler + */ +@FunctionalInterface +public interface GetTaskResultHandler { + + /** + * Handles a {@code tasks/result} request for a task created by the associated tool. + * @param exchange the server exchange providing access to the client session + * @param request the task result request containing the task ID + * @return a Mono emitting the task payload result, or an error if the result cannot + * be retrieved + */ + Mono handle(McpAsyncServerExchange exchange, + McpSchema.GetTaskPayloadRequest request); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueue.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueue.java new file mode 100644 index 000000000..289eac255 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueue.java @@ -0,0 +1,126 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; + +import reactor.core.publisher.Mono; + +/** + * In-memory implementation of {@link TaskMessageQueue}. + * + *

+ * This implementation stores messages in memory using thread-safe concurrent queues. Each + * task has its own isolated queue. + * + *

Thread Safety

+ *

+ * This implementation is thread-safe. Queue operations use internal synchronization to + * ensure atomicity of size checks and modifications during enqueue operations. The + * internal queue references are managed by this class and should not be accessed directly + * by external code. + * + *

+ * This is an experimental API that may change in future releases. + */ +public class InMemoryTaskMessageQueue implements TaskMessageQueue { + + // Use centralized default from TaskDefaults + private static final int DEFAULT_MAX_SIZE = TaskDefaults.DEFAULT_MAX_QUEUE_SIZE; + + private final Map> queues = new ConcurrentHashMap<>(); + + @Override + public Mono enqueue(String taskId, QueuedMessage message, Integer maxSize) { + // Validate maxSize bounds to prevent unbounded memory growth + if (maxSize != null) { + if (maxSize < 1) { + return Mono.error(new IllegalArgumentException("maxSize must be at least 1, got: " + maxSize)); + } + if (maxSize > TaskDefaults.MAX_ALLOWED_QUEUE_SIZE) { + return Mono.error(new IllegalArgumentException( + "maxSize must not exceed " + TaskDefaults.MAX_ALLOWED_QUEUE_SIZE + ", got: " + maxSize)); + } + } + + return Mono.fromRunnable(() -> { + Queue queue = queues.computeIfAbsent(taskId, k -> new ConcurrentLinkedQueue<>()); + + int effectiveMaxSize = maxSize != null ? maxSize : DEFAULT_MAX_SIZE; + + // Synchronize to make size check + poll + offer atomic + // This prevents race conditions where concurrent enqueues could exceed + // maxSize + synchronized (queue) { + while (queue.size() >= effectiveMaxSize) { + queue.poll(); + } + queue.offer(message); + } + }); + } + + @Override + public Mono dequeue(String taskId) { + return Mono.fromCallable(() -> { + Queue queue = queues.get(taskId); + if (queue == null) { + return null; + } + return queue.poll(); + }); + } + + @Override + public Mono> dequeueAll(String taskId) { + return Mono.fromCallable(() -> { + Queue queue = queues.get(taskId); + if (queue == null) { + return List.of(); + } + + // Synchronize to prevent race with enqueue() which also syncs on queue + synchronized (queue) { + List messages = new ArrayList<>(queue); + queue.clear(); + return messages; + } + }); + } + + /** + * Clears all messages for a specific task. + * @param taskId the task identifier + */ + public void clear(String taskId) { + queues.remove(taskId); + } + + @Override + public Mono clearTask(String taskId) { + return Mono.fromRunnable(() -> queues.remove(taskId)); + } + + /** + * Clears all queues. Use with caution. + */ + public void clearAll() { + queues.clear(); + } + + @Override + public Mono getQueueSize(String taskId) { + return Mono.fromCallable(() -> { + Queue queue = queues.get(taskId); + return queue != null ? queue.size() : 0; + }); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStore.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStore.java new file mode 100644 index 000000000..df9f9ad4d --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStore.java @@ -0,0 +1,684 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.Task; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * In-memory implementation of {@link TaskStore}. + * + *

+ * This implementation stores tasks in memory using thread-safe concurrent data + * structures. It supports TTL-based cleanup of expired tasks. + * + *

+ * The type parameter {@code R} specifies the result type that this store handles: + *

    + *
  • For server-side stores (handling tool calls), use + * {@link McpSchema.ServerTaskPayloadResult} + *
  • For client-side stores (handling sampling/elicitation), use + * {@link McpSchema.ClientTaskPayloadResult} + *
  • For stores that can handle any result type, use {@link McpSchema.Result} + *
+ * + *

+ * This is an experimental API that may change in future releases. + * + * @param the type of result this store handles + */ +public class InMemoryTaskStore implements TaskStore { + + private static final Logger logger = LoggerFactory.getLogger(InMemoryTaskStore.class); + + // Use centralized defaults from TaskDefaults + private static final long DEFAULT_TTL_MS = TaskDefaults.DEFAULT_TTL_MS; + + private static final long DEFAULT_POLL_INTERVAL_MS = TaskDefaults.DEFAULT_POLL_INTERVAL_MS; + + private static final int DEFAULT_PAGE_SIZE = TaskDefaults.DEFAULT_PAGE_SIZE; + + // Use centralized max tasks default from TaskDefaults + private static final int DEFAULT_MAX_TASKS = TaskDefaults.DEFAULT_MAX_TASKS; + + // Use ConcurrentSkipListMap for O(log n) sorted access and efficient + // cursor-based pagination via tailMap() + private final NavigableMap tasks = new ConcurrentSkipListMap<>(); + + private final Map results = new ConcurrentHashMap<>(); + + private final Set cancellationRequests = ConcurrentHashMap.newKeySet(); + + private final ScheduledExecutorService cleanupExecutor; + + private final long defaultTtl; + + private final long defaultPollInterval; + + // Counter for unique instance IDs to distinguish multiple stores in thread names + private static final AtomicLong INSTANCE_COUNTER = new AtomicLong(0); + + private final long instanceId; + + // Optional message queue for coordinated cleanup + private final TaskMessageQueue messageQueue; + + // Maximum number of concurrent tasks + private final int maxTasks; + + /** + * Creates a new InMemoryTaskStore with default settings. + */ + public InMemoryTaskStore() { + this(DEFAULT_TTL_MS, DEFAULT_POLL_INTERVAL_MS, null, DEFAULT_MAX_TASKS); + } + + /** + * Creates a new InMemoryTaskStore with custom TTL and poll interval. + * @param defaultTtl the default TTL in milliseconds + * @param defaultPollInterval the default poll interval in milliseconds + */ + public InMemoryTaskStore(long defaultTtl, long defaultPollInterval) { + this(defaultTtl, defaultPollInterval, null, DEFAULT_MAX_TASKS); + } + + /** + * Creates a new InMemoryTaskStore with custom settings and optional message queue for + * coordinated cleanup. + * @param defaultTtl the default TTL in milliseconds + * @param defaultPollInterval the default poll interval in milliseconds + * @param messageQueue optional message queue to clean up when tasks expire (may be + * null) + */ + public InMemoryTaskStore(long defaultTtl, long defaultPollInterval, TaskMessageQueue messageQueue) { + this(defaultTtl, defaultPollInterval, messageQueue, DEFAULT_MAX_TASKS); + } + + /** + * Creates a new InMemoryTaskStore with custom settings, optional message queue, and + * maximum task limit. + * @param defaultTtl the default TTL in milliseconds (must be positive) + * @param defaultPollInterval the default poll interval in milliseconds (must be + * positive) + * @param messageQueue optional message queue to clean up when tasks expire (may be + * null) + * @param maxTasks maximum number of concurrent tasks (must be positive) + * @throws IllegalArgumentException if defaultTtl, defaultPollInterval, or maxTasks is + * not positive + */ + public InMemoryTaskStore(long defaultTtl, long defaultPollInterval, TaskMessageQueue messageQueue, int maxTasks) { + if (defaultTtl <= 0) { + throw new IllegalArgumentException("defaultTtl must be positive"); + } + if (defaultPollInterval <= 0) { + throw new IllegalArgumentException("defaultPollInterval must be positive"); + } + if (maxTasks <= 0) { + throw new IllegalArgumentException("maxTasks must be positive"); + } + this.instanceId = INSTANCE_COUNTER.incrementAndGet(); + this.defaultTtl = defaultTtl; + this.defaultPollInterval = defaultPollInterval; + this.messageQueue = messageQueue; + this.maxTasks = maxTasks; + this.cleanupExecutor = Executors.newSingleThreadScheduledExecutor(r -> { + // Include instance ID in thread name for debugging with multiple stores + Thread t = new Thread(r, "mcp-task-cleanup-" + instanceId); + t.setDaemon(true); + return t; + }); + this.cleanupExecutor.scheduleAtFixedRate(this::cleanupExpiredTasks, 1, 1, TimeUnit.MINUTES); + } + + /** + * Creates a new builder for InMemoryTaskStore with default settings. + * + *

+ * The builder provides a fluent API for configuring the store: + * + *

{@code
+	 * InMemoryTaskStore store = InMemoryTaskStore.builder()
+	 *     .defaultTtl(Duration.ofMinutes(30))
+	 *     .defaultPollInterval(Duration.ofSeconds(2))
+	 *     .maxTasks(5000)
+	 *     .messageQueue(messageQueue)
+	 *     .build();
+	 * }
+ * @param the result type for this store + * @return a new builder instance + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Builder for creating {@link InMemoryTaskStore} instances with custom configuration. + * + *

+ * All parameters are optional; defaults are used for any unset values: + *

    + *
  • {@code defaultTtl}: {@link TaskDefaults#DEFAULT_TTL_MS} (1 minute)
  • + *
  • {@code defaultPollInterval}: {@link TaskDefaults#DEFAULT_POLL_INTERVAL_MS} (1 + * second)
  • + *
  • {@code maxTasks}: {@link TaskDefaults#DEFAULT_MAX_TASKS} (10,000)
  • + *
  • {@code messageQueue}: null (no coordinated cleanup)
  • + *
+ * + * @param the result type for stores created by this builder + */ + public static class Builder { + + private long defaultTtl = DEFAULT_TTL_MS; + + private long defaultPollInterval = DEFAULT_POLL_INTERVAL_MS; + + private TaskMessageQueue messageQueue = null; + + private int maxTasks = DEFAULT_MAX_TASKS; + + /** + * Sets the default TTL for tasks when not specified in CreateTaskOptions. + * @param ttl the default TTL duration (must be positive) + * @return this builder for chaining + */ + public Builder defaultTtl(Duration ttl) { + this.defaultTtl = ttl.toMillis(); + return this; + } + + /** + * Sets the default TTL for tasks in milliseconds. + * @param ttlMs the default TTL in milliseconds (must be positive) + * @return this builder for chaining + */ + public Builder defaultTtlMs(long ttlMs) { + this.defaultTtl = ttlMs; + return this; + } + + /** + * Sets the default poll interval for task status checking. + * @param interval the default poll interval duration (must be positive) + * @return this builder for chaining + */ + public Builder defaultPollInterval(Duration interval) { + this.defaultPollInterval = interval.toMillis(); + return this; + } + + /** + * Sets the default poll interval in milliseconds. + * @param intervalMs the default poll interval in milliseconds (must be positive) + * @return this builder for chaining + */ + public Builder defaultPollIntervalMs(long intervalMs) { + this.defaultPollInterval = intervalMs; + return this; + } + + /** + * Sets the message queue for coordinated cleanup of task messages. + * @param queue the message queue (may be null) + * @return this builder for chaining + */ + public Builder messageQueue(TaskMessageQueue queue) { + this.messageQueue = queue; + return this; + } + + /** + * Sets the maximum number of concurrent tasks. + * @param max the maximum task count (must be positive) + * @return this builder for chaining + */ + public Builder maxTasks(int max) { + this.maxTasks = max; + return this; + } + + /** + * Builds the InMemoryTaskStore with the configured settings. + * @return a new InMemoryTaskStore instance + * @throws IllegalArgumentException if any configured value is invalid + */ + public InMemoryTaskStore build() { + return new InMemoryTaskStore<>(defaultTtl, defaultPollInterval, messageQueue, maxTasks); + } + + } + + // Lock object for task creation to prevent race condition in max task check + private final Object createTaskLock = new Object(); + + /** + * Validates session access for a task entry. + * + *

+ * Session validation rules: + *

    + *
  • If requestSessionId is null, access is allowed (single-tenant mode)
  • + *
  • If task has no session (created with null sessionId), access is allowed from + * any session
  • + *
  • If both task and request have session IDs, they must match for access
  • + *
+ * @param entry the task entry to validate + * @param requestSessionId the session ID from the request, or null for single-tenant + * @return true if access is allowed, false otherwise + */ + private boolean isSessionValid(TaskEntry entry, String requestSessionId) { + // Null request session = single-tenant mode, allow all + if (requestSessionId == null) { + return true; + } + // Task has no session = allow access from any session + String taskSessionId = entry.sessionId(); + if (taskSessionId == null || taskSessionId.isEmpty()) { + return true; + } + // Both have session IDs, they must match + return requestSessionId.equals(taskSessionId); + } + + @Override + public Mono createTask(CreateTaskOptions options) { + return Mono.fromCallable(() -> { + // Synchronize to make max task check and task creation atomic + // This prevents multiple concurrent calls from exceeding maxTasks + synchronized (createTaskLock) { + if (tasks.size() >= maxTasks) { + throw McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Maximum task limit reached (" + maxTasks + ")") + .build(); + } + + // Use provided taskId if present, otherwise generate one + String taskId = options.taskId() != null ? options.taskId() : UUID.randomUUID().toString(); + String now = Instant.now().toString(); + + Long ttl = options.requestedTtl() != null ? options.requestedTtl() : defaultTtl; + + Long pollInterval = options.pollInterval() != null ? options.pollInterval() : defaultPollInterval; + + String sessionId = options.sessionId(); + + Task task = Task.builder() + .taskId(taskId) + .status(TaskStatus.WORKING) + .createdAt(now) + .lastUpdatedAt(now) + .ttl(ttl) + .pollInterval(pollInterval) + .build(); + + tasks.put(taskId, new TaskEntry(task, options.originatingRequest(), options.context(), sessionId)); + + return task; + } + }); + } + + @Override + public Mono getTask(String taskId, String sessionId) { + return Mono.fromCallable(() -> { + TaskEntry entry = tasks.get(taskId); + if (entry == null) { + return null; + } + // Validate session access atomically + if (!isSessionValid(entry, sessionId)) { + return null; + } + return new GetTaskFromStoreResult(entry.task(), entry.originatingRequest()); + }); + } + + @Override + public Mono updateTaskStatus(String taskId, String sessionId, TaskStatus status, String statusMessage) { + return Mono.fromRunnable(() -> { + tasks.computeIfPresent(taskId, (id, entry) -> { + // Validate session access + if (!isSessionValid(entry, sessionId)) { + return entry; // Silently ignore session mismatch + } + Task oldTask = entry.task(); + // Skip update if task is already in terminal state + if (oldTask.isTerminal()) { + return entry; + } + String now = Instant.now().toString(); + Task newTask = Task.builder() + .taskId(oldTask.taskId()) + .status(status) + .statusMessage(statusMessage) + .createdAt(oldTask.createdAt()) + .lastUpdatedAt(now) + .ttl(oldTask.ttl()) + .pollInterval(oldTask.pollInterval()) + .build(); + return new TaskEntry(newTask, entry.originatingRequest(), entry.context(), entry.sessionId()); + }); + }); + } + + @Override + public Mono storeTaskResult(String taskId, String sessionId, TaskStatus status, R result) { + return Mono.fromRunnable(() -> { + // Update task to terminal status, only storing result if task exists + // Throws McpError if task not found or session mismatch to avoid silent data + // loss + AtomicBoolean taskFound = new AtomicBoolean(false); + AtomicBoolean sessionValid = new AtomicBoolean(true); + AtomicBoolean wasTerminal = new AtomicBoolean(false); + + tasks.computeIfPresent(taskId, (id, entry) -> { + taskFound.set(true); + + // Validate session access + if (!isSessionValid(entry, sessionId)) { + sessionValid.set(false); + return entry; + } + + Task oldTask = entry.task(); + + // Don't overwrite if task is already in terminal state + if (oldTask.isTerminal()) { + wasTerminal.set(true); + return entry; + } + + results.put(taskId, result); + String now = Instant.now().toString(); + Task newTask = Task.builder() + .taskId(oldTask.taskId()) + .status(status) + .createdAt(oldTask.createdAt()) + .lastUpdatedAt(now) + .ttl(oldTask.ttl()) + .pollInterval(oldTask.pollInterval()) + .build(); + return new TaskEntry(newTask, entry.originatingRequest(), entry.context(), entry.sessionId()); + }); + + if (!taskFound.get()) { + throw McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL): " + taskId) + .build(); + } + if (!sessionValid.get()) { + throw McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL): " + taskId) + .build(); + } + // Log if we skipped storing because task was already terminal + if (wasTerminal.get()) { + logger.debug("Skipped storing result for task {} - already in terminal state", taskId); + } + }); + } + + @Override + public Mono getTaskResult(String taskId, String sessionId) { + return Mono.fromCallable(() -> { + // First validate session access + TaskEntry entry = tasks.get(taskId); + if (entry == null || !isSessionValid(entry, sessionId)) { + return null; + } + return results.get(taskId); + }); + } + + @Override + public Mono listTasks(String cursor, String sessionId) { + return Mono.fromCallable(() -> { + List taskList = new ArrayList<>(); + String nextCursor = null; + + // Use tailMap for O(log n) cursor lookup instead of O(n) indexOf() + // tailMap gracefully handles missing cursors (e.g., expired tasks) by + // returning entries that come after where the cursor would be + // lexicographically + NavigableMap view; + if (cursor != null) { + // Get entries strictly after cursor (exclusive) + // If cursor doesn't exist (e.g., task expired), this returns entries + // after where the cursor would be, providing graceful degradation + view = tasks.tailMap(cursor, false); + } + else { + view = tasks; + } + + // Iterate through the view, collecting up to PAGE_SIZE entries + // Filter by sessionId if provided + // Note: When filtering by sessionId, pages may contain fewer than + // DEFAULT_PAGE_SIZE entries. This is intentional - it ensures consistent + // cursor behavior while allowing session-scoped views of the task list. + Iterator> iterator = view.entrySet().iterator(); + int count = 0; + String lastKey = null; + + while (iterator.hasNext() && count < DEFAULT_PAGE_SIZE) { + Map.Entry entry = iterator.next(); + TaskEntry taskEntry = entry.getValue(); + + // Filter by session if sessionId is provided + if (sessionId != null && !sessionId.equals(taskEntry.sessionId())) { + continue; + } + + taskList.add(taskEntry.task()); + lastKey = entry.getKey(); + count++; + } + + // Check if there are more entries (that match the session filter) + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (sessionId == null || sessionId.equals(entry.getValue().sessionId())) { + nextCursor = lastKey; + break; + } + } + + return McpSchema.ListTasksResult.builder().tasks(taskList).nextCursor(nextCursor).build(); + }); + } + + /** + * {@inheritDoc} + * + *

+ * Per the MCP specification, cancellation of tasks in terminal status MUST be + * rejected with error code {@code -32602} (Invalid params). This implementation + * throws {@link McpError} with the appropriate error code when attempting to cancel a + * task that is already in COMPLETED, FAILED, or CANCELLED status. + * + *

+ * Return Value: This method returns the updated Task (now with + * status CANCELLED) rather than {@code Mono} to allow callers to immediately + * verify the cancellation was applied and obtain the updated task state without + * making a separate {@code getTask()} call. This is especially useful for returning + * the cancelled task in the {@code tasks/cancel} response. + * @throws McpError with code {@link McpSchema.ErrorCodes#INVALID_PARAMS} if the task + * is in a terminal state + */ + @Override + public Mono requestCancellation(String taskId, String sessionId) { + return Mono.fromCallable(() -> { + AtomicReference resultRef = new AtomicReference<>(); + AtomicReference terminalStatusRef = new AtomicReference<>(); + AtomicBoolean sessionValid = new AtomicBoolean(true); + + // Use computeIfPresent for atomic update to avoid race condition + tasks.computeIfPresent(taskId, (id, entry) -> { + // Validate session access + if (!isSessionValid(entry, sessionId)) { + sessionValid.set(false); + return entry; + } + + Task oldTask = entry.task(); + // Per MCP spec: MUST reject cancellation of tasks in terminal status + if (oldTask.isTerminal()) { + terminalStatusRef.set(oldTask.status()); + resultRef.set(oldTask); + return entry; + } + cancellationRequests.add(taskId); + String now = Instant.now().toString(); + Task newTask = Task.builder() + .taskId(oldTask.taskId()) + .status(TaskStatus.CANCELLED) + .statusMessage("Cancellation requested") + .createdAt(oldTask.createdAt()) + .lastUpdatedAt(now) + .ttl(oldTask.ttl()) + .pollInterval(oldTask.pollInterval()) + .build(); + resultRef.set(newTask); + return new TaskEntry(newTask, entry.originatingRequest(), entry.context(), entry.sessionId()); + }); + + // Session mismatch returns empty (task not found from caller's perspective) + if (!sessionValid.get()) { + return null; + } + + // Check if we encountered a terminal task and throw appropriate error + TaskStatus terminalStatus = terminalStatusRef.get(); + if (terminalStatus != null) { + throw McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Cannot cancel task: already in terminal status '" + terminalStatus + "'") + .data("taskId: " + taskId) + .build(); + } + + return resultRef.get(); + }); + } + + @Override + public Mono isCancellationRequested(String taskId, String sessionId) { + return Mono.fromCallable(() -> { + // Validate session access first + TaskEntry entry = tasks.get(taskId); + if (entry == null || !isSessionValid(entry, sessionId)) { + return false; + } + return cancellationRequests.contains(taskId); + }); + } + + @Override + public Flux watchTaskUntilTerminal(String taskId, String sessionId, Duration timeout) { + // Use the task's poll interval if available, otherwise default + return getTask(taskId, sessionId).map(GetTaskFromStoreResult::task).flatMapMany(initialTask -> { + long pollInterval = initialTask != null && initialTask.pollInterval() != null ? initialTask.pollInterval() + : DEFAULT_POLL_INTERVAL_MS; + + return Flux.interval(Duration.ofMillis(pollInterval)) + .concatMap(tick -> getTask(taskId, sessionId).map(GetTaskFromStoreResult::task)) + .filter(task -> task != null) + .takeUntil(Task::isTerminal) + .timeout(timeout); + }) + .switchIfEmpty(Flux.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL): " + taskId) + .build())); + } + + /** + * Cleans up expired tasks based on TTL. Package-private for testing. + */ + void cleanupExpiredTasks() { + Instant now = Instant.now(); + List expiredTaskIds = new ArrayList<>(); + + // Collect expired tasks and remove from maps (non-blocking) + tasks.entrySet().removeIf(entry -> { + Task task = entry.getValue().task(); + if (task.ttl() == null) { + return false; // Null TTL means unlimited lifetime + } + Instant createdAt = Instant.parse(task.createdAt()); + Instant expiresAt = createdAt.plusMillis(task.ttl()); + if (now.isAfter(expiresAt)) { + String taskId = entry.getKey(); + // Clean up related data BEFORE removing task entry for atomicity + results.remove(taskId); + cancellationRequests.remove(taskId); + // Collect for parallel message queue cleanup + expiredTaskIds.add(taskId); + return true; + } + return false; + }); + + // Clean up message queues asynchronously to avoid blocking the cleanup thread + if (messageQueue != null && !expiredTaskIds.isEmpty()) { + Flux.fromIterable(expiredTaskIds) + .flatMap(taskId -> messageQueue.clearTask(taskId).timeout(Duration.ofSeconds(1)).onErrorResume(e -> { + logger.warn("Failed to clear task queue for {}", taskId, e); + return Mono.empty(); + })) + .subscribe(null, // no per-item handling needed + error -> logger.warn("Error during message queue cleanup", error), + () -> logger.debug("Completed cleanup of {} message queues", expiredTaskIds.size())); + } + } + + /** + * Shuts down the cleanup executor. Call this when the store is no longer needed. + */ + @Override + public Mono shutdown() { + return Mono.fromRunnable(() -> { + cleanupExecutor.shutdown(); + try { + if (!cleanupExecutor.awaitTermination(5, TimeUnit.SECONDS)) { + cleanupExecutor.shutdownNow(); + } + } + catch (InterruptedException e) { + cleanupExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + }); + } + + /** + * Internal entry holding task data, originating request, optional context, and + * session ID. + */ + private record TaskEntry(Task task, McpSchema.Request originatingRequest, Object context, String sessionId) { + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/QueuedMessage.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/QueuedMessage.java new file mode 100644 index 000000000..4f53831cf --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/QueuedMessage.java @@ -0,0 +1,50 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * A message that can be queued for bidirectional communication during task execution. + * + *

+ * This is an experimental API that may change in future releases. + * + */ +public sealed interface QueuedMessage + permits QueuedMessage.Request, QueuedMessage.Response, QueuedMessage.Notification { + + /** + * A request message (e.g., sampling or elicitation request during a task). + * + * @param requestId the request identifier for correlation + * @param method the method name + * @param request the request payload + */ + record Request(Object requestId, String method, McpSchema.Request request) implements QueuedMessage { + + } + + /** + * A response message (e.g., the result of a sampling or elicitation request). + * + * @param requestId the request identifier for correlation + * @param result the result payload + */ + record Response(Object requestId, McpSchema.Result result) implements QueuedMessage { + + } + + /** + * A notification message (e.g., progress updates). + * + * @param method the notification method name + * @param notification the notification payload + */ + record Notification(String method, McpSchema.Notification notification) implements QueuedMessage { + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskExtra.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskExtra.java new file mode 100644 index 000000000..97568bedb --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskExtra.java @@ -0,0 +1,204 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.function.Consumer; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Synchronous context passed to {@link SyncCreateTaskHandler} providing access to task + * infrastructure and request metadata. + * + *

+ * This is the synchronous variant of {@link CreateTaskExtra}. It gives tool handlers + * access to everything needed to create and manage tasks. + * + *

+ * Example usage: + * + *

{@code
+ * SyncCreateTaskHandler handler = (args, extra) -> {
+ *     // Decide TTL based on request or use a default
+ *     long ttl = extra.requestTtl() != null
+ *         ? Math.min(extra.requestTtl(), Duration.ofMinutes(30).toMillis())
+ *         : Duration.ofMinutes(5).toMillis();
+ *
+ *     Task task = extra.taskStore()
+ *         .createTask(CreateTaskOptions.builder()
+ *             .requestedTtl(ttl)
+ *             .sessionId(extra.sessionId())
+ *             .build())
+ *         .block();
+ *
+ *     // Use exchange for client communication
+ *     startBackgroundWork(task.taskId(), args, extra.exchange());
+ *
+ *     return new McpSchema.CreateTaskResult(task, null);
+ * };
+ * }
+ * + *

+ * Design Note: This interface mirrors {@link CreateTaskExtra} for the + * asynchronous API. The duplication is intentional because async methods return + * {@code Mono} while sync methods return values directly. This separation allows for + * proper blocking semantics without requiring reactive programming knowledge. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see SyncCreateTaskHandler + * @see CreateTaskExtra + * @see TaskStore + * @see TaskMessageQueue + */ +public interface SyncCreateTaskExtra { + + /** + * The task store for creating and managing tasks. + * + *

+ * Tools use this to create tasks with their desired configuration: + * + *

{@code
+	 * Task task = extra.taskStore().createTask(CreateTaskOptions.builder()
+	 *     .requestedTtl(Duration.ofMinutes(5).toMillis())
+	 *     .pollInterval(Duration.ofSeconds(1).toMillis())
+	 *     .sessionId(extra.sessionId())
+	 *     .build()).block();
+	 * }
+ * @return the TaskStore instance + */ + TaskStore taskStore(); + + /** + * The message queue for task communication during INPUT_REQUIRED state. + * + *

+ * Use this for interactive tasks that need to communicate with the client during + * execution. + * @return the TaskMessageQueue instance, or null if not configured + */ + TaskMessageQueue taskMessageQueue(); + + /** + * The server exchange for client interaction. + * + *

+ * Provides access to session-scoped operations like sending notifications to the + * client. + * @return the McpSyncServerExchange instance + */ + McpSyncServerExchange exchange(); + + /** + * Session ID for task isolation. + * + *

+ * Tasks created with this session ID will only be visible to the same session, + * enabling proper multi-client isolation. + * @return the session ID string + */ + String sessionId(); + + /** + * Request-specified TTL from client (may be null). + * + *

+ * If the client specified a TTL in the task metadata of their request, it will be + * available here. Tools can use this to implement client-controlled TTL policies: + * + *

{@code
+	 * // Client can lower but not raise TTL
+	 * long maxTtl = Duration.ofMinutes(30).toMillis();
+	 * long ttl = extra.requestTtl() != null
+	 *     ? Math.min(extra.requestTtl(), maxTtl)
+	 *     : maxTtl;
+	 * }
+ * @return the TTL in milliseconds from the client request, or null if not specified + */ + Long requestTtl(); + + /** + * The original MCP request that triggered this task creation. + * + *

+ * For tool calls, this will be a {@link McpSchema.CallToolRequest} containing the + * tool name, arguments, and any task metadata. This request is stored alongside the + * task and can be retrieved later via the task store, eliminating the need for + * separate task-to-tool mapping. + * @return the original request that triggered task creation + */ + McpSchema.Request originatingRequest(); + + // -------------------------- + // Convenience Methods + // -------------------------- + + /** + * Convenience method to create a task with default options derived from this context. + * + *

+ * This method automatically uses {@link #originatingRequest()}, {@link #sessionId()}, + * and {@link #requestTtl()} from this context, eliminating common boilerplate: + * + *

{@code
+	 * // Instead of:
+	 * Task task = extra.taskStore().createTask(CreateTaskOptions.builder(extra.originatingRequest())
+	 *     .sessionId(extra.sessionId())
+	 *     .requestedTtl(extra.requestTtl())
+	 *     .build()).block();
+	 *
+	 * // You can simply use:
+	 * Task task = extra.createTask();
+	 * }
+ * @return the created Task + */ + default McpSchema.Task createTask() { + return taskStore() + .createTask(CreateTaskOptions.builder(originatingRequest()) + .sessionId(sessionId()) + .requestedTtl(requestTtl()) + .build()) + .block(); + } + + /** + * Convenience method to create a task with custom options, but inheriting session + * context. + * + *

+ * This method pre-populates the builder with {@link #originatingRequest()}, + * {@link #sessionId()}, and {@link #requestTtl()}, then allows customization: + * + *

{@code
+	 * // Create a task with custom poll interval:
+	 * Task task = extra.createTask(opts -> opts.pollInterval(500L));
+	 *
+	 * // Create a task with custom TTL (ignoring client request):
+	 * Task task = extra.createTask(opts -> opts.requestedTtl(Duration.ofMinutes(10).toMillis()));
+	 * }
+ * @param customizer function to customize options beyond the defaults + * @return the created Task + */ + default McpSchema.Task createTask(Consumer customizer) { + CreateTaskOptions.Builder builder = CreateTaskOptions.builder(originatingRequest()) + .sessionId(sessionId()) + .requestedTtl(requestTtl()); + customizer.accept(builder); + return taskStore().createTask(builder.build()).block(); + } + + /** + * Create a TaskContext for managing the given task's lifecycle. + * @param task the task to create a context for + * @return a TaskContext bound to the given task and this extra's infrastructure + */ + default TaskContext createTaskContext(McpSchema.Task task) { + return new DefaultTaskContext<>(task.taskId(), sessionId(), taskStore(), taskMessageQueue()); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskHandler.java new file mode 100644 index 000000000..6f2754072 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncCreateTaskHandler.java @@ -0,0 +1,67 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Synchronous functional interface for handling task creation for tool calls. + * + *

+ * This is the synchronous variant of {@link CreateTaskHandler}. It gives tool + * implementers full control over task creation including TTL configuration, poll + * interval, and any background work initiation. + * + *

+ * Example usage: + * + *

{@code
+ * SyncCreateTaskHandler handler = (args, extra) -> {
+ *     // Tool decides TTL directly
+ *     long ttl = Duration.ofMinutes(5).toMillis();
+ *
+ *     Task task = extra.taskStore()
+ *         .createTask(CreateTaskOptions.builder()
+ *             .requestedTtl(ttl)
+ *             .sessionId(extra.sessionId())
+ *             .build())
+ *         .block();
+ *
+ *     // Start background work (blocking or async)
+ *     startBackgroundWork(task.taskId(), args);
+ *
+ *     return new McpSchema.CreateTaskResult(task, null);
+ * };
+ * }
+ * + *

+ * This is an experimental API that may change in future releases. + * + * @see CreateTaskHandler + * @see SyncCreateTaskExtra + * @see TaskAwareSyncToolSpecification + */ +@FunctionalInterface +public interface SyncCreateTaskHandler { + + /** + * Handles task creation for a tool call. + * + *

+ * The handler is responsible for: + *

    + *
  • Creating the task with desired TTL and poll interval
  • + *
  • Starting any background work needed
  • + *
  • Returning the created task wrapped in a CreateTaskResult
  • + *
+ * @param args The parsed tool arguments from the CallToolRequest + * @param extra Context providing taskStore, exchange, and request metadata + * @return the CreateTaskResult containing the created Task + */ + McpSchema.CreateTaskResult createTask(Map args, SyncCreateTaskExtra extra); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskHandler.java new file mode 100644 index 000000000..b1e8cc6ea --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskHandler.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Synchronous functional interface for handling custom task retrieval logic. + * + *

+ * This is the synchronous variant of {@link GetTaskHandler}. When a tool registers a + * custom {@code SyncGetTaskHandler}, it will be called instead of the default task store + * lookup when {@code tasks/get} requests are received for tasks created by that tool. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see GetTaskHandler + * @see SyncGetTaskResultHandler + */ +@FunctionalInterface +public interface SyncGetTaskHandler { + + /** + * Handles a {@code tasks/get} request for a task created by the associated tool. + * @param exchange the server exchange providing access to the client session + * @param request the task retrieval request containing the task ID + * @return the task result + * @throws io.modelcontextprotocol.spec.McpError if the task cannot be retrieved + */ + McpSchema.GetTaskResult handle(McpSyncServerExchange exchange, McpSchema.GetTaskRequest request); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskResultHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskResultHandler.java new file mode 100644 index 000000000..e578d3a35 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/SyncGetTaskResultHandler.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Synchronous functional interface for handling custom task result retrieval logic. + * + *

+ * This is the synchronous variant of {@link GetTaskResultHandler}. When a tool registers + * a custom {@code SyncGetTaskResultHandler}, it will be called instead of the default + * task store lookup when {@code tasks/result} requests are received for tasks created by + * that tool. + * + *

+ * This is an experimental API that may change in future releases. + * + * @see GetTaskResultHandler + * @see SyncGetTaskHandler + */ +@FunctionalInterface +public interface SyncGetTaskResultHandler { + + /** + * Handles a {@code tasks/result} request for a task created by the associated tool. + * @param exchange the server exchange providing access to the client session + * @param request the task result request containing the task ID + * @return the task payload result + * @throws io.modelcontextprotocol.spec.McpError if the result cannot be retrieved + */ + McpSchema.ServerTaskPayloadResult handle(McpSyncServerExchange exchange, McpSchema.GetTaskPayloadRequest request); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecification.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecification.java new file mode 100644 index 000000000..fd085a361 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecification.java @@ -0,0 +1,336 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.concurrent.Executor; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * Specification for a task-aware asynchronous tool. + * + *

+ * This class encapsulates all information needed to define an MCP tool that supports + * task-augmented execution (SEP-1686). Unlike regular tools + * ({@link io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification}), + * task-aware tools have handlers for managing task lifecycle. + * + *

+ * Use {@link #builder()} to create instances. + * + *

Usage

+ * + *
{@code
+ * TaskAwareAsyncToolSpecification spec = TaskAwareAsyncToolSpecification.builder()
+ *     .name("long-computation")
+ *     .description("A long-running computation task")
+ *     .inputSchema(new JsonSchema("object", Map.of("input", Map.of("type", "string")), null, null, null, null))
+ *     .createTaskHandler((args, extra) -> {
+ *         long ttl = Duration.ofMinutes(5).toMillis();
+ *         return extra.taskStore()
+ *             .createTask(CreateTaskOptions.builder()
+ *                 .requestedTtl(ttl)
+ *                 .sessionId(extra.sessionId())
+ *                 .build())
+ *             .flatMap(task -> {
+ *                 doExpensiveComputation(task.taskId(), args).subscribe();
+ *                 return Mono.just(new McpSchema.CreateTaskResult(task, null));
+ *             });
+ *     })
+ *     .build();
+ *
+ * // Register with server
+ * McpServer.async(transport)
+ *     .taskTools(spec)
+ *     .build();
+ * }
+ * + *

How It Works

+ *
    + *
  1. When the tool is called with task metadata, the {@link CreateTaskHandler} is + * invoked
  2. + *
  3. The handler creates a task with desired TTL and poll interval
  4. + *
  5. The handler starts any background work and returns immediately with the task
  6. + *
  7. Callers can poll {@code tasks/get} for status and retrieve results via + * {@code tasks/result}
  8. + *
+ * + *

Task Support Modes

+ *

+ * The {@link TaskSupportMode} controls how the tool responds to calls with or without + * task metadata: + *

    + *
  • {@link TaskSupportMode#OPTIONAL} (default): Tool can be called + * with OR without task metadata. When called without metadata, the server automatically + * creates an internal task and polls it to completion, returning the result as if it were + * a synchronous call. This provides backward compatibility for clients that don't support + * tasks.
  • + *
  • {@link TaskSupportMode#REQUIRED}: Tool MUST be called with task + * metadata. Calls without metadata return error {@code -32601} (METHOD_NOT_FOUND). Use + * this for tools where callers must explicitly handle task lifecycle (e.g., very + * long-running operations, tasks requiring user input).
  • + *
  • {@link TaskSupportMode#FORBIDDEN}: Tool cannot use tasks. This is + * the default for normal (non-task-aware) tools. Task-aware tools should not use this + * mode.
  • + *
+ * + *

+ * This is an experimental API that may change in future releases. + * + * @see CreateTaskHandler + * @see GetTaskHandler + * @see GetTaskResultHandler + * @see TaskAwareSyncToolSpecification + * @see Builder + */ +public final class TaskAwareAsyncToolSpecification { + + private final Tool tool; + + private final BiFunction> callHandler; + + private final CreateTaskHandler createTaskHandler; + + private final GetTaskHandler getTaskHandler; + + private final GetTaskResultHandler getTaskResultHandler; + + private TaskAwareAsyncToolSpecification(Tool tool, + BiFunction> callHandler, + CreateTaskHandler createTaskHandler, GetTaskHandler getTaskHandler, + GetTaskResultHandler getTaskResultHandler) { + this.tool = tool; + this.callHandler = callHandler; + this.createTaskHandler = createTaskHandler; + this.getTaskHandler = getTaskHandler; + this.getTaskResultHandler = getTaskResultHandler; + } + + /** + * Returns the tool definition. + * @return the tool definition including name, description, and schema + */ + public Tool tool() { + return this.tool; + } + + /** + * Returns the handler for direct (non-task) tool calls. + * @return the call handler + */ + public BiFunction> callHandler() { + return this.callHandler; + } + + /** + * Returns the handler for task creation. + * @return the create task handler + */ + public CreateTaskHandler createTaskHandler() { + return this.createTaskHandler; + } + + /** + * Returns the optional custom handler for tasks/get requests. + * @return the get task handler, or null if not set + */ + public GetTaskHandler getTaskHandler() { + return this.getTaskHandler; + } + + /** + * Returns the optional custom handler for tasks/result requests. + * @return the get task result handler, or null if not set + */ + public GetTaskResultHandler getTaskResultHandler() { + return this.getTaskResultHandler; + } + + /** + * Creates a new builder for constructing a task-aware async tool specification. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Converts a synchronous task-aware tool specification to an asynchronous one. + * + *

+ * The sync handlers are wrapped to execute on the provided executor. + * + *

+ * Important: The provided executor should be bounded to prevent + * thread exhaustion under high task load. Unbounded executors like + * {@link java.util.concurrent.ForkJoinPool#commonPool()} are not recommended for + * production use. Consider using a bounded thread pool such as + * {@link java.util.concurrent.Executors#newFixedThreadPool(int)}. + * + *

+ * Note: This conversion creates a new + * {@link DefaultSyncCreateTaskExtra} internally. Custom {@link SyncCreateTaskExtra} + * implementations from the original sync specification will not be preserved. + * @param sync the synchronous specification to convert + * @param executor the executor for running sync handlers (should be bounded) + * @return an asynchronous task-aware tool specification + */ + public static TaskAwareAsyncToolSpecification fromSync(TaskAwareSyncToolSpecification sync, Executor executor) { + Assert.notNull(sync, "sync specification must not be null"); + Assert.notNull(executor, "executor must not be null"); + + // Wrap sync callHandler + BiFunction> asyncCallHandler = (exchange, + request) -> Mono + .fromCallable(() -> sync.callHandler().apply(sync.createSyncExchange(exchange), request)) + .subscribeOn(Schedulers.fromExecutor(executor)); + + // Wrap sync createTaskHandler + CreateTaskHandler asyncCreateTaskHandler = (args, extra) -> Mono.fromCallable(() -> { + SyncCreateTaskExtra syncExtra = new DefaultSyncCreateTaskExtra(extra.taskStore(), extra.taskMessageQueue(), + sync.createSyncExchange(extra.exchange()), extra.sessionId(), extra.requestTtl(), + extra.originatingRequest()); + return sync.createTaskHandler().createTask(args, syncExtra); + }).subscribeOn(Schedulers.fromExecutor(executor)); + + // Wrap sync getTask handler if present + GetTaskHandler asyncGetTaskHandler = null; + if (sync.getTaskHandler() != null) { + asyncGetTaskHandler = (exchange, request) -> Mono + .fromCallable(() -> sync.getTaskHandler().handle(sync.createSyncExchange(exchange), request)) + .subscribeOn(Schedulers.fromExecutor(executor)); + } + + // Wrap sync getTaskResult handler if present + GetTaskResultHandler asyncGetTaskResultHandler = null; + if (sync.getTaskResultHandler() != null) { + asyncGetTaskResultHandler = (exchange, request) -> Mono + .fromCallable(() -> sync.getTaskResultHandler().handle(sync.createSyncExchange(exchange), request)) + .subscribeOn(Schedulers.fromExecutor(executor)); + } + + return new TaskAwareAsyncToolSpecification(sync.tool(), asyncCallHandler, asyncCreateTaskHandler, + asyncGetTaskHandler, asyncGetTaskResultHandler); + } + + /** + * Builder for creating task-aware async tool specifications. + * + *

+ * This builder provides full control over task creation through the + * {@link CreateTaskHandler}. Tools decide their own TTL, poll interval, and how to + * start background work. + */ + public static class Builder extends AbstractTaskAwareToolSpecificationBuilder { + + private CreateTaskHandler createTaskHandler; + + private GetTaskHandler getTaskHandler; + + private GetTaskResultHandler getTaskResultHandler; + + /** + * Sets the handler for task creation (required). + * + *

+ * This handler is called when the tool is invoked with task metadata. The handler + * has full control over task creation including TTL configuration. + * + *

+ * Example: + * + *

{@code
+		 * .createTaskHandler((args, extra) -> {
+		 *     long ttl = Duration.ofMinutes(5).toMillis();
+		 *     return extra.taskStore()
+		 *         .createTask(CreateTaskOptions.builder()
+		 *             .requestedTtl(ttl)
+		 *             .sessionId(extra.sessionId())
+		 *             .build())
+		 *         .flatMap(task -> {
+		 *             doWork(task.taskId(), args).subscribe();
+		 *             return Mono.just(new McpSchema.CreateTaskResult(task, null));
+		 *         });
+		 * })
+		 * }
+ * @param createTaskHandler the task creation handler + * @return this builder + */ + public Builder createTaskHandler(CreateTaskHandler createTaskHandler) { + this.createTaskHandler = createTaskHandler; + return this; + } + + /** + * Sets a custom handler for {@code tasks/get} requests. + * + *

+ * When set, this handler will be called instead of the default task store lookup + * when retrieving task status. This enables fetching from external storage or + * custom task lifecycle logic. + * @param getTaskHandler the custom task retrieval handler + * @return this builder + */ + public Builder getTaskHandler(GetTaskHandler getTaskHandler) { + this.getTaskHandler = getTaskHandler; + return this; + } + + /** + * Sets a custom handler for {@code tasks/result} requests. + * + *

+ * When set, this handler will be called instead of the default task store lookup + * when retrieving task results. This enables fetching from external storage or + * lazy result computation. + * @param getTaskResultHandler the custom task result retrieval handler + * @return this builder + */ + public Builder getTaskResultHandler(GetTaskResultHandler getTaskResultHandler) { + this.getTaskResultHandler = getTaskResultHandler; + return this; + } + + /** + * Builds the {@link TaskAwareAsyncToolSpecification}. + * + *

+ * The returned specification handles task-augmented tool calls by delegating to + * the createTaskHandler. For non-task calls, the server uses an automatic polling + * shim. + * @return a new TaskAwareAsyncToolSpecification instance + * @throws IllegalArgumentException if required fields (name, createTask) are not + * set + */ + @Override + public TaskAwareAsyncToolSpecification build() { + validateCommonFields(); + Assert.notNull(createTaskHandler, "createTaskHandler must not be null"); + + Tool tool = buildTool(); + + // Create a placeholder callHandler for non-task calls + // (will be handled by automatic polling shim in McpAsyncServer) + BiFunction> callHandler = (exchange, + request) -> Mono.error(new UnsupportedOperationException("Tool '" + name + + "' requires task-augmented execution. Either provide TaskMetadata in the request, " + + "or ensure the server has a TaskStore configured for automatic polling. " + + "Direct tool calls without task support are not available for this tool.")); + + return new TaskAwareAsyncToolSpecification(tool, callHandler, createTaskHandler, getTaskHandler, + getTaskResultHandler); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecification.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecification.java new file mode 100644 index 000000000..6885421a9 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecification.java @@ -0,0 +1,258 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.function.BiFunction; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.Assert; + +/** + * Specification for a task-aware synchronous tool. + * + *

+ * This is the synchronous variant of {@link TaskAwareAsyncToolSpecification}. It + * encapsulates all information needed to define an MCP tool that supports task-augmented + * execution (SEP-1686) using blocking handlers. + * + *

+ * Use {@link #builder()} to create instances. + * + *

Usage

+ * + *
{@code
+ * TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder()
+ *     .name("long-computation")
+ *     .description("A long-running computation task")
+ *     .createTaskHandler((args, extra) -> {
+ *         long ttl = Duration.ofMinutes(5).toMillis();
+ *         Task task = extra.taskStore()
+ *             .createTask(CreateTaskOptions.builder()
+ *                 .requestedTtl(ttl)
+ *                 .sessionId(extra.sessionId())
+ *                 .build())
+ *             .block();
+ *
+ *         // Start background work
+ *         startBackgroundComputation(task.taskId(), args);
+ *
+ *         return new McpSchema.CreateTaskResult(task, null);
+ *     })
+ *     .build();
+ *
+ * // Register with server
+ * McpServer.sync(transport)
+ *     .taskTools(spec)
+ *     .build();
+ * }
+ * + *

+ * This is an experimental API that may change in future releases. + * + * @see SyncCreateTaskHandler + * @see SyncGetTaskHandler + * @see SyncGetTaskResultHandler + * @see TaskAwareAsyncToolSpecification + * @see Builder + */ +public final class TaskAwareSyncToolSpecification { + + private final Tool tool; + + private final BiFunction callHandler; + + private final SyncCreateTaskHandler createTaskHandler; + + private final SyncGetTaskHandler getTaskHandler; + + private final SyncGetTaskResultHandler getTaskResultHandler; + + private TaskAwareSyncToolSpecification(Tool tool, + BiFunction callHandler, + SyncCreateTaskHandler createTaskHandler, SyncGetTaskHandler getTaskHandler, + SyncGetTaskResultHandler getTaskResultHandler) { + this.tool = tool; + this.callHandler = callHandler; + this.createTaskHandler = createTaskHandler; + this.getTaskHandler = getTaskHandler; + this.getTaskResultHandler = getTaskResultHandler; + } + + /** + * Returns the tool definition. + * @return the tool definition including name, description, and schema + */ + public Tool tool() { + return this.tool; + } + + /** + * Returns the handler for direct (non-task) tool calls. + * @return the call handler + */ + public BiFunction callHandler() { + return this.callHandler; + } + + /** + * Returns the handler for task creation. + * @return the create task handler + */ + public SyncCreateTaskHandler createTaskHandler() { + return this.createTaskHandler; + } + + /** + * Returns the optional custom handler for tasks/get requests. + * @return the get task handler, or null if not set + */ + public SyncGetTaskHandler getTaskHandler() { + return this.getTaskHandler; + } + + /** + * Returns the optional custom handler for tasks/result requests. + * @return the get task result handler, or null if not set + */ + public SyncGetTaskResultHandler getTaskResultHandler() { + return this.getTaskResultHandler; + } + + /** + * Creates a new builder for constructing a task-aware sync tool specification. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a synchronous server exchange from an asynchronous one. + * + *

+ * This is used internally when converting sync tools to async for server execution. + * @param asyncExchange the asynchronous exchange + * @return a synchronous exchange wrapping the async exchange + */ + McpSyncServerExchange createSyncExchange(McpAsyncServerExchange asyncExchange) { + return new McpSyncServerExchange(asyncExchange); + } + + /** + * Builder for creating task-aware sync tool specifications. + * + *

+ * This builder provides full control over task creation through the + * {@link SyncCreateTaskHandler}. Tools decide their own TTL, poll interval, and how + * to start background work. + */ + public static class Builder extends AbstractTaskAwareToolSpecificationBuilder { + + private SyncCreateTaskHandler createTaskHandler; + + private SyncGetTaskHandler getTaskHandler; + + private SyncGetTaskResultHandler getTaskResultHandler; + + /** + * Sets the handler for task creation (required). + * + *

+ * This handler is called when the tool is invoked with task metadata. The handler + * has full control over task creation including TTL configuration. + * + *

+ * Example: + * + *

{@code
+		 * .createTaskHandler((args, extra) -> {
+		 *     long ttl = Duration.ofMinutes(5).toMillis();
+		 *     Task task = extra.taskStore()
+		 *         .createTask(CreateTaskOptions.builder()
+		 *             .requestedTtl(ttl)
+		 *             .sessionId(extra.sessionId())
+		 *             .build())
+		 *         .block();
+		 *
+		 *     startBackgroundWork(task.taskId(), args);
+		 *     return new McpSchema.CreateTaskResult(task, null);
+		 * })
+		 * }
+ * @param createTaskHandler the task creation handler + * @return this builder + */ + public Builder createTaskHandler(SyncCreateTaskHandler createTaskHandler) { + this.createTaskHandler = createTaskHandler; + return this; + } + + /** + * Sets a custom handler for {@code tasks/get} requests. + * + *

+ * When set, this handler will be called instead of the default task store lookup + * when retrieving task status. This enables fetching from external storage or + * custom task lifecycle logic. + * @param getTaskHandler the custom task retrieval handler + * @return this builder + */ + public Builder getTaskHandler(SyncGetTaskHandler getTaskHandler) { + this.getTaskHandler = getTaskHandler; + return this; + } + + /** + * Sets a custom handler for {@code tasks/result} requests. + * + *

+ * When set, this handler will be called instead of the default task store lookup + * when retrieving task results. This enables fetching from external storage or + * lazy result computation. + * @param getTaskResultHandler the custom task result retrieval handler + * @return this builder + */ + public Builder getTaskResultHandler(SyncGetTaskResultHandler getTaskResultHandler) { + this.getTaskResultHandler = getTaskResultHandler; + return this; + } + + /** + * Builds the {@link TaskAwareSyncToolSpecification}. + * + *

+ * The returned specification handles task-augmented tool calls by delegating to + * the createTaskHandler. For non-task calls, the server uses an automatic polling + * shim. + * @return a new TaskAwareSyncToolSpecification instance + * @throws IllegalArgumentException if required fields (name, createTask) are not + * set + */ + @Override + public TaskAwareSyncToolSpecification build() { + validateCommonFields(); + Assert.notNull(createTaskHandler, "createTaskHandler must not be null"); + + Tool tool = buildTool(); + + // Create a placeholder callHandler for non-task calls + // (will be handled by automatic polling shim in McpSyncServer) + BiFunction callHandler = (exchange, request) -> { + throw new UnsupportedOperationException("Tool '" + name + + "' requires task-augmented execution. Either provide TaskMetadata in the request, " + + "or ensure the server has a TaskStore configured for automatic polling. " + + "Direct tool calls without task support are not available for this tool."); + }; + + return new TaskAwareSyncToolSpecification(tool, callHandler, createTaskHandler, getTaskHandler, + getTaskResultHandler); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskContext.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskContext.java new file mode 100644 index 000000000..069f10f2b --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskContext.java @@ -0,0 +1,134 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.Task; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import reactor.core.publisher.Mono; + +/** + * Context for executing a task, providing methods for status updates, cancellation + * checking, and completion signaling. + * + *

+ * TaskContext is passed to task handlers to allow them to: + *

    + *
  • Check if cancellation has been requested via {@link #isCancelled()} + *
  • Update the task status with progress information via {@link #updateStatus(String)} + *
  • Complete the task with a result via {@link #complete(McpSchema.Result)} + *
  • Fail the task with an error message via {@link #fail(String)} + *
+ * + *

+ * Example usage in a task handler: + * + *

{@code
+ * public Mono handleLongRunningTask(TaskContext context, Object args) {
+ *     return Mono.fromCallable(() -> {
+ *         for (int i = 0; i < 100; i++) {
+ *             // Check for cancellation periodically
+ *             if (context.isCancelled().block()) {
+ *                 return null; // Task was cancelled
+ *             }
+ *
+ *             // Update progress
+ *             context.updateStatus("Processing item " + i + "/100").block();
+ *
+ *             // Do actual work...
+ *             processItem(i);
+ *         }
+ *
+ *         // Complete with result
+ *         context.complete(new CallToolResult(...)).block();
+ *         return new CreateTaskResult(context.getTask(), null);
+ *     });
+ * }
+ * }
+ * + *

+ * This is an experimental API that may change in future releases. + * + */ +public interface TaskContext { + + /** + * Returns the task ID for this context. + * @return the task identifier + */ + String getTaskId(); + + /** + * Returns the current task object with the latest status. + * @return a Mono emitting the current task state + */ + Mono getTask(); + + /** + * Checks if cancellation has been requested for this task. + * + *

+ * Task handlers should check this periodically (e.g., before starting expensive + * operations) to support cooperative cancellation. + * @return a Mono emitting true if cancellation was requested, false otherwise + */ + Mono isCancelled(); + + /** + * Requests cancellation of this task. + * + *

+ * This signals that cancellation is desired. The actual cancellation is cooperative - + * the task handler must check {@link #isCancelled()} and respond appropriately. + * @return a Mono that completes when the cancellation request is recorded + */ + Mono requestCancellation(); + + /** + * Updates the task status with a progress message. + * + *

+ * This keeps the task in WORKING status but updates the statusMessage field to + * provide progress information to clients polling the task. + * @param statusMessage human-readable progress message + * @return a Mono that completes when the status is updated + */ + Mono updateStatus(String statusMessage); + + /** + * Transitions the task to INPUT_REQUIRED status. + * + *

+ * Use this when the task needs additional input (e.g., via elicitation) before it can + * continue. The task will transition back to WORKING when input is received. + * @param statusMessage description of what input is required + * @return a Mono that completes when the status is updated + */ + Mono requireInput(String statusMessage); + + /** + * Completes the task successfully with the given result. + * + *

+ * This transitions the task to COMPLETED status and stores the result for retrieval + * via tasks/result. After calling this method, the task cannot transition to any + * other state. + * @param result the task result to store + * @return a Mono that completes when the task is completed + */ + Mono complete(McpSchema.Result result); + + /** + * Fails the task with an error message. + * + *

+ * This transitions the task to FAILED status. After calling this method, the task + * cannot transition to any other state. + * @param errorMessage description of what went wrong + * @return a Mono that completes when the task is failed + */ + Mono fail(String errorMessage); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskDefaults.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskDefaults.java new file mode 100644 index 000000000..e21841c3d --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskDefaults.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; + +/** + * Default constants for task-related operations. + * + *

+ * This class centralizes task-related default values to ensure consistency across the + * SDK. All task-related components (stores, clients, servers) should reference these + * constants instead of defining their own. + * + *

+ * This is an experimental API that may change in future releases. + * + */ +public final class TaskDefaults { + + private TaskDefaults() { + // Utility class - no instantiation + } + + /** + * Default poll interval in milliseconds for task status polling. Clients will poll + * the server for task status updates at this interval unless the task specifies a + * different interval. + */ + public static final long DEFAULT_POLL_INTERVAL_MS = 1_000L; + + /** + * Default time-to-live in milliseconds for tasks. Tasks that exceed this TTL may be + * cleaned up by the task store. + */ + public static final long DEFAULT_TTL_MS = 60_000L; + + /** + * Default page size for task listing operations. + */ + public static final int DEFAULT_PAGE_SIZE = 100; + + /** + * Default maximum queue size for task message queues. + */ + public static final int DEFAULT_MAX_QUEUE_SIZE = 1000; + + /** + * Maximum allowed queue size for task message queues. Values above this limit will be + * rejected to prevent unbounded memory growth. + */ + public static final int MAX_ALLOWED_QUEUE_SIZE = 10_000; + + /** + * Maximum allowed TTL for tasks (24 hours). Setting a TTL higher than this will be + * rejected to prevent tasks from lingering indefinitely. + */ + public static final long MAX_TTL_MS = 24 * 60 * 60 * 1000L; // 24 hours + + /** + * Minimum allowed poll interval (100ms). Setting an interval lower than this will be + * rejected to prevent excessive polling. + */ + public static final long MIN_POLL_INTERVAL_MS = 100L; + + /** + * Maximum allowed poll interval (1 hour). Setting an interval higher than this will + * be rejected. + */ + public static final long MAX_POLL_INTERVAL_MS = 60 * 60 * 1000L; // 1 hour + + /** + * Default timeout for automatic task polling when a task-enabled tool is called + * without task metadata. The server will poll the task until it completes or this + * timeout is reached. Default is 30 minutes. + */ + public static final long DEFAULT_AUTOMATIC_POLLING_TIMEOUT_MS = 30 * 60 * 1000L; // 30 + // minutes + + /** + * Default maximum number of concurrent tasks for in-memory task stores. This provides + * protection against resource exhaustion while being generous enough for typical use + * cases. + */ + public static final int DEFAULT_MAX_TASKS = 10_000; + + /** + * Empty JSON schema representing an object with no properties. Used as the default + * input schema for task-aware tools that don't require input parameters. + */ + public static final JsonSchema EMPTY_INPUT_SCHEMA = new JsonSchema("object", null, null, null, null, null); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskHelper.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskHelper.java new file mode 100644 index 000000000..6bc5b1950 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskHelper.java @@ -0,0 +1,110 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; + +/** + * Utility methods for working with MCP tasks. + * + *

+ * This class provides helper methods for common task operations like checking terminal + * states and determining valid state transitions. + * + *

+ * This is an experimental API that may change in future releases. + * + */ +public final class TaskHelper { + + private TaskHelper() { + // Utility class, no instantiation + } + + /** + * Checks if a task status is a terminal state. + * + *

+ * Terminal states are: COMPLETED, FAILED, CANCELLED. Once a task reaches a terminal + * state, it cannot transition to any other state. + * + *

+ * Note: INPUT_REQUIRED is NOT a terminal state - it can transition back to WORKING. + * @param status the task status to check + * @return true if the status is terminal, false otherwise + */ + public static boolean isTerminal(TaskStatus status) { + if (status == null) { + return false; + } + return status == TaskStatus.COMPLETED || status == TaskStatus.FAILED || status == TaskStatus.CANCELLED; + } + + /** + * Checks if a state transition is valid according to the MCP task state machine. + * + *

+ * Valid transitions: + *

    + *
  • WORKING → COMPLETED, FAILED, CANCELLED, INPUT_REQUIRED + *
  • INPUT_REQUIRED → WORKING, COMPLETED, FAILED, CANCELLED + *
  • COMPLETED → (none - terminal) + *
  • FAILED → (none - terminal) + *
  • CANCELLED → (none - terminal) + *
+ * @param from the current status + * @param to the desired new status + * @return true if the transition is valid, false otherwise + */ + public static boolean isValidTransition(TaskStatus from, TaskStatus to) { + if (from == null || to == null) { + return false; + } + + // Terminal states cannot transition to any state (including themselves) + if (isTerminal(from)) { + return false; + } + + // From WORKING, can go to any state (including WORKING again). + // Note: WORKING → WORKING is valid - this represents a status update without + // actual state change, which may occur when updating the statusMessage field + // while the task continues running. + if (from == TaskStatus.WORKING) { + return true; + } + + // From INPUT_REQUIRED, can go back to WORKING or to terminal states. + // Note: INPUT_REQUIRED → INPUT_REQUIRED is NOT valid. If the task needs to + // remain in INPUT_REQUIRED state with an updated message, callers should + // update the statusMessage directly rather than calling a transition method. + // This ensures that INPUT_REQUIRED always represents a clear "waiting for + // input" → "got input, resuming work" lifecycle. + if (from == TaskStatus.INPUT_REQUIRED) { + return to == TaskStatus.WORKING || isTerminal(to); + } + + return false; + } + + /** + * Gets a human-readable description of a task status. + * @param status the task status + * @return a human-readable description + */ + public static String getStatusDescription(TaskStatus status) { + if (status == null) { + return "Unknown"; + } + return switch (status) { + case WORKING -> "Task is in progress"; + case INPUT_REQUIRED -> "Task requires additional input"; + case COMPLETED -> "Task completed successfully"; + case FAILED -> "Task failed"; + case CANCELLED -> "Task was cancelled"; + }; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskMessageQueue.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskMessageQueue.java new file mode 100644 index 000000000..21ab3f95a --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskMessageQueue.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.util.List; + +import reactor.core.publisher.Mono; + +/** + * Interface for queueing messages associated with tasks. + * + *

+ * The TaskMessageQueue enables bidirectional communication during task execution, + * supporting scenarios like elicitation or sampling during a long-running task. + * + *

+ * This is an experimental API that may change in future releases. + * + */ +public interface TaskMessageQueue { + + /** + * Enqueues a message for a task. + * @param taskId the task identifier + * @param message the message to enqueue + * @param maxSize maximum queue size (older messages are dropped if exceeded) + * @return a Mono that completes when the message is enqueued + */ + Mono enqueue(String taskId, QueuedMessage message, Integer maxSize); + + /** + * Dequeues the next message for a task. + * @param taskId the task identifier + * @return a Mono emitting the next message, or empty if queue is empty + */ + Mono dequeue(String taskId); + + /** + * Dequeues all messages for a task. + * @param taskId the task identifier + * @return a Mono emitting a list of all queued messages + */ + Mono> dequeueAll(String taskId); + + /** + * Clears all messages for a task. Called during task cleanup/expiration. + * @param taskId the task identifier + * @return a Mono that completes when cleanup is done + */ + default Mono clearTask(String taskId) { + return Mono.empty(); + } + + /** + * Gets the current queue size for a task. + * + *

+ * This method is useful for monitoring and debugging queue depth during task + * execution. Note that the returned size is a snapshot and may change immediately + * after the call returns in concurrent scenarios. + * + *

+ * Default implementation returns 0 (no monitoring support). Implementations that + * support monitoring should override this method. + * @param taskId the task identifier + * @return a Mono emitting the current queue size, or 0 if task has no queue + */ + default Mono getQueueSize(String taskId) { + return Mono.just(0); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskStore.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskStore.java new file mode 100644 index 000000000..2bfead314 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskStore.java @@ -0,0 +1,264 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.time.Duration; +import java.util.Objects; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.Task; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Interface for storing and managing MCP tasks. + * + *

+ * The TaskStore provides persistence for long-running tasks, enabling task state + * management, result storage, and task listing. Implementations may store tasks in + * memory, a database, or other backing stores. + * + *

+ * The type parameter {@code R} specifies the result type that this store handles: + *

    + *
  • For server-side stores (handling tool calls), use + * {@link McpSchema.ServerTaskPayloadResult} + *
  • For client-side stores (handling sampling/elicitation), use + * {@link McpSchema.ClientTaskPayloadResult} + *
  • For stores that can handle any result type, use {@link McpSchema.Result} + *
+ * + *

Error Handling Contract

+ *

+ * Methods in this interface follow these error handling conventions: + *

    + *
  • {@link #getTask}, {@link #getTaskResult}: Return empty Mono if task not found or + * session mismatch (null-safe pattern for optional lookups)
  • + *
  • {@link #storeTaskResult}: Throws {@link io.modelcontextprotocol.spec.McpError} if + * task not found or session mismatch (fail-fast to prevent data loss)
  • + *
  • {@link #requestCancellation}: Returns empty Mono if task not found or session + * mismatch (idempotent); throws {@link io.modelcontextprotocol.spec.McpError} with code + * {@code -32602} if task is in terminal status (per MCP specification requirement)
  • + *
  • {@link #updateTaskStatus}: Completes silently if task not found or session mismatch + * (idempotent)
  • + *
+ * + *

Session Isolation Model (Defense-in-Depth)

+ *

+ * All task operations require a {@code sessionId} parameter for defense-in-depth session + * isolation. This ensures that even if higher layers (e.g., request handlers) have bugs, + * the TaskStore itself enforces session boundaries. + * + *

+ * Session validation rules: + *

    + *
  • If {@code sessionId} is {@code null}, access is allowed (single-tenant mode)
  • + *
  • If task has no session (created with null sessionId), access is allowed from any + * session
  • + *
  • If both task and request have session IDs, they must match for access
  • + *
+ * + *

+ * For implementers: + *

    + *
  1. Store the session ID from {@link CreateTaskOptions#sessionId()} when creating + * tasks
  2. + *
  3. Validate session ID on ALL operations using the rules above
  4. + *
  5. Use atomic operations to prevent TOCTOU race conditions between session check and + * data access
  6. + *
+ * + *

+ * Durability Note: The default {@link InMemoryTaskStore} does not + * provide durability guarantees - task results may be lost if the server crashes during + * execution. For production use cases requiring durability, implement a custom TaskStore + * backed by persistent storage (database, Redis, etc.). + * + *

+ * This is an experimental API that may change in future releases. + * + * @param the type of result this store handles + */ +public interface TaskStore { + + /** + * Creates a new task with the given options. + * + *

+ * The session ID for the task is captured from {@link CreateTaskOptions#sessionId()}. + * @param options the task creation options + * @return a Mono emitting the created Task + */ + Mono createTask(CreateTaskOptions options); + + /** + * Retrieves a task by its ID with session validation. + * + *

+ * This method performs atomic session validation - the task is only returned if the + * session ID matches (or if either is null for single-tenant mode). + * + *

+ * The returned {@link GetTaskFromStoreResult} contains both the task and the original + * request that created it, enabling callers to access full context without separate + * lookups. For tool calls, the originating request will be a + * {@link McpSchema.CallToolRequest} containing the tool name. + * @param taskId the task identifier + * @param sessionId the session ID for validation, or null for single-tenant mode + * @return a Mono emitting the GetTaskFromStoreResult, or empty if not found or + * session mismatch + */ + Mono getTask(String taskId, String sessionId); + + /** + * Updates the status of a task with session validation. + * + *

+ * Terminal state behavior: If the task is already in a terminal + * state (COMPLETED, FAILED, or CANCELLED), the update will be silently ignored and + * the Mono will complete successfully without making any changes. This is intentional + * - once a task reaches a terminal state, it cannot transition to any other state. + * + *

+ * Valid state transitions: + *

    + *
  • WORKING → INPUT_REQUIRED, COMPLETED, FAILED, CANCELLED
  • + *
  • INPUT_REQUIRED → WORKING, COMPLETED, FAILED, CANCELLED
  • + *
  • COMPLETED, FAILED, CANCELLED → (no further transitions allowed)
  • + *
+ * @param taskId the task identifier + * @param sessionId the session ID for validation, or null for single-tenant mode + * @param status the new status + * @param statusMessage optional human-readable status message + * @return a Mono that completes when the update is done (or silently ignored for + * terminal tasks or session mismatch) + */ + Mono updateTaskStatus(String taskId, String sessionId, TaskStatus status, String statusMessage); + + /** + * Stores the result of a completed task with session validation. + * + *

+ * Implementations should throw {@link io.modelcontextprotocol.spec.McpError} with + * {@link McpSchema.ErrorCodes#INVALID_PARAMS} if the task is not found or session + * validation fails. This ensures callers are notified of race conditions (e.g., task + * expired and was cleaned up between checking existence and storing result) rather + * than silently losing data. + * @param taskId the task identifier + * @param sessionId the session ID for validation, or null for single-tenant mode + * @param status the terminal status (completed, failed, or cancelled) + * @param result the result to store + * @return a Mono that completes when the result is stored + * @throws io.modelcontextprotocol.spec.McpError if task not found or session mismatch + */ + Mono storeTaskResult(String taskId, String sessionId, TaskStatus status, R result); + + /** + * Retrieves the stored result of a task with session validation. + * @param taskId the task identifier + * @param sessionId the session ID for validation, or null for single-tenant mode + * @return a Mono emitting the Result, or empty if not available or session mismatch + */ + Mono getTaskResult(String taskId, String sessionId); + + /** + * Lists tasks with pagination and session filtering. + * + *

+ * When sessionId is provided, only tasks belonging to that session are returned. When + * sessionId is null, all tasks are returned (single-tenant mode). + * + *

+ * Implementation note: When filtering by sessionId, pages may + * contain fewer than the configured page size entries. This is intentional - it + * ensures consistent cursor behavior while allowing session-scoped views of the task + * list. Implementations should NOT attempt to "fill" pages by fetching additional + * entries. + * @param cursor optional pagination cursor + * @param sessionId the session ID to filter tasks by, or null for all tasks + * @return a Mono emitting the ListTasksResult + */ + Mono listTasks(String cursor, String sessionId); + + /** + * Requests cancellation of a task with session validation. This is cooperative - the + * task handler must periodically check for cancellation. + * + *

+ * Per the MCP specification, cancellation of tasks in terminal status (COMPLETED, + * FAILED, or CANCELLED) MUST be rejected with error code {@code -32602} (Invalid + * params). Implementations must throw {@link io.modelcontextprotocol.spec.McpError} + * with the appropriate error code when this occurs. + * @param taskId the task identifier + * @param sessionId the session ID for validation, or null for single-tenant mode + * @return a Mono emitting the updated Task after cancellation is requested, or empty + * if task not found or session mismatch + * @throws io.modelcontextprotocol.spec.McpError with code {@code -32602} if the task + * is in a terminal state + */ + Mono requestCancellation(String taskId, String sessionId); + + /** + * Checks if cancellation has been requested for a task with session validation. + * @param taskId the task identifier + * @param sessionId the session ID for validation, or null for single-tenant mode + * @return a Mono emitting true if cancellation was requested, false if not canceled + * or task not found/session mismatch + */ + Mono isCancellationRequested(String taskId, String sessionId); + + /** + * Watches a task until it reaches a terminal state (COMPLETED, FAILED, or CANCELLED), + * emitting status updates along the way. + * + *

+ * This method is used for implementing blocking behavior in tasks/result requests. It + * polls the task status at regular intervals and emits each status update until the + * task reaches a terminal state or the timeout is reached. + * + *

+ * Default Implementation Note: The default implementation uses + * {@link Flux#interval} for polling, which creates periodic emissions. This is a + * basic approach suitable for development and testing. For production deployments + * with many concurrent tasks, consider overriding this method to use more efficient + * mechanisms: + *

    + *
  • Event-based watching with notifications
  • + *
  • Callback-based approaches using CompletableFuture
  • + *
  • Long polling or server-sent events
  • + *
+ * + *

+ * The default implementation also uses {@code concatMap} (sequential), meaning only + * one {@link #getTask} call is in-flight at a time per watching stream. + * @param taskId the task identifier + * @param sessionId the session ID for validation, or null for single-tenant mode + * @param timeout maximum duration to wait for the task to reach terminal state + * @return a Flux emitting Task status updates, completing when terminal or timing out + */ + default Flux watchTaskUntilTerminal(String taskId, String sessionId, Duration timeout) { + long pollIntervalMs = TaskDefaults.DEFAULT_POLL_INTERVAL_MS; + return Flux.interval(Duration.ofMillis(pollIntervalMs)) + .concatMap(tick -> getTask(taskId, sessionId).map(GetTaskFromStoreResult::task)) + .filter(Objects::nonNull) + .takeUntil(Task::isTerminal) + .timeout(timeout); + } + + /** + * Shuts down the task store, releasing any resources such as background threads or + * connections. + * + *

+ * Default implementation is a no-op. Implementations with cleanup requirements should + * override this method. + * @return a Mono that completes when shutdown is complete + */ + default Mono shutdown() { + return Mono.empty(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/package-info.java b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/package-info.java new file mode 100644 index 000000000..3aeb30d31 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/package-info.java @@ -0,0 +1,121 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +/** + * Experimental support for MCP Tasks (SEP-1686). + * + *

Core Types

+ *
    + *
  • {@link io.modelcontextprotocol.experimental.tasks.TaskStore} - Interface for task + * state persistence
  • + *
  • {@link io.modelcontextprotocol.experimental.tasks.InMemoryTaskStore} - In-memory + * TaskStore implementation
  • + *
  • {@link io.modelcontextprotocol.experimental.tasks.TaskContext} - Runtime context + * for task lifecycle management
  • + *
  • {@link io.modelcontextprotocol.experimental.tasks.DefaultTaskContext} - Default + * TaskContext implementation
  • + *
  • {@link io.modelcontextprotocol.experimental.tasks.TaskMessageQueue} - Queue for + * INPUT_REQUIRED state communication
  • + *
+ * + *

Handler Interfaces

+ * + *

+ * Task-aware tools use three handler interfaces for different lifecycle phases: + * + *

CreateTaskHandler / SyncCreateTaskHandler (Required)

+ *

+ * Invoked when a task-augmented request is received. Responsible for: + *

    + *
  • Creating the task in the TaskStore
  • + *
  • Starting any background work
  • + *
  • Returning a {@link io.modelcontextprotocol.spec.McpSchema.CreateTaskResult + * CreateTaskResult} with the task details
  • + *
+ * + *

GetTaskHandler / SyncGetTaskHandler (Optional)

+ *

+ * Custom handler for {@code tasks/get} requests. Use when: + *

    + *
  • Mapping external job IDs to MCP task status
  • + *
  • Implementing custom status derivation
  • + *
+ *

+ * If not provided, the default implementation uses + * {@link io.modelcontextprotocol.experimental.tasks.TaskStore#getTask(String, String) + * TaskStore.getTask()}. + * + *

GetTaskResultHandler / SyncGetTaskResultHandler (Optional)

+ *

+ * Custom handler for {@code tasks/result} requests. Use when: + *

    + *
  • Fetching results from external systems
  • + *
  • Transforming stored results before returning
  • + *
+ *

+ * If not provided, the default implementation uses + * {@link io.modelcontextprotocol.experimental.tasks.TaskStore#getTaskResult(String, String) + * TaskStore.getTaskResult()}. + * + *

Automatic Polling Behavior

+ * + *

+ * When a task-aware tool with + * {@link io.modelcontextprotocol.spec.McpSchema.TaskSupportMode#OPTIONAL OPTIONAL} mode + * is called without task metadata, the server automatically handles the + * task lifecycle: + * + *

    + *
  1. Creates an internal task via the tool's createTask handler
  2. + *
  3. Polls the task status using the configured poll interval
  4. + *
  5. When the task reaches a terminal state (COMPLETED, FAILED, CANCELLED), retrieves + * the result
  6. + *
  7. Returns the result to the caller as if it were a synchronous operation
  8. + *
+ * + *

+ * This behavior provides backward compatibility for clients that don't support the tasks + * protocol extension. + * + *

+ * Note: Automatic polling does NOT work for tasks that enter + * {@link io.modelcontextprotocol.spec.McpSchema.TaskStatus#INPUT_REQUIRED INPUT_REQUIRED} + * state, as this requires explicit client interaction. Such tasks will timeout or fail. + * + *

+ * For tools where automatic polling is not appropriate (e.g., very long operations, tasks + * requiring user input), use + * {@link io.modelcontextprotocol.spec.McpSchema.TaskSupportMode#REQUIRED REQUIRED} mode + * instead. + * + *

Tool Specifications

+ *
    + *
  • {@link io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification} + * - Async task-aware tool definition
  • + *
  • {@link io.modelcontextprotocol.experimental.tasks.TaskAwareSyncToolSpecification} - + * Sync task-aware tool definition
  • + *
+ * + *

Context Types

+ *
    + *
  • {@link io.modelcontextprotocol.experimental.tasks.CreateTaskExtra} / + * {@link io.modelcontextprotocol.experimental.tasks.SyncCreateTaskExtra} - Handler + * context
  • + *
  • {@link io.modelcontextprotocol.experimental.tasks.CreateTaskOptions} - Task + * creation configuration
  • + *
+ * + *

Utilities

+ *
    + *
  • {@link io.modelcontextprotocol.experimental.tasks.TaskDefaults} - Default + * constants
  • + *
  • {@link io.modelcontextprotocol.experimental.tasks.TaskHelper} - State transition + * utilities
  • + *
+ * + *

+ * WARNING: This is an experimental API that may change in future + * releases. + */ +package io.modelcontextprotocol.experimental.tasks; diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 23285d514..297252470 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -14,6 +14,12 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import io.modelcontextprotocol.experimental.tasks.CreateTaskExtra; +import io.modelcontextprotocol.experimental.tasks.DefaultCreateTaskExtra; +import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskDefaults; +import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue; +import io.modelcontextprotocol.experimental.tasks.TaskStore; import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.json.schema.JsonSchemaValidator; @@ -105,18 +111,37 @@ public class McpAsyncServer { private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + // Index for fast lookup of tools by name + private final ConcurrentHashMap toolsByName = new ConcurrentHashMap<>(); + + // Task-aware tools that support long-running operations + private final CopyOnWriteArrayList taskTools = new CopyOnWriteArrayList<>(); + + // Index for fast lookup of task tools by name (for task handler dispatch) + private final ConcurrentHashMap taskToolsByName = new ConcurrentHashMap<>(); + + // Lock for atomic tool registration to prevent race conditions between normal tools + // and task tools + private final Object toolRegistrationLock = new Object(); + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); private final ConcurrentHashMap resourceTemplates = new ConcurrentHashMap<>(); private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - // FIXME: this field is deprecated and should be remvoed together with the + // FIXME: this field is deprecated and should be removed together with the // broadcasting loggingNotification. private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + private final TaskStore taskStore; + + private final TaskMessageQueue taskMessageQueue; + + private final Duration automaticPollingTimeout; + private List protocolVersions; private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DefaultMcpUriTemplateManagerFactory(); @@ -130,19 +155,34 @@ public class McpAsyncServer { */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + Duration automaticPollingTimeout) { this.mcpTransportProvider = mcpTransportProvider; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); this.instructions = features.instructions(); - this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + List wrappedTools = withStructuredOutputHandling(jsonSchemaValidator, + features.tools()); + this.tools.addAll(wrappedTools); + // Populate toolsByName index for fast lookup + for (McpServerFeatures.AsyncToolSpecification tool : wrappedTools) { + this.toolsByName.put(tool.tool().name(), tool); + } + // Populate task tools from features + this.taskTools.addAll(features.taskTools()); + for (TaskAwareAsyncToolSpecification taskTool : features.taskTools()) { + this.taskToolsByName.put(taskTool.tool().name(), taskTool); + } this.resources.putAll(features.resources()); this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.taskStore = features.taskStore(); + this.taskMessageQueue = features.taskMessageQueue(); + this.automaticPollingTimeout = automaticPollingTimeout; Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); @@ -155,19 +195,34 @@ public class McpAsyncServer { McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + Duration automaticPollingTimeout) { this.mcpTransportProvider = mcpTransportProvider; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities().mutate().logging().build(); this.instructions = features.instructions(); - this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + List wrappedTools = withStructuredOutputHandling(jsonSchemaValidator, + features.tools()); + this.tools.addAll(wrappedTools); + // Populate toolsByName index for fast lookup + for (McpServerFeatures.AsyncToolSpecification tool : wrappedTools) { + this.toolsByName.put(tool.tool().name(), tool); + } + // Populate task tools from features + this.taskTools.addAll(features.taskTools()); + for (TaskAwareAsyncToolSpecification taskTool : features.taskTools()) { + this.taskToolsByName.put(taskTool.tool().name(), taskTool); + } this.resources.putAll(features.resources()); this.resourceTemplates.putAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.taskStore = features.taskStore(); + this.taskMessageQueue = features.taskMessageQueue(); + this.automaticPollingTimeout = automaticPollingTimeout; Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); @@ -232,6 +287,29 @@ private Map> prepareRequestHandlers() { if (this.serverCapabilities.completions() != null) { requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); } + + // Add tasks API handlers if the tasks capability is enabled + // Warn about capability/implementation mismatches + if (this.serverCapabilities.tasks() != null && this.taskStore == null) { + logger.warn("Server has tasks capability enabled but no TaskStore configured. " + + "Task operations will not be available. Either provide a TaskStore or " + + "remove the tasks capability."); + } + if (this.taskStore != null && this.serverCapabilities.tasks() == null) { + logger.warn("Server has TaskStore configured but tasks capability is not enabled. " + + "Task operations will not be available. Either enable the tasks capability " + + "or remove the TaskStore configuration."); + } + if (this.serverCapabilities.tasks() != null && this.taskStore != null) { + requestHandlers.put(McpSchema.METHOD_TASKS_GET, tasksGetRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TASKS_RESULT, tasksResultRequestHandler()); + if (this.serverCapabilities.tasks().list() != null) { + requestHandlers.put(McpSchema.METHOD_TASKS_LIST, tasksListRequestHandler()); + } + if (this.serverCapabilities.tasks().cancel() != null) { + requestHandlers.put(McpSchema.METHOD_TASKS_CANCEL, tasksCancelRequestHandler()); + } + } return requestHandlers; } @@ -288,13 +366,17 @@ public McpSchema.Implementation getServerInfo() { * @return A Mono that completes when the server has been closed */ public Mono closeGracefully() { - return this.mcpTransportProvider.closeGracefully(); + Mono taskStoreShutdown = this.taskStore != null ? this.taskStore.shutdown() : Mono.empty(); + return taskStoreShutdown.then(this.mcpTransportProvider.closeGracefully()); } /** * Close the server immediately. */ public void close() { + if (this.taskStore != null) { + this.taskStore.shutdown().block(Duration.ofSeconds(5)); + } this.mcpTransportProvider.close(); } @@ -336,13 +418,23 @@ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecifica var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); return Mono.defer(() -> { - // Remove tools with duplicate tool names first - if (this.tools.removeIf(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { - logger.warn("Replace existing Tool with name '{}'", wrappedToolSpecification.tool().name()); - } + String toolName = wrappedToolSpecification.tool().name(); + synchronized (this.toolRegistrationLock) { + // Check for name collision with task tools + if (this.taskToolsByName.containsKey(toolName)) { + return Mono + .error(new IllegalArgumentException("A task tool with name '" + toolName + "' already exists")); + } + + // Remove tools with duplicate tool names first + if (this.tools.removeIf(th -> th.tool().name().equals(toolName))) { + logger.warn("Replace existing Tool with name '{}'", toolName); + } - this.tools.add(wrappedToolSpecification); - logger.debug("Added tool handler: {}", wrappedToolSpecification.tool().name()); + this.tools.add(wrappedToolSpecification); + this.toolsByName.put(toolName, wrappedToolSpecification); + } + logger.debug("Added tool handler: {}", toolName); if (this.serverCapabilities.tools().listChanged()) { return notifyToolsListChanged(); @@ -489,7 +581,7 @@ public Mono removeTool(String toolName) { return Mono.defer(() -> { if (this.tools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { - + this.toolsByName.remove(toolName); logger.debug("Removed tool handler: {}", toolName); if (this.serverCapabilities.tools().listChanged()) { return notifyToolsListChanged(); @@ -503,6 +595,93 @@ public Mono removeTool(String toolName) { }); } + /** + * Add a new task-aware tool at runtime. + * + *

+ * Task-aware tools support long-running operations with task lifecycle management + * (SEP-1686). They differ from normal tools in that they can return tasks instead of + * direct results. + * @param taskToolSpecification The task-aware tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTaskTool(TaskAwareAsyncToolSpecification taskToolSpecification) { + if (taskToolSpecification == null) { + return Mono.error(new IllegalArgumentException("Task tool specification must not be null")); + } + if (taskToolSpecification.tool() == null) { + return Mono.error(new IllegalArgumentException("Tool must not be null")); + } + if (taskToolSpecification.createTaskHandler() == null) { + return Mono.error(new IllegalArgumentException("createTask handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + String toolName = taskToolSpecification.tool().name(); + synchronized (this.toolRegistrationLock) { + // Check for name collision with normal tools + if (this.toolsByName.containsKey(toolName)) { + return Mono.error( + new IllegalArgumentException("A normal tool with name '" + toolName + "' already exists")); + } + + // Remove existing task tool with same name if present + if (this.taskTools.removeIf(th -> th.tool().name().equals(toolName))) { + logger.warn("Replace existing TaskTool with name '{}'", toolName); + } + + this.taskTools.add(taskToolSpecification); + this.taskToolsByName.put(toolName, taskToolSpecification); + } + logger.debug("Added task tool handler: {}", toolName); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a task-aware tool at runtime. + * @param toolName The name of the task-aware tool to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTaskTool(String toolName) { + if (toolName == null) { + return Mono.error(new IllegalArgumentException("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new IllegalStateException("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + if (this.taskTools.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName))) { + this.taskToolsByName.remove(toolName); + logger.debug("Removed task tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + } + else { + logger.warn("Ignore as a TaskTool with name '{}' not found", toolName); + } + + return Mono.empty(); + }); + } + + /** + * List all registered task-aware tools. + * @return A Flux stream of all registered task-aware tools + */ + public Flux listTaskTools() { + return Flux.fromIterable(this.taskTools).map(TaskAwareAsyncToolSpecification::tool); + } + /** * Notifies clients that the list of available tools has changed. * @return A Mono that completes when all clients have been notified @@ -513,33 +692,229 @@ public Mono notifyToolsListChanged() { private McpRequestHandler toolsListRequestHandler() { return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + // Combine normal tools and task tools into a single list + List allTools = new java.util.ArrayList<>(); + allTools.addAll(this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList()); + allTools.addAll(this.taskTools.stream().map(TaskAwareAsyncToolSpecification::tool).toList()); - return Mono.just(new McpSchema.ListToolsResult(tools, null)); + return Mono.just(new McpSchema.ListToolsResult(allTools, null)); }; } - private McpRequestHandler toolsCallRequestHandler() { + private McpRequestHandler toolsCallRequestHandler() { return (exchange, params) -> { McpSchema.CallToolRequest callToolRequest = jsonMapper.convertValue(params, new TypeRef() { }); - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + String toolName = callToolRequest.name(); - if (toolSpecification.isEmpty()) { - return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) - .message("Unknown tool: invalid_tool_name") - .data("Tool not found: " + callToolRequest.name()) - .build()); + // First, check if it's a normal tool + McpServerFeatures.AsyncToolSpecification normalTool = this.toolsByName.get(toolName); + if (normalTool != null) { + // Normal tools do not support task-augmented requests + if (callToolRequest.task() != null) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Tool '" + toolName + "' does not support task-augmented requests") + .data("Remove the 'task' parameter or use a task-aware tool") + .build()); + } + return normalTool.callHandler().apply(exchange, callToolRequest).cast(Object.class); + } + + // Second, check if it's a task-aware tool + TaskAwareAsyncToolSpecification taskTool = this.taskToolsByName.get(toolName); + if (taskTool != null) { + return handleTaskToolCall(exchange, callToolRequest, taskTool).cast(Object.class); } - return toolSpecification.get().callHandler().apply(exchange, callToolRequest); + // Tool not found + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: " + callToolRequest.name()) + .data("Tool not found: " + callToolRequest.name()) + .build()); }; } + /** + * Handles a call to a task-aware tool. Task-aware tools always support tasks and use + * the createTaskHandler for task creation. + */ + private Mono handleTaskToolCall(McpAsyncServerExchange exchange, McpSchema.CallToolRequest request, + TaskAwareAsyncToolSpecification taskTool) { + + McpSchema.ToolExecution execution = taskTool.tool().execution(); + McpSchema.TaskSupportMode taskSupportMode = execution != null ? execution.taskSupport() : null; + + // Handle task-augmented calls + if (request.task() != null) { + // Check if server has task capability + if (this.taskStore == null) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_REQUEST) + .message("Server does not support tasks") + .data("Task store not configured") + .build()); + } + return handleTaskToolCreateTask(exchange, request, taskTool); + } + + // Check if tool REQUIRES task augmentation + if (taskSupportMode == McpSchema.TaskSupportMode.REQUIRED) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("This tool requires task-augmented execution") + .data("Tool '" + request.name() + "' requires task metadata in the request") + .build()); + } + + // No task metadata - use automatic polling shim if taskStore is configured + if (this.taskStore != null) { + return handleAutomaticTaskPolling(exchange, request, taskTool); + } + + // Fall back to direct call if no task store + return taskTool.callHandler().apply(exchange, request); + } + + /** + * Handles task creation for a task-aware tool. The tool's createTaskHandler has full + * control over task creation including TTL configuration. + */ + private Mono handleTaskToolCreateTask(McpAsyncServerExchange exchange, + McpSchema.CallToolRequest request, TaskAwareAsyncToolSpecification taskTool) { + + // Extract request TTL from task metadata. + // If null (no task metadata or TTL not specified), the TaskStore's default TTL + // will be used when creating the task. This allows clients to optionally request + // shorter TTLs but does not require them to specify one. + Long requestTtl = request.task() != null ? request.task().ttl() : null; + + // Create the extra context for the handler + CreateTaskExtra extra = new DefaultCreateTaskExtra(this.taskStore, this.taskMessageQueue, exchange, + exchange.sessionId(), requestTtl, request); + + // Delegate to the tool's createTaskHandler + Map args = request.arguments() != null ? request.arguments() : Map.of(); + + return taskTool.createTaskHandler() + .createTask(args, extra) + // Wrap non-McpError exceptions in McpError for consistent error handling + .onErrorMap(e -> !(e instanceof McpError), + e -> new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + "Task creation failed: " + e.getMessage(), null))); + } + + /** + * Handles automatic task polling for task-aware tools called without task metadata. + * + *

When This Occurs

+ *

+ * This behavior only occurs when ALL of the following are true: + *

    + *
  • The tool has {@link McpSchema.TaskSupportMode#OPTIONAL}
  • + *
  • The request does NOT include task metadata
  • + *
  • A {@link TaskStore} is configured on the server
  • + *
+ * + *

The Flow

+ *

+ * When a task-aware tool is called without task metadata, we: + *

    + *
  1. Call createTaskHandler internally to create an internal task
  2. + *
  3. Poll the task at the configured poll interval until it reaches a terminal + * state
  4. + *
  5. Retrieve the final result and return it directly as a CallToolResult
  6. + *
+ * + *

+ * This makes the call appear synchronous to the caller - they receive the final + * result without needing to manage task polling themselves. + * + *

Contrast with REQUIRED Mode

+ *

+ * Tools with {@link McpSchema.TaskSupportMode#REQUIRED} will return error -32601 + * (METHOD_NOT_FOUND) if called without task metadata. This automatic polling behavior + * is only available for OPTIONAL mode tools. + * + *

Limitations

+ *
    + *
  • Tasks requiring interactive input (INPUT_REQUIRED) will fail with an error + * since automatic polling cannot support bidirectional communication
  • + *
  • Long-running tasks may timeout based on the server's automatic polling timeout + * configuration
  • + *
+ * @param exchange the server exchange context + * @param request the original tool call request (without task metadata) + * @param taskTool the task-aware tool specification + * @return a Mono that completes with the final CallToolResult + */ + private Mono handleAutomaticTaskPolling(McpAsyncServerExchange exchange, + McpSchema.CallToolRequest request, TaskAwareAsyncToolSpecification taskTool) { + + // Create the extra context for the handler (no request TTL since no task + // metadata) + CreateTaskExtra extra = new DefaultCreateTaskExtra(this.taskStore, this.taskMessageQueue, exchange, + exchange.sessionId(), null, request); + + Map args = request.arguments() != null ? request.arguments() : Map.of(); + + // 1. Call createTask handler internally + return taskTool.createTaskHandler().createTask(args, extra).flatMap(createResult -> { + McpSchema.Task task = createResult.task(); + if (task == null) { + return Mono.error(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("createTaskHandler did not return a task") + .build()); + } + + String taskId = task.taskId(); + long pollInterval = task.pollInterval() != null ? task.pollInterval() + : TaskDefaults.DEFAULT_POLL_INTERVAL_MS; + + // 2. Poll until terminal state or INPUT_REQUIRED + // Note: INPUT_REQUIRED is not terminal but needs special handling for + // automatic polling + String sessionId = exchange.sessionId(); + return Flux.interval(Duration.ofMillis(pollInterval)).flatMap(tick -> { + // Use getTaskHandler or default + if (taskTool.getTaskHandler() != null) { + return taskTool.getTaskHandler() + .handle(exchange, McpSchema.GetTaskRequest.builder().taskId(taskId).build()) + .map(McpSchema.GetTaskResult::toTask); + } + return this.taskStore.getTask(taskId, sessionId) + .map(io.modelcontextprotocol.experimental.tasks.GetTaskFromStoreResult::task); + }) + .filter(t -> t != null) + .takeUntil(t -> t.isTerminal() || t.status() == McpSchema.TaskStatus.INPUT_REQUIRED) + .last() + .timeout(getEffectiveAutomaticPollingTimeout()) + .onErrorResume(java.util.concurrent.TimeoutException.class, + e -> Mono.error(new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + "Task timed out waiting for completion: " + taskId, null)))) + .flatMap(finalTask -> { + // Handle INPUT_REQUIRED - automatic polling cannot support + // interactive input + if (finalTask.status() == McpSchema.TaskStatus.INPUT_REQUIRED) { + return Mono.error(new McpError(new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, + "Task requires interactive input which is not supported in automatic polling mode. " + + "Use task-augmented requests (with TaskMetadata) to enable interactive input. " + + "Task ID: " + taskId, + null))); + } + // 3. Get final result + if (taskTool.getTaskResultHandler() != null) { + return taskTool.getTaskResultHandler() + .handle(exchange, McpSchema.GetTaskPayloadRequest.builder().taskId(taskId).build()) + .map(result -> (McpSchema.CallToolResult) result); + } + return this.taskStore.getTaskResult(taskId, sessionId) + .map(result -> (McpSchema.CallToolResult) result); + }); + }); + } + // --------------------------------------- // Resource Management // --------------------------------------- @@ -1073,4 +1448,216 @@ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } + // --------------------------------------- + // Task Management (Experimental) + // --------------------------------------- + + /** + * Get the task store used for managing long-running tasks. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @return The task store, or null if tasks are not enabled + */ + public TaskStore getTaskStore() { + return this.taskStore; + } + + /** + * Get the task message queue used for task communication during input_required state. + * + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @return The task message queue, or null if not configured + */ + public TaskMessageQueue getTaskMessageQueue() { + return this.taskMessageQueue; + } + + private McpRequestHandler tasksGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetTaskRequest request = jsonMapper.convertValue(params, new TypeRef() { + }); + + String sessionId = exchange.sessionId(); + + // Validate session ownership before any processing + return getTaskWithSessionValidation(request.taskId(), sessionId).flatMap(storeResult -> { + McpSchema.Task task = storeResult.task(); + // Extract tool name from originating request (if it was a tool call) + String toolName = null; + if (storeResult.originatingRequest() instanceof McpSchema.CallToolRequest ctr) { + toolName = ctr.name(); + } + // Check for custom handler + TaskAwareAsyncToolSpecification taskTool = toolName != null ? this.taskToolsByName.get(toolName) : null; + var handler = taskTool != null ? taskTool.getTaskHandler() : null; + if (handler != null) { + // Use custom handler - full override pattern (fetches everything + // independently) + return handler.handle(exchange, request); + } + // Fallback to default: already validated, just convert to result + return Mono.just(McpSchema.GetTaskResult.fromTask(task)); + }); + }; + } + + /** + * Validates that the requesting session has permission to access the specified task. + * + *

+ * This implements session isolation for multi-client server scenarios. A task can be + * accessed if: + *

    + *
  • The task has no associated session (single-client mode)
  • + *
  • The requesting session ID matches the task's session ID
  • + *
+ * + *

+ * If access is denied, the error message is intentionally vague ("Task not found") to + * avoid revealing the existence of tasks belonging to other sessions. + * @param taskId the ID of the task to access + * @param exchangeSessionId the session ID of the requesting client + * @return a Mono emitting the GetTaskFromStoreResult if access is allowed + * @throws McpError with INVALID_PARAMS if task not found or access denied + */ + private Mono getTaskWithSessionValidation( + String taskId, String exchangeSessionId) { + // TaskStore.getTask performs session validation and returns empty if access + // denied + return this.taskStore.getTask(taskId, exchangeSessionId) + .switchIfEmpty(Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL)") + .data("Task ID: " + taskId) + .build())); + } + + private McpRequestHandler tasksResultRequestHandler() { + return (exchange, params) -> { + McpSchema.GetTaskPayloadRequest request = jsonMapper.convertValue(params, + new TypeRef() { + }); + + String sessionId = exchange.sessionId(); + + // Validate session ownership before any processing + return getTaskWithSessionValidation(request.taskId(), sessionId).flatMap(storeResult -> { + McpSchema.Task task = storeResult.task(); + // Extract tool name from originating request (if it was a tool call) + String toolName = null; + if (storeResult.originatingRequest() instanceof McpSchema.CallToolRequest ctr) { + toolName = ctr.name(); + } + // Check for custom handler + TaskAwareAsyncToolSpecification taskTool = toolName != null ? this.taskToolsByName.get(toolName) : null; + var handler = taskTool != null ? taskTool.getTaskResultHandler() : null; + if (handler != null) { + // Use custom handler - full override pattern (fetches everything + // independently) + return handler.handle(exchange, request); + } + // Fallback to default task store lookup (session already validated) + return defaultGetTaskResult(exchange, request, task); + }); + }; + } + + /** + * Default implementation for tasks/result that uses the task store. Uses a + * pre-validated task to avoid redundant lookups and ensure session validation has + * already occurred. + */ + private Mono defaultGetTaskResult(McpAsyncServerExchange exchange, + McpSchema.GetTaskPayloadRequest request, McpSchema.Task task) { + String sessionId = exchange.sessionId(); + + // If already terminal, return result immediately + if (task.isTerminal()) { + return fetchTaskResult(request.taskId(), sessionId); + } + + // Block until task reaches terminal state (per SEP-1686 spec) + return this.taskStore.watchTaskUntilTerminal(request.taskId(), sessionId, getEffectiveAutomaticPollingTimeout()) + .last() + .onErrorResume(java.util.concurrent.TimeoutException.class, + e -> Mono.error(McpError.builder(ErrorCodes.INTERNAL_ERROR) + .message("Task did not complete within timeout") + .data("Task ID: " + request.taskId()) + .build())) + .flatMap(terminalTask -> fetchTaskResult(request.taskId(), sessionId)); + } + + /** + * Fetches the result for a task that is in terminal state. + */ + // Safe: TaskStore where ServerTaskPayloadResult is sealed to + // CallToolResult, + // which implements Result. The cast from ServerTaskPayloadResult to Result is always + // valid. + @SuppressWarnings("unchecked") + private Mono fetchTaskResult(String taskId, String sessionId) { + return this.taskStore.getTaskResult(taskId, sessionId) + .map(result -> (McpSchema.Result) result) + .switchIfEmpty(Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Task result not available") + .data("Task ID: " + taskId) + .build())); + } + + /** + * Returns the effective automatic polling timeout, using the configured value or the + * default if not configured. + */ + private Duration getEffectiveAutomaticPollingTimeout() { + return this.automaticPollingTimeout != null ? this.automaticPollingTimeout + : Duration.ofMillis(TaskDefaults.DEFAULT_AUTOMATIC_POLLING_TIMEOUT_MS); + } + + private McpRequestHandler tasksListRequestHandler() { + return (exchange, params) -> { + McpSchema.PaginatedRequest request = jsonMapper.convertValue(params, + new TypeRef() { + }); + + // Use session-filtered listing for proper isolation + return this.taskStore.listTasks(request != null ? request.cursor() : null, exchange.sessionId()); + }; + } + + private McpRequestHandler tasksCancelRequestHandler() { + return (exchange, params) -> { + McpSchema.CancelTaskRequest request = jsonMapper.convertValue(params, + new TypeRef() { + }); + + String sessionId = exchange.sessionId(); + + // Validate session ownership before allowing cancellation + return getTaskWithSessionValidation(request.taskId(), sessionId) + .flatMap(task -> this.taskStore.requestCancellation(request.taskId(), sessionId)) + .switchIfEmpty(Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL)") + .data("Task ID: " + request.taskId()) + .build())) + .map(McpSchema.CancelTaskResult::fromTask); + }; + } + + /** + * Sends a task status notification to all connected clients. + * @param taskStatusNotification The task status notification to send + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyTaskStatus(McpSchema.TaskStatusNotification taskStatusNotification) { + if (taskStatusNotification == null) { + return Mono.error(McpError.builder(ErrorCodes.INVALID_REQUEST) + .message("Task status notification must not be null") + .build()); + } + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TASKS_STATUS, + taskStatusNotification); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index a15c58cd5..5a9a2ed29 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -5,8 +5,15 @@ package io.modelcontextprotocol.server; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.experimental.tasks.QueuedMessage; +import io.modelcontextprotocol.experimental.tasks.TaskDefaults; +import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; +import java.util.Map; +import java.util.UUID; import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; @@ -16,6 +23,9 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -27,6 +37,8 @@ */ public class McpAsyncServerExchange { + private static final Logger logger = LoggerFactory.getLogger(McpAsyncServerExchange.class); + private final String sessionId; private final McpLoggableSession session; @@ -37,6 +49,18 @@ public class McpAsyncServerExchange { private final McpTransportContext transportContext; + /** + * The current task ID if this exchange is executing within a task context. This is + * set when handling task-augmented tool calls and allows the tool handler to access + * the task ID for operations like elicitation-during-task. + */ + private final String currentTaskId; + + /** + * Optional message queue for enqueuing messages during task execution. + */ + private final TaskMessageQueue taskMessageQueue; + private static final TypeRef CREATE_MESSAGE_RESULT_TYPE_REF = new TypeRef<>() { }; @@ -49,6 +73,50 @@ public class McpAsyncServerExchange { public static final TypeRef OBJECT_TYPE_REF = new TypeRef<>() { }; + private static final TypeRef GET_TASK_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef CREATE_TASK_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef LIST_TASKS_RESULT_TYPE_REF = new TypeRef<>() { + }; + + private static final TypeRef CANCEL_TASK_RESULT_TYPE_REF = new TypeRef<>() { + }; + + /** + * The default poll interval in milliseconds for task status polling when the client + * does not specify one. + */ + private static final long DEFAULT_TASK_POLL_INTERVAL_MS = TaskDefaults.DEFAULT_POLL_INTERVAL_MS; + + /** + * Default number of maximum poll attempts before timing out. Used with poll interval + * to calculate dynamic timeouts. + */ + private static final int DEFAULT_MAX_POLL_ATTEMPTS = 60; + + /** + * Maximum timeout in milliseconds (1 hour). This prevents unbounded timeouts when + * tasks specify very large poll intervals. + */ + private static final long MAX_TIMEOUT_MS = 3600000L; + + /** + * Calculates timeout based on poll interval. This provides reasonable timeouts that + * scale with the polling frequency: 500ms poll interval = 30s timeout, 5000ms = 5 min + * timeout. The result is capped at {@link #MAX_TIMEOUT_MS} to prevent unbounded + * timeouts. + * @param pollInterval the poll interval in milliseconds + * @return the calculated timeout duration, capped at 1 hour + */ + private static Duration calculateTimeout(Long pollInterval) { + long interval = pollInterval != null ? pollInterval : DEFAULT_TASK_POLL_INTERVAL_MS; + long calculatedMs = interval * DEFAULT_MAX_POLL_ATTEMPTS; + return Duration.ofMillis(Math.min(calculatedMs, MAX_TIMEOUT_MS)); + } + /** * Create a new asynchronous exchange with the client. * @param session The server session representing a 1-1 interaction. @@ -69,6 +137,8 @@ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities c this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; this.transportContext = McpTransportContext.EMPTY; + this.currentTaskId = null; + this.taskMessageQueue = null; } /** @@ -83,11 +153,76 @@ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities c public McpAsyncServerExchange(String sessionId, McpLoggableSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, McpTransportContext transportContext) { + this(sessionId, session, clientCapabilities, clientInfo, transportContext, null, null); + } + + /** + * Create a new asynchronous exchange with the client and a task message queue. + * @param sessionId The session ID. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + * @param transportContext context associated with the client as extracted from the + * transport + * @param taskMessageQueue Optional message queue for task message enqueuing + */ + public McpAsyncServerExchange(String sessionId, McpLoggableSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpTransportContext transportContext, TaskMessageQueue taskMessageQueue) { + this(sessionId, session, clientCapabilities, clientInfo, transportContext, null, taskMessageQueue); + } + + /** + * Create a new asynchronous exchange with the client, optionally within a task + * context. + * @param sessionId The session ID. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + * @param transportContext context associated with the client as extracted from the + * transport + * @param currentTaskId The current task ID if executing within a task context, null + * otherwise + * @param taskMessageQueue Optional message queue for task message enqueuing + */ + private McpAsyncServerExchange(String sessionId, McpLoggableSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpTransportContext transportContext, String currentTaskId, TaskMessageQueue taskMessageQueue) { this.sessionId = sessionId; this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; this.transportContext = transportContext; + this.currentTaskId = currentTaskId; + this.taskMessageQueue = taskMessageQueue; + } + + /** + * Create a new exchange that is scoped to a specific task. This is used when + * executing tool handlers in a task context, allowing them to access the task ID for + * operations like elicitation-during-task. + * @param taskId The task ID to scope this exchange to + * @return A new exchange instance with the task context set + */ + public McpAsyncServerExchange withTaskContext(String taskId) { + Assert.notNull(taskId, "Task ID must not be null"); + return new McpAsyncServerExchange(this.sessionId, this.session, this.clientCapabilities, this.clientInfo, + this.transportContext, taskId, this.taskMessageQueue); + } + + /** + * Create a new exchange that is scoped to a specific task with an explicit message + * queue. + * @param taskId The task ID to scope this exchange to + * @param queue The task message queue for enqueuing messages + * @return A new exchange instance with the task context and queue set + */ + public McpAsyncServerExchange withTaskContext(String taskId, TaskMessageQueue queue) { + Assert.notNull(taskId, "Task ID must not be null"); + return new McpAsyncServerExchange(this.sessionId, this.session, this.clientCapabilities, this.clientInfo, + this.transportContext, taskId, queue); } /** @@ -124,10 +259,19 @@ public String sessionId() { return this.sessionId; } + /** + * Get the current task ID if this exchange is executing within a task context. This + * is set when handling task-augmented tool calls. + * @return the current task ID, or null if not executing within a task context + */ + public String getCurrentTaskId() { + return this.currentTaskId; + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. This + * sampling ("completions" or "generations") from language models via clients. This * flow allows clients to maintain control over model access, selection, and * permissions while enabling servers to leverage AI capabilities—with no server API * keys necessary. Servers can request text or image-based interactions and optionally @@ -147,8 +291,60 @@ public Mono createMessage(McpSchema.CreateMessage if (this.clientCapabilities.sampling() == null) { return Mono.error(new McpError("Client must be configured with sampling capabilities")); } - return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, - CREATE_MESSAGE_RESULT_TYPE_REF); + + // Generate a request ID for tracking + String requestId = UUID.randomUUID().toString(); + + // Add related task metadata to request if within task context + McpSchema.CreateMessageRequest requestWithMeta = createMessageRequest; + if (this.currentTaskId != null) { + Map meta = new java.util.HashMap<>(); + if (createMessageRequest.meta() != null) { + meta.putAll(createMessageRequest.meta()); + } + meta.put(McpSchema.RELATED_TASK_META_KEY, + McpSchema.RelatedTaskMetadata.builder().taskId(this.currentTaskId).build()); + requestWithMeta = new McpSchema.CreateMessageRequest(createMessageRequest.messages(), + createMessageRequest.modelPreferences(), createMessageRequest.systemPrompt(), + createMessageRequest.includeContext(), createMessageRequest.temperature(), + createMessageRequest.maxTokens(), createMessageRequest.stopSequences(), + createMessageRequest.metadata(), createMessageRequest.task(), meta); + } + + // Enqueue request if within task context + Mono enqueueRequest = Mono.empty(); + final McpSchema.CreateMessageRequest finalRequest = requestWithMeta; + if (this.currentTaskId != null && this.taskMessageQueue != null) { + QueuedMessage.Request queuedRequest = new QueuedMessage.Request(requestId, + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, finalRequest); + enqueueRequest = this.taskMessageQueue.enqueue(this.currentTaskId, queuedRequest, null) + .doOnError(e -> logger.error("Failed to enqueue sampling request for task {}: {}", this.currentTaskId, + e.getMessage())) + // Message queue failures should not fail the main sampling operation. + // Errors are logged above; swallowing ensures the primary request + // succeeds. + .onErrorResume(e -> Mono.empty()); + } + + return enqueueRequest + .then(this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, finalRequest, + CREATE_MESSAGE_RESULT_TYPE_REF)) + .flatMap(result -> { + // Enqueue response if within task context + if (this.currentTaskId != null && this.taskMessageQueue != null) { + QueuedMessage.Response queuedResponse = new QueuedMessage.Response(requestId, result); + return this.taskMessageQueue.enqueue(this.currentTaskId, queuedResponse, null) + .doOnError(e -> logger.error("Failed to enqueue sampling response for task {}: {}", + this.currentTaskId, e.getMessage())) + // Message queue failures should not fail the main sampling + // operation. + // Errors are logged above; swallowing ensures the primary request + // succeeds. + .onErrorResume(e -> Mono.empty()) + .thenReturn(result); + } + return Mono.just(result); + }); } /** @@ -172,8 +368,57 @@ public Mono createElicitation(McpSchema.ElicitRequest el if (this.clientCapabilities.elicitation() == null) { return Mono.error(new McpError("Client must be configured with elicitation capabilities")); } - return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, - ELICITATION_RESULT_TYPE_REF); + + // Generate a request ID for tracking + String requestId = UUID.randomUUID().toString(); + + // Add related task metadata to request if within task context + McpSchema.ElicitRequest requestWithMeta = elicitRequest; + if (this.currentTaskId != null) { + Map meta = new java.util.HashMap<>(); + if (elicitRequest.meta() != null) { + meta.putAll(elicitRequest.meta()); + } + meta.put(McpSchema.RELATED_TASK_META_KEY, + McpSchema.RelatedTaskMetadata.builder().taskId(this.currentTaskId).build()); + requestWithMeta = new McpSchema.ElicitRequest(elicitRequest.message(), elicitRequest.requestedSchema(), + elicitRequest.task(), meta); + } + + // Enqueue request if within task context + Mono enqueueRequest = Mono.empty(); + final McpSchema.ElicitRequest finalRequest = requestWithMeta; + if (this.currentTaskId != null && this.taskMessageQueue != null) { + QueuedMessage.Request queuedRequest = new QueuedMessage.Request(requestId, + McpSchema.METHOD_ELICITATION_CREATE, finalRequest); + enqueueRequest = this.taskMessageQueue.enqueue(this.currentTaskId, queuedRequest, null) + .doOnError(e -> logger.error("Failed to enqueue elicitation request for task {}: {}", + this.currentTaskId, e.getMessage())) + // Message queue failures should not fail the main elicitation operation. + // Errors are logged above; swallowing ensures the primary request + // succeeds. + .onErrorResume(e -> Mono.empty()); + } + + return enqueueRequest + .then(this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, finalRequest, + ELICITATION_RESULT_TYPE_REF)) + .flatMap(result -> { + // Enqueue response if within task context + if (this.currentTaskId != null && this.taskMessageQueue != null) { + QueuedMessage.Response queuedResponse = new QueuedMessage.Response(requestId, result); + return this.taskMessageQueue.enqueue(this.currentTaskId, queuedResponse, null) + .doOnError(e -> logger.error("Failed to enqueue elicitation response for task {}: {}", + this.currentTaskId, e.getMessage())) + // Message queue failures should not fail the main elicitation + // operation. + // Errors are logged above; swallowing ensures the primary request + // succeeds. + .onErrorResume(e -> Mono.empty()) + .thenReturn(result); + } + return Mono.just(result); + }); } /** @@ -239,9 +484,32 @@ public Mono progressNotification(McpSchema.ProgressNotification progressNo return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification); } + /** + * Sends a task status notification to THIS client only. + * + *

+ * This method sends a notification to the specific client associated with this + * exchange. Use this for targeted notifications when a tool handler needs to update a + * specific client about task progress. + * + *

+ * For broadcasting task status to ALL connected clients, use + * {@link McpAsyncServer#notifyTaskStatus(McpSchema.TaskStatusNotification)} instead. + * @param notification The task status notification to send + * @return A Mono that completes when the notification has been sent + * @see McpAsyncServer#notifyTaskStatus(McpSchema.TaskStatusNotification) for + * broadcasting to all clients + */ + public Mono notifyTaskStatus(McpSchema.TaskStatusNotification notification) { + if (notification == null) { + return Mono.error(new IllegalStateException("Task status notification must not be null")); + } + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_TASKS_STATUS, notification); + } + /** * Sends a ping request to the client. - * @return A Mono that completes with clients's ping response + * @return A Mono that completes with the client's ping response */ public Mono ping() { return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF); @@ -257,4 +525,473 @@ void setMinLoggingLevel(LoggingLevel minLoggingLevel) { this.session.setMinLoggingLevel(minLoggingLevel); } + // -------------------------- + // Client Task Operations + // -------------------------- + + /** + * Get the status of a task hosted by the client. This is used when the server has + * sent a task-augmented request to the client and needs to poll for status updates. + * @param getTaskRequest The request containing the task ID + * @return A Mono that emits the task status + */ + public Mono getTask(McpSchema.GetTaskRequest getTaskRequest) { + if (this.clientCapabilities == null) { + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.tasks() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_TASKS_GET, getTaskRequest, GET_TASK_RESULT_TYPE_REF); + } + + /** + * Get the result of a completed task hosted by the client. + * @param The expected result type + * @param getTaskPayloadRequest The request containing the task ID + * @param resultTypeRef Type reference for deserializing the result + * @return A Mono that emits the task result + */ + public Mono getTaskResult( + McpSchema.GetTaskPayloadRequest getTaskPayloadRequest, TypeRef resultTypeRef) { + if (this.clientCapabilities == null) { + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.tasks() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_TASKS_RESULT, getTaskPayloadRequest, resultTypeRef); + } + + /** + * List all tasks hosted by the client. + * + *

+ * This method automatically handles pagination, fetching all pages and combining them + * into a single result with an unmodifiable list. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @return A Mono that emits the list of all client tasks + */ + public Mono listTasks() { + return this.listTasks(McpSchema.FIRST_PAGE).expand(result -> { + String next = result.nextCursor(); + return (next != null && !next.isEmpty()) ? this.listTasks(next) : Mono.empty(); + }).reduce(McpSchema.ListTasksResult.builder().tasks(new ArrayList<>()).build(), (allTasksResult, result) -> { + allTasksResult.tasks().addAll(result.tasks()); + return allTasksResult; + }) + .map(result -> McpSchema.ListTasksResult.builder() + .tasks(Collections.unmodifiableList(result.tasks())) + .build()); + } + + /** + * List tasks hosted by the client with pagination support. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cursor Pagination cursor from a previous list request + * @return A Mono that emits a page of client tasks + */ + public Mono listTasks(String cursor) { + if (this.clientCapabilities == null) { + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.tasks() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks capabilities")); + } + if (this.clientCapabilities.tasks().list() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks.list capability")); + } + return this.session.sendRequest(McpSchema.METHOD_TASKS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_TASKS_RESULT_TYPE_REF); + } + + /** + * Request cancellation of a task hosted by the client. + * + *

+ * Note that cancellation is cooperative - the client may not honor the cancellation + * request, or may take some time to cancel the task. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cancelTaskRequest The request containing the task ID + * @return A Mono that emits the updated task status + */ + public Mono cancelTask(McpSchema.CancelTaskRequest cancelTaskRequest) { + if (this.clientCapabilities == null) { + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.tasks() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks capabilities")); + } + if (this.clientCapabilities.tasks().cancel() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks.cancel capability")); + } + return this.session.sendRequest(McpSchema.METHOD_TASKS_CANCEL, cancelTaskRequest, CANCEL_TASK_RESULT_TYPE_REF); + } + + /** + * Request cancellation of a task hosted by the client by task ID. + * + *

+ * This is a convenience overload that creates a {@link McpSchema.CancelTaskRequest} + * with the given task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param taskId The task identifier to cancel + * @return A Mono that emits the updated task status + */ + public Mono cancelTask(String taskId) { + Assert.hasText(taskId, "Task ID must not be null or empty"); + return cancelTask(McpSchema.CancelTaskRequest.builder().taskId(taskId).build()); + } + + // -------------------------- + // Task-Augmented Sampling + // -------------------------- + + /** + * Low-level method to create a new message using task-augmented sampling. The client + * will process the request as a long-running task, allowing the server to poll for + * status updates. + * + *

+ * Recommendation: For most use cases, prefer + * {@link #createMessageStream} which provides a unified streaming interface that + * handles both regular and task-augmented sampling automatically, including polling + * and result retrieval. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param createMessageRequest The request to create a new message (must have task + * metadata) + * @return A Mono that emits the task creation result + * @see #createMessageStream + */ + public Mono createMessageTask(McpSchema.CreateMessageRequest createMessageRequest) { + if (createMessageRequest.task() == null) { + return Mono.error(new IllegalArgumentException( + "Task metadata is required for task-augmented sampling. Use createMessage() for regular requests.")); + } + if (this.clientCapabilities == null) { + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new IllegalStateException("Client must be configured with sampling capabilities")); + } + if (this.clientCapabilities.tasks() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_TASK_RESULT_TYPE_REF); + } + + /** + * Create a message and return a stream of response messages, handling both regular + * and task-augmented requests automatically. + * + *

+ * This method provides a unified streaming interface for sampling: + *

    + *
  • For non-task requests (when {@code task} field is null): + * yields a single {@link McpSchema.ResultMessage} or {@link McpSchema.ErrorMessage} + *
  • For task-augmented requests: yields + * {@link McpSchema.TaskCreatedMessage} → zero or more + * {@link McpSchema.TaskStatusMessage} → {@link McpSchema.ResultMessage} or + * {@link McpSchema.ErrorMessage} + *
+ * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param createMessageRequest The request containing the sampling parameters. If the + * {@code task} field is set, the call will be task-augmented. + * @return A Flux that emits {@link McpSchema.ResponseMessage} instances + */ + public Flux> createMessageStream( + McpSchema.CreateMessageRequest createMessageRequest) { + // For non-task requests, just wrap the result in a single message + if (createMessageRequest.task() == null) { + return this + .createMessage(createMessageRequest).>map(McpSchema.ResultMessage::of) + .onErrorResume(error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(error.getMessage()).build(); + return Mono.just(McpSchema.ErrorMessage.of(mcpError)); + }) + .flux(); + } + + // For task-augmented requests, handle the full lifecycle + return Flux.create(sink -> { + this.createMessageTask(createMessageRequest).subscribe(createResult -> { + McpSchema.Task task = createResult.task(); + if (task == null) { + sink.error(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task creation did not return a task") + .build()); + return; + } + + // Emit taskCreated message + sink.next(McpSchema.TaskCreatedMessage.of(task)); + + // Start polling for task status + pollTaskUntilTerminal(task.taskId(), sink, Instant.now(), CREATE_MESSAGE_RESULT_TYPE_REF); + }, error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(error.getMessage()).build(); + sink.next(McpSchema.ErrorMessage.of(mcpError)); + sink.complete(); + }); + }); + } + + // -------------------------- + // Task-Augmented Elicitation + // -------------------------- + + /** + * Low-level method to create a new elicitation using task-augmented processing. The + * client will process the request as a long-running task, allowing the server to poll + * for status updates. + * + *

+ * Recommendation: For most use cases, prefer + * {@link #createElicitationStream} which provides a unified streaming interface that + * handles both regular and task-augmented elicitation automatically, including + * polling and result retrieval. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param elicitRequest The elicitation request (must have task metadata) + * @return A Mono that emits the task creation result + * @see #createElicitationStream + */ + public Mono createElicitationTask(McpSchema.ElicitRequest elicitRequest) { + if (elicitRequest.task() == null) { + return Mono.error(new IllegalArgumentException( + "Task metadata is required for task-augmented elicitation. Use createElicitation() for regular requests.")); + } + if (this.clientCapabilities == null) { + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null) { + return Mono.error(new IllegalStateException("Client must be configured with elicitation capabilities")); + } + if (this.clientCapabilities.tasks() == null) { + return Mono.error(new IllegalStateException("Client must be configured with tasks capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, + CREATE_TASK_RESULT_TYPE_REF); + } + + /** + * Create an elicitation and return a stream of response messages, handling both + * regular and task-augmented requests automatically. + * + *

+ * This method provides a unified streaming interface for elicitation: + *

    + *
  • For non-task requests (when {@code task} field is null): + * yields a single {@link McpSchema.ResultMessage} or {@link McpSchema.ErrorMessage} + *
  • For task-augmented requests: yields + * {@link McpSchema.TaskCreatedMessage} → zero or more + * {@link McpSchema.TaskStatusMessage} → {@link McpSchema.ResultMessage} or + * {@link McpSchema.ErrorMessage} + *
+ * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param elicitRequest The request containing the elicitation parameters. If the + * {@code task} field is set, the call will be task-augmented. + * @return A Flux that emits {@link McpSchema.ResponseMessage} instances + */ + public Flux> createElicitationStream( + McpSchema.ElicitRequest elicitRequest) { + // For non-task requests, just wrap the result in a single message + if (elicitRequest.task() == null) { + return this + .createElicitation(elicitRequest).>map(McpSchema.ResultMessage::of) + .onErrorResume(error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(error.getMessage()).build(); + return Mono.just(McpSchema.ErrorMessage.of(mcpError)); + }) + .flux(); + } + + // For task-augmented requests, handle the full lifecycle + return Flux.create(sink -> { + this.createElicitationTask(elicitRequest).subscribe(createResult -> { + McpSchema.Task task = createResult.task(); + if (task == null) { + sink.error(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task creation did not return a task") + .build()); + return; + } + + // Emit taskCreated message + sink.next(McpSchema.TaskCreatedMessage.of(task)); + + // Start polling for task status + pollTaskUntilTerminal(task.taskId(), sink, Instant.now(), ELICITATION_RESULT_TYPE_REF); + }, error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(error.getMessage()).build(); + sink.next(McpSchema.ErrorMessage.of(mcpError)); + sink.complete(); + }); + }); + } + + // -------------------------- + // Task Polling Helpers + // -------------------------- + + /** + * Polls client task status until it reaches a terminal state, emitting status updates + * and final result. + * + *

+ * Uses proper reactive composition with {@link Flux#interval} and + * {@link Flux#takeUntil} to avoid unbounded subscription chains from recursive + * subscribe patterns. The timeout is dynamically calculated based on the task's poll + * interval. + */ + private void pollTaskUntilTerminal(String taskId, + reactor.core.publisher.FluxSink> sink, Instant startTime, + TypeRef resultTypeRef) { + + // First fetch to get initial state and poll interval + this.getTask(McpSchema.GetTaskRequest.builder().taskId(taskId).build()).subscribe(initialResult -> { + McpSchema.Task initialTask = initialResult.toTask(); + + // Emit initial status + sink.next(McpSchema.TaskStatusMessage.of(initialTask)); + + // Handle already terminal task + if (initialTask.isTerminal()) { + handleTerminalTask(taskId, initialTask, sink, resultTypeRef); + return; + } + + // Handle INPUT_REQUIRED - fetch result which blocks until terminal + if (initialTask.status() == McpSchema.TaskStatus.INPUT_REQUIRED) { + fetchTaskResultAndComplete(taskId, sink, resultTypeRef); + return; + } + + // Set up polling using proper reactive composition + long pollInterval = initialTask.pollInterval() != null ? initialTask.pollInterval() + : DEFAULT_TASK_POLL_INTERVAL_MS; + Duration timeout = calculateTimeout(pollInterval); + + // Use Flux.interval + takeUntil instead of recursive subscribe + reactor.core.Disposable pollSubscription = Flux.interval(Duration.ofMillis(pollInterval)) + .flatMap(tick -> getTask(McpSchema.GetTaskRequest.builder().taskId(taskId).build())) + .takeUntil(taskResult -> { + McpSchema.Task task = taskResult.toTask(); + // Emit status update for each poll + sink.next(McpSchema.TaskStatusMessage.of(task)); + // Stop when terminal or input_required + return task.isTerminal() || task.status() == McpSchema.TaskStatus.INPUT_REQUIRED; + }) + .timeout(timeout) + .last() + .subscribe(finalResult -> { + McpSchema.Task task = finalResult.toTask(); + if (task.isTerminal()) { + handleTerminalTask(taskId, task, sink, resultTypeRef); + } + else if (task.status() == McpSchema.TaskStatus.INPUT_REQUIRED) { + fetchTaskResultAndComplete(taskId, sink, resultTypeRef); + } + }, error -> { + String errorMsg = error.getMessage() != null ? error.getMessage() + : "Task polling failed: " + error.getClass().getSimpleName(); + if (error instanceof java.util.concurrent.TimeoutException) { + errorMsg = "Task polling timed out after " + timeout; + } + sink.next(McpSchema.ErrorMessage + .of(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(errorMsg).build())); + sink.complete(); + }); + + // Register disposal handler for proper cleanup when sink is cancelled + sink.onDispose(pollSubscription); + + }, error -> { + String errorMsg = error.getMessage() != null ? error.getMessage() + : "Failed to get task: " + error.getClass().getSimpleName(); + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(errorMsg).build(); + sink.next(McpSchema.ErrorMessage.of(mcpError)); + sink.complete(); + }); + } + + /** + * Handles a client task that has reached a terminal state. + */ + private void handleTerminalTask(String taskId, McpSchema.Task task, + reactor.core.publisher.FluxSink> sink, TypeRef resultTypeRef) { + if (task.status() == McpSchema.TaskStatus.COMPLETED) { + fetchTaskResultAndComplete(taskId, sink, resultTypeRef); + } + else if (task.status() == McpSchema.TaskStatus.FAILED) { + String message = task.statusMessage() != null ? task.statusMessage() : "Task " + taskId + " failed"; + sink.next(McpSchema.ErrorMessage + .of(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(message).build())); + sink.complete(); + } + else if (task.status() == McpSchema.TaskStatus.CANCELLED) { + sink.next(McpSchema.ErrorMessage.of(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task " + taskId + " was cancelled") + .build())); + sink.complete(); + } + else { + sink.complete(); + } + } + + /** + * Fetches the client task result and completes the stream. + */ + private void fetchTaskResultAndComplete(String taskId, + reactor.core.publisher.FluxSink> sink, TypeRef resultTypeRef) { + this.getTaskResult(McpSchema.GetTaskPayloadRequest.builder().taskId(taskId).build(), resultTypeRef) + .subscribe(result -> { + sink.next(McpSchema.ResultMessage.of(result)); + sink.complete(); + }, error -> { + McpError mcpError = (error instanceof McpError) ? (McpError) error + : McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR).message(error.getMessage()).build(); + sink.next(McpSchema.ErrorMessage.of(mcpError)); + sink.complete(); + }); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java index fe3125271..e78df76f1 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -14,6 +14,10 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; +import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskAwareSyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue; +import io.modelcontextprotocol.experimental.tasks.TaskStore; import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.json.schema.JsonSchemaValidator; @@ -235,15 +239,17 @@ private SingleSessionAsyncSpecification(McpServerTransportProvider transportProv */ @Override public McpAsyncServer build() { + validateTaskConfiguration(); + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); + this.taskTools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions, this.taskStore, this.taskMessageQueue); var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator : JsonSchemaValidator.getDefault(); return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, automaticPollingTimeout); } } @@ -263,13 +269,15 @@ public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider t */ @Override public McpAsyncServer build() { + validateTaskConfiguration(); + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); + this.taskTools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions, this.taskStore, this.taskMessageQueue); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : JsonSchemaValidator.getDefault(); return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, automaticPollingTimeout); } } @@ -300,6 +308,12 @@ abstract class AsyncSpecification> { */ final List tools = new ArrayList<>(); + /** + * Task-aware tools that support long-running operations via task-augmented + * execution (SEP-1686). This is an experimental feature. + */ + final List taskTools = new ArrayList<>(); + /** * The Model Context Protocol (MCP) provides a standardized way for servers to * expose resources to clients. Resources allow servers to share data that @@ -331,10 +345,44 @@ abstract class AsyncSpecification> { final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - Duration requestTimeout = Duration.ofHours(10); // Default timeout + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + /** + * The task store for managing long-running tasks. This is an experimental + * feature. + */ + TaskStore taskStore; + + /** + * The message queue for task communication. This is an experimental feature. + */ + TaskMessageQueue taskMessageQueue; + + /** + * The timeout for automatic task polling. This is an experimental feature. + */ + Duration automaticPollingTimeout; public abstract McpAsyncServer build(); + /** + * Validates task configuration. Task-aware tools require a TaskStore to be + * configured for task lifecycle management. + * @throws IllegalStateException if task-aware tools are registered without a + * TaskStore + */ + protected void validateTaskConfiguration() { + boolean hasTaskTools = !this.taskTools.isEmpty(); + boolean hasTaskStore = this.taskStore != null; + + if (hasTaskTools && !hasTaskStore) { + throw new IllegalStateException("Task-aware tools registered but no TaskStore configured. " + + "Add a TaskStore via .taskStore(store) or remove task tools."); + } + // Note: Having taskStore without taskTools is allowed (for future dynamic + // registration) + } + /** * Sets the URI template manager factory to use for creating URI templates. This * allows for custom URI template parsing and variable extraction. @@ -427,6 +475,63 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCap return this; } + /** + * Sets the task store for managing long-running tasks. This enables support for + * MCP Tasks, allowing servers to create tasks that clients can poll for status + * and results. + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param taskStore The task store implementation. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStore is null + */ + public AsyncSpecification taskStore(TaskStore taskStore) { + Assert.notNull(taskStore, "Task store must not be null"); + this.taskStore = taskStore; + return this; + } + + /** + * Sets the message queue for task communication. This enables servers to queue + * messages (like elicitation requests) during task execution that clients can + * retrieve via streaming. + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param taskMessageQueue The message queue implementation. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskMessageQueue is null + */ + public AsyncSpecification taskMessageQueue(TaskMessageQueue taskMessageQueue) { + Assert.notNull(taskMessageQueue, "Task message queue must not be null"); + this.taskMessageQueue = taskMessageQueue; + return this; + } + + /** + * Sets the maximum timeout for automatic task polling. When a task-enabled tool + * is called without task metadata, the server creates a task internally and polls + * until completion. This timeout prevents indefinite polling. + * + *

+ * If not set, defaults to 30 minutes. + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param automaticPollingTimeout The maximum polling timeout. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if automaticPollingTimeout is null + */ + public AsyncSpecification automaticPollingTimeout(Duration automaticPollingTimeout) { + Assert.notNull(automaticPollingTimeout, "Automatic polling timeout must not be null"); + this.automaticPollingTimeout = automaticPollingTimeout; + return this; + } + /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a @@ -539,10 +644,79 @@ public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... t return this; } + /** + * Adds multiple task-aware tools with their handlers to the server using a List. + * Task-aware tools support long-running operations via task-augmented execution + * (SEP-1686). + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param taskToolSpecifications The list of task-aware tool specifications to + * add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskToolSpecifications is null + * @see #taskTools(TaskAwareAsyncToolSpecification...) + */ + public AsyncSpecification taskTools(List taskToolSpecifications) { + Assert.notNull(taskToolSpecifications, "Task tool specifications list must not be null"); + + for (var taskTool : taskToolSpecifications) { + assertNoDuplicateTool(taskTool.tool().name()); + this.taskTools.add(taskTool); + } + return this; + } + + /** + * Adds multiple task-aware tools with their handlers to the server using varargs. + * Task-aware tools support long-running operations via task-augmented execution + * (SEP-1686). + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * + *

+ * Example usage: + * + *

{@code
+		 * .taskTools(
+		 *     TaskAwareAsyncToolSpecification.builder()
+		 *         .name("long-computation")
+		 *         .description("A long-running computation")
+		 *         .createTask((args, extra) -> {
+		 *             return extra.taskStore().createTask(...)
+		 *                 .flatMap(task -> {
+		 *                     doWork(task.taskId(), args).subscribe();
+		 *                     return Mono.just(new CreateTaskResult(task, null));
+		 *                 });
+		 *         })
+		 *         .build()
+		 * )
+		 * }
+ * @param taskToolSpecifications The task-aware tool specifications to add. Must + * not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskToolSpecifications is null + */ + public AsyncSpecification taskTools(TaskAwareAsyncToolSpecification... taskToolSpecifications) { + Assert.notNull(taskToolSpecifications, "Task tool specifications must not be null"); + + for (var taskTool : taskToolSpecifications) { + assertNoDuplicateTool(taskTool.tool().name()); + this.taskTools.add(taskTool); + } + return this; + } + private void assertNoDuplicateTool(String toolName) { if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); } + if (this.taskTools.stream().anyMatch(taskToolSpec -> taskToolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } } /** @@ -827,16 +1001,19 @@ private SingleSessionSyncSpecification(McpServerTransportProvider transportProvi */ @Override public McpSyncServer build() { + validateTaskConfiguration(); + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, - this.rootsChangeHandlers, this.instructions); + this.tools, this.taskTools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions, this.taskStore, this.taskMessageQueue); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); var asyncServer = new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, requestTimeout, uriTemplateManagerFactory, - jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault()); + jsonSchemaValidator != null ? jsonSchemaValidator : JsonSchemaValidator.getDefault(), + automaticPollingTimeout); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -858,16 +1035,18 @@ private StreamableSyncSpecification(McpStreamableServerTransportProvider transpo */ @Override public McpSyncServer build() { + validateTaskConfiguration(); + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, - this.rootsChangeHandlers, this.instructions); + this.tools, this.taskTools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions, this.taskStore, this.taskMessageQueue); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : JsonSchemaValidator.getDefault(); var asyncServer = new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + this.uriTemplateManagerFactory, jsonSchemaValidator, automaticPollingTimeout); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -897,6 +1076,12 @@ abstract class SyncSpecification> { */ final List tools = new ArrayList<>(); + /** + * Task-aware tools that support long-running operations via task-augmented + * execution (SEP-1686). This is an experimental feature. + */ + final List taskTools = new ArrayList<>(); + /** * The Model Context Protocol (MCP) provides a standardized way for servers to * expose resources to clients. Resources allow servers to share data that @@ -934,8 +1119,42 @@ abstract class SyncSpecification> { boolean immediateExecution = false; + /** + * The task store for managing long-running tasks. This is an experimental + * feature. + */ + TaskStore taskStore; + + /** + * The message queue for task communication. This is an experimental feature. + */ + TaskMessageQueue taskMessageQueue; + + /** + * The timeout for automatic task polling. This is an experimental feature. + */ + Duration automaticPollingTimeout; + public abstract McpSyncServer build(); + /** + * Validates task configuration. Task-aware tools require a TaskStore to be + * configured for task lifecycle management. + * @throws IllegalStateException if task-aware tools are registered without a + * TaskStore + */ + protected void validateTaskConfiguration() { + boolean hasTaskTools = !this.taskTools.isEmpty(); + boolean hasTaskStore = this.taskStore != null; + + if (hasTaskTools && !hasTaskStore) { + throw new IllegalStateException("Task-aware tools registered but no TaskStore configured. " + + "Add a TaskStore via .taskStore(store) or remove task tools."); + } + // Note: Having taskStore without taskTools is allowed (for future dynamic + // registration) + } + /** * Sets the URI template manager factory to use for creating URI templates. This * allows for custom URI template parsing and variable extraction. @@ -1028,6 +1247,63 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapa return this; } + /** + * Sets the task store for managing long-running tasks. This enables support for + * MCP Tasks, allowing servers to create tasks that clients can poll for status + * and results. + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param taskStore The task store implementation. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskStore is null + */ + public SyncSpecification taskStore(TaskStore taskStore) { + Assert.notNull(taskStore, "Task store must not be null"); + this.taskStore = taskStore; + return this; + } + + /** + * Sets the message queue for task communication. This enables support for + * input_required task state, allowing tasks to queue messages while waiting for + * user input (elicitation). + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param taskMessageQueue The message queue implementation. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskMessageQueue is null + */ + public SyncSpecification taskMessageQueue(TaskMessageQueue taskMessageQueue) { + Assert.notNull(taskMessageQueue, "Task message queue must not be null"); + this.taskMessageQueue = taskMessageQueue; + return this; + } + + /** + * Sets the maximum timeout for automatic task polling. When a task-enabled tool + * is called without task metadata, the server creates a task internally and polls + * until completion. This timeout prevents indefinite polling. + * + *

+ * If not set, defaults to 30 minutes. + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param automaticPollingTimeout The maximum polling timeout. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if automaticPollingTimeout is null + */ + public SyncSpecification automaticPollingTimeout(Duration automaticPollingTimeout) { + Assert.notNull(automaticPollingTimeout, "Automatic polling timeout must not be null"); + this.automaticPollingTimeout = automaticPollingTimeout; + return this; + } + /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a @@ -1139,10 +1415,77 @@ public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... too return this; } + /** + * Adds multiple task-aware tools with their handlers to the server using a List. + * Task-aware tools support long-running operations via task-augmented execution + * (SEP-1686). + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * @param taskToolSpecifications The list of task-aware tool specifications to + * add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskToolSpecifications is null + * @see #taskTools(TaskAwareSyncToolSpecification...) + */ + public SyncSpecification taskTools(List taskToolSpecifications) { + Assert.notNull(taskToolSpecifications, "Task tool specifications list must not be null"); + + for (var taskTool : taskToolSpecifications) { + assertNoDuplicateTool(taskTool.tool().name()); + this.taskTools.add(taskTool); + } + return this; + } + + /** + * Adds multiple task-aware tools with their handlers to the server using varargs. + * Task-aware tools support long-running operations via task-augmented execution + * (SEP-1686). + * + *

+ * Note: This is an experimental feature that may change in + * future releases. + * + *

+ * Example usage: + * + *

{@code
+		 * .taskTools(
+		 *     TaskAwareSyncToolSpecification.builder()
+		 *         .name("long-computation")
+		 *         .description("A long-running computation")
+		 *         .createTask((args, extra) -> {
+		 *             Task task = extra.taskStore().createTask(...).block();
+		 *             startBackgroundWork(task.taskId(), args);
+		 *             return new CreateTaskResult(task, null);
+		 *         })
+		 *         .build()
+		 * )
+		 * }
+ * @param taskToolSpecifications The task-aware tool specifications to add. Must + * not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if taskToolSpecifications is null + */ + public SyncSpecification taskTools(TaskAwareSyncToolSpecification... taskToolSpecifications) { + Assert.notNull(taskToolSpecifications, "Task tool specifications must not be null"); + + for (var taskTool : taskToolSpecifications) { + assertNoDuplicateTool(taskTool.tool().name()); + this.taskTools.add(taskTool); + } + return this; + } + private void assertNoDuplicateTool(String toolName) { if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); } + if (this.taskTools.stream().anyMatch(taskToolSpec -> taskToolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } } /** diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index fe0608b1c..01086de1d 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -11,6 +11,10 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; +import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskAwareSyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue; +import io.modelcontextprotocol.experimental.tasks.TaskStore; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.util.Assert; @@ -32,40 +36,50 @@ public class McpServerFeatures { * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities * @param tools The list of tool specifications + * @param taskTools The list of task-aware tool specifications (experimental) * @param resources The map of resource specifications * @param resourceTemplates The list of resource templates * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when the * roots list changes * @param instructions The server instructions text + * @param taskStore The task store for managing long-running tasks (experimental) + * @param taskMessageQueue The message queue for task communication (experimental) */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, + List tools, List taskTools, + Map resources, Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, - String instructions) { + String instructions, TaskStore taskStore, + TaskMessageQueue taskMessageQueue) { /** * Create an instance and validate the arguments. * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities * @param tools The list of tool specifications + * @param taskTools The list of task-aware tool specifications (experimental) * @param resources The map of resource specifications * @param resourceTemplates The map of resource templates * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes * @param instructions The server instructions text + * @param taskStore The task store for managing long-running tasks (experimental) + * @param taskMessageQueue The message queue for task communication (experimental) */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, + List tools, List taskTools, + Map resources, Map resourceTemplates, Map prompts, Map completions, List, Mono>> rootsChangeConsumers, - String instructions) { + String instructions, TaskStore taskStore, + TaskMessageQueue taskMessageQueue) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -80,15 +94,24 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, - !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + (!Utils.isEmpty(tools) || !Utils.isEmpty(taskTools)) + ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null, + taskStore != null ? McpSchema.ServerCapabilities.ServerTaskCapabilities.builder() + .list() + .cancel() + .toolsCall() + .build() : null); this.tools = (tools != null) ? tools : List.of(); + this.taskTools = (taskTools != null) ? taskTools : List.of(); this.resources = (resources != null) ? resources : Map.of(); this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : Map.of(); this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); this.instructions = instructions; + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; } /** @@ -107,6 +130,13 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); } + // Convert sync task tools to async + List taskTools = new ArrayList<>(); + for (var taskTool : syncSpec.taskTools()) { + taskTools.add(TaskAwareAsyncToolSpecification.fromSync(taskTool, + immediateExecution ? Runnable::run : Schedulers.boundedElastic()::schedule)); + } + Map resources = new HashMap<>(); syncSpec.resources().forEach((key, resource) -> { resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); @@ -135,8 +165,9 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { .subscribeOn(Schedulers.boundedElastic())); } - return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, resourceTemplates, - prompts, completions, rootChangeConsumers, syncSpec.instructions()); + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, taskTools, resources, + resourceTemplates, prompts, completions, rootChangeConsumers, syncSpec.instructions(), + syncSpec.taskStore(), syncSpec.taskMessageQueue()); } } @@ -146,41 +177,48 @@ static Async fromSync(Sync syncSpec, boolean immediateExecution) { * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities * @param tools The list of tool specifications + * @param taskTools The list of task-aware tool specifications (experimental) * @param resources The map of resource specifications * @param resourceTemplates The list of resource templates * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when the * roots list changes * @param instructions The server instructions text + * @param taskStore The task store for managing long-running tasks (experimental) + * @param taskMessageQueue The message queue for task communication (experimental) */ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, + List tools, List taskTools, Map resources, Map resourceTemplates, Map prompts, Map completions, - List>> rootsChangeConsumers, String instructions) { + List>> rootsChangeConsumers, String instructions, + TaskStore taskStore, TaskMessageQueue taskMessageQueue) { /** * Create an instance and validate the arguments. * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities * @param tools The list of tool specifications + * @param taskTools The list of task-aware tool specifications (experimental) * @param resources The map of resource specifications * @param resourceTemplates The list of resource templates * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes * @param instructions The server instructions text + * @param taskStore The task store for managing long-running tasks (experimental) + * @param taskMessageQueue The message queue for task communication (experimental) */ Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, + List tools, List taskTools, Map resources, Map resourceTemplates, Map prompts, Map completions, - List>> rootsChangeConsumers, - String instructions) { + List>> rootsChangeConsumers, String instructions, + TaskStore taskStore, TaskMessageQueue taskMessageQueue) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -195,15 +233,24 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, - !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + (!Utils.isEmpty(tools) || !Utils.isEmpty(taskTools)) + ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null, + taskStore != null ? McpSchema.ServerCapabilities.ServerTaskCapabilities.builder() + .list() + .cancel() + .toolsCall() + .build() : null); this.tools = (tools != null) ? tools : new ArrayList<>(); + this.taskTools = (taskTools != null) ? taskTools : new ArrayList<>(); this.resources = (resources != null) ? resources : new HashMap<>(); this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Map.of(); this.prompts = (prompts != null) ? prompts : new HashMap<>(); this.completions = (completions != null) ? completions : new HashMap<>(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : new ArrayList<>(); this.instructions = instructions; + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; } } @@ -213,6 +260,10 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * primary way for MCP servers to expose functionality to AI models. Each tool * represents a specific capability. * + *

+ * For task-aware tools that support long-running operations, see + * {@link io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification}. + * * @param tool The tool definition including name, description, and parameter schema * @param call Deprecated. Use the {@link AsyncToolSpecification#callHandler} instead. * @param callHandler The function that implements the tool's logic, receiving a @@ -501,6 +552,10 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * primary way for MCP servers to expose functionality to AI models. * *

+ * For task-aware tools that support long-running operations, see + * {@link io.modelcontextprotocol.experimental.tasks.TaskAwareSyncToolSpecification}. + * + *

* Example tool specification: * *

{@code
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java
index a15681ba5..8714637c4 100644
--- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java
@@ -74,7 +74,9 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s
 							!Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null,
 							!Utils.isEmpty(resources)
 									? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null,
-							!Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null);
+							!Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null,
+							null // tasks
+					);
 
 			this.tools = (tools != null) ? tools : List.of();
 			this.resources = (resources != null) ? resources : Map.of();
@@ -175,7 +177,9 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se
 							!Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null,
 							!Utils.isEmpty(resources)
 									? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null,
-							!Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null);
+							!Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null,
+							null // tasks
+					);
 
 			this.tools = (tools != null) ? tools : new ArrayList<>();
 			this.resources = (resources != null) ? resources : new HashMap<>();
diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java
index 10f0e5a31..9971d0320 100644
--- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java
+++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java
@@ -5,7 +5,13 @@
 package io.modelcontextprotocol.server;
 
 import java.util.List;
+import java.util.concurrent.Executor;
 
+import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification;
+import io.modelcontextprotocol.experimental.tasks.TaskAwareSyncToolSpecification;
+import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue;
+import io.modelcontextprotocol.experimental.tasks.TaskStore;
+import reactor.core.scheduler.Schedulers;
 import io.modelcontextprotocol.spec.McpSchema;
 import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
 import io.modelcontextprotocol.util.Assert;
@@ -105,6 +111,71 @@ public void removeTool(String toolName) {
 		this.asyncServer.removeTool(toolName).block();
 	}
 
+	/**
+	 * Add a new task-aware tool at runtime.
+	 *
+	 * 

+ * Task-aware tools support long-running operations with task lifecycle management + * (SEP-1686). The sync specification is converted to async and delegated to the + * underlying async server. + * @param taskToolSpecification The task-aware tool specification to add + */ + public void addTaskTool(TaskAwareSyncToolSpecification taskToolSpecification) { + Executor executor = this.immediateExecution ? Runnable::run : Schedulers.boundedElastic()::schedule; + TaskAwareAsyncToolSpecification asyncSpec = TaskAwareAsyncToolSpecification.fromSync(taskToolSpecification, + executor); + this.asyncServer.addTaskTool(asyncSpec).block(); + } + + /** + * List all registered task-aware tools. + * @return A list of all registered task-aware tools + */ + public List listTaskTools() { + return this.asyncServer.listTaskTools().collectList().block(); + } + + /** + * Remove a task-aware tool. + * @param toolName The name of the task-aware tool to remove + */ + public void removeTaskTool(String toolName) { + this.asyncServer.removeTaskTool(toolName).block(); + } + + /** + * Sends a task status notification to the client. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @param notification The task status notification to send + */ + public void notifyTaskStatus(McpSchema.TaskStatusNotification notification) { + this.asyncServer.notifyTaskStatus(notification).block(); + } + + /** + * Get the task store used for managing long-running tasks. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @return The task store, or null if tasks are not enabled + */ + public TaskStore getTaskStore() { + return this.asyncServer.getTaskStore(); + } + + /** + * Get the task message queue used for task communication during input_required state. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @return The task message queue, or null if not configured + */ + public TaskMessageQueue getTaskMessageQueue() { + return this.asyncServer.getTaskMessageQueue(); + } + /** * Add a new resource handler. * @param resourceSpecification The resource specification to add diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 0b9115b79..5bb9268df 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -5,9 +5,13 @@ package io.modelcontextprotocol.server; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import java.util.List; + /** * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The * exchange provides methods to interact with the client and query its capabilities. @@ -135,12 +139,260 @@ public void progressNotification(McpSchema.ProgressNotification progressNotifica this.exchange.progressNotification(progressNotification).block(); } + /** + * Sends a task status notification to THIS client only. + * + *

+ * This method sends a notification to the specific client associated with this + * exchange. Use this for targeted notifications when a tool handler needs to update a + * specific client about task progress. + * + *

+ * For broadcasting task status to ALL connected clients, use + * {@link McpSyncServer#notifyTaskStatus(McpSchema.TaskStatusNotification)} instead. + * @param notification The task status notification to send + * @see McpSyncServer#notifyTaskStatus(McpSchema.TaskStatusNotification) for + * broadcasting to all clients + */ + public void notifyTaskStatus(McpSchema.TaskStatusNotification notification) { + this.exchange.notifyTaskStatus(notification).block(); + } + + /** + * Gets the current task ID if this exchange is operating within a task context. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @return The current task ID, or null if not in a task context + */ + public String getCurrentTaskId() { + return this.exchange.getCurrentTaskId(); + } + + /** + * Creates a new exchange scoped to the specified task ID. + *

+ * The returned exchange will have its task context set, making + * {@link #getCurrentTaskId()} return the specified task ID. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @param taskId The task ID to scope this exchange to + * @return A new exchange with task context set + */ + public McpSyncServerExchange withTaskContext(String taskId) { + return new McpSyncServerExchange(this.exchange.withTaskContext(taskId)); + } + + /** + * Creates a new exchange scoped to the specified task ID and message queue. + *

+ * The returned exchange will have its task context and message queue set. + *

+ * Warning: This is an experimental API that may change in future + * releases. Use with caution in production environments. + * @param taskId The task ID to scope this exchange to + * @param queue The message queue for task communication + * @return A new exchange with task context and message queue set + */ + public McpSyncServerExchange withTaskContext(String taskId, TaskMessageQueue queue) { + return new McpSyncServerExchange(this.exchange.withTaskContext(taskId, queue)); + } + /** * Sends a synchronous ping request to the client. - * @return + * @return The ping response from the client */ public Object ping() { return this.exchange.ping().block(); } + // -------------------------- + // Client Task Operations + // -------------------------- + + /** + * Get the status of a task hosted by the client. This is used when the server has + * sent a task-augmented request to the client and needs to poll for status updates. + * @param getTaskRequest The request containing the task ID + * @return The task status + */ + public McpSchema.GetTaskResult getTask(McpSchema.GetTaskRequest getTaskRequest) { + return this.exchange.getTask(getTaskRequest).block(); + } + + /** + * Get the result of a completed task hosted by the client. + * @param The expected result type + * @param getTaskPayloadRequest The request containing the task ID + * @param resultTypeRef Type reference for deserializing the result + * @return The task result + */ + public T getTaskResult( + McpSchema.GetTaskPayloadRequest getTaskPayloadRequest, TypeRef resultTypeRef) { + return this.exchange.getTaskResult(getTaskPayloadRequest, resultTypeRef).block(); + } + + /** + * List all tasks hosted by the client. + * + *

+ * This method automatically handles pagination, fetching all pages and combining them + * into a single result. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @return The list of all client tasks + */ + public McpSchema.ListTasksResult listTasks() { + return this.exchange.listTasks().block(); + } + + /** + * List tasks hosted by the client with pagination support. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cursor Pagination cursor from a previous list request + * @return A page of client tasks + */ + public McpSchema.ListTasksResult listTasks(String cursor) { + return this.exchange.listTasks(cursor).block(); + } + + /** + * Request cancellation of a task hosted by the client. + * + *

+ * Note that cancellation is cooperative - the client may not honor the cancellation + * request, or may take some time to cancel the task. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param cancelTaskRequest The request containing the task ID + * @return The updated task status + */ + public McpSchema.CancelTaskResult cancelTask(McpSchema.CancelTaskRequest cancelTaskRequest) { + return this.exchange.cancelTask(cancelTaskRequest).block(); + } + + /** + * Request cancellation of a task hosted by the client by task ID. + * + *

+ * This is a convenience overload that creates a {@link McpSchema.CancelTaskRequest} + * with the given task ID. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param taskId The task identifier to cancel + * @return The updated task status + */ + public McpSchema.CancelTaskResult cancelTask(String taskId) { + return this.exchange.cancelTask(taskId).block(); + } + + // -------------------------- + // Task-Augmented Sampling + // -------------------------- + + /** + * Low-level method to create a new message using task-augmented sampling. The client + * will process the request as a long-running task, allowing the server to poll for + * status updates. + * + *

+ * Recommendation: For most use cases, prefer + * {@link #createMessageStream} which provides a unified interface that handles both + * regular and task-augmented sampling automatically, including polling and result + * retrieval. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param createMessageRequest The request to create a new message (must have task + * metadata) + * @return The task creation result + * @see #createMessageStream + */ + public McpSchema.CreateTaskResult createMessageTask(McpSchema.CreateMessageRequest createMessageRequest) { + return this.exchange.createMessageTask(createMessageRequest).block(); + } + + /** + * Create a message and return a list of response messages, handling both regular and + * task-augmented requests automatically. + * + *

+ * This method blocks until the sampling completes. For non-blocking streaming, use + * the async exchange's + * {@link McpAsyncServerExchange#createMessageStream(McpSchema.CreateMessageRequest)} + * method. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param createMessageRequest The request containing the sampling parameters. If the + * {@code task} field is set, the call will be task-augmented. + * @return A list of {@link McpSchema.ResponseMessage} instances representing the + * progress and result + */ + public List> createMessageStream( + McpSchema.CreateMessageRequest createMessageRequest) { + return this.exchange.createMessageStream(createMessageRequest).collectList().block(); + } + + // -------------------------- + // Task-Augmented Elicitation + // -------------------------- + + /** + * Low-level method to create a new elicitation using task-augmented processing. The + * client will process the request as a long-running task, allowing the server to poll + * for status updates. + * + *

+ * Recommendation: For most use cases, prefer + * {@link #createElicitationStream} which provides a unified interface that handles + * both regular and task-augmented elicitation automatically, including polling and + * result retrieval. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param elicitRequest The elicitation request (must have task metadata) + * @return The task creation result + * @see #createElicitationStream + */ + public McpSchema.CreateTaskResult createElicitationTask(McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitationTask(elicitRequest).block(); + } + + /** + * Create an elicitation and return a list of response messages, handling both regular + * and task-augmented requests automatically. + * + *

+ * This method blocks until the elicitation completes. For non-blocking streaming, use + * the async exchange's + * {@link McpAsyncServerExchange#createElicitationStream(McpSchema.ElicitRequest)} + * method. + * + *

+ * Note: This is an experimental feature that may change in future + * releases. + * @param elicitRequest The request containing the elicitation parameters. If the + * {@code task} field is set, the call will be task-augmented. + * @return A list of {@link McpSchema.ResponseMessage} instances representing the + * progress and result + */ + public List> createElicitationStream( + McpSchema.ElicitRequest elicitRequest) { + return this.exchange.createElicitationStream(elicitRequest).collectList().block(); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index b58f1c552..e9450adb2 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -5,13 +5,16 @@ package io.modelcontextprotocol.spec; import java.io.IOException; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -21,6 +24,7 @@ import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; +import reactor.util.annotation.Nullable; import org.slf4j.LoggerFactory; /** @@ -108,6 +112,17 @@ private McpSchema() { // Elicitation Methods public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + // Tasks Methods + public static final String METHOD_TASKS_GET = "tasks/get"; + + public static final String METHOD_TASKS_RESULT = "tasks/result"; + + public static final String METHOD_TASKS_CANCEL = "tasks/cancel"; + + public static final String METHOD_TASKS_LIST = "tasks/list"; + + public static final String METHOD_NOTIFICATION_TASKS_STATUS = "notifications/tasks/status"; + // --------------------------- // JSON-RPC Error Codes // --------------------------- @@ -163,9 +178,9 @@ public interface Meta { } - public sealed interface Request extends Meta - permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, - GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, PaginatedRequest { + public sealed interface Request extends Meta permits InitializeRequest, CallToolRequest, CreateMessageRequest, + ElicitRequest, CompleteRequest, GetPromptRequest, ReadResourceRequest, SubscribeRequest, UnsubscribeRequest, + PaginatedRequest, GetTaskRequest, GetTaskPayloadRequest, CancelTaskRequest { default Object progressToken() { if (meta() != null && meta().containsKey("progressToken")) { @@ -176,14 +191,39 @@ default Object progressToken() { } - public sealed interface Result extends Meta permits InitializeResult, ListResourcesResult, - ListResourceTemplatesResult, ReadResourceResult, ListPromptsResult, GetPromptResult, ListToolsResult, - CallToolResult, CreateMessageResult, ElicitResult, CompleteResult, ListRootsResult { + public sealed interface Result extends Meta + permits InitializeResult, ListResourcesResult, ListResourceTemplatesResult, ReadResourceResult, + ListPromptsResult, GetPromptResult, ListToolsResult, CompleteResult, ListRootsResult, GetTaskResult, + CancelTaskResult, ListTasksResult, CreateTaskResult, ServerTaskPayloadResult, ClientTaskPayloadResult { + + } + + /** + * Sealed interface for results that servers produce from task-augmented operations. + * When a client calls a server's tool in task mode, the server produces a + * {@link CallToolResult}. + * + *

+ * This interface provides type safety for server-side task result handling. + */ + public sealed interface ServerTaskPayloadResult extends Result permits CallToolResult { + + } + + /** + * Sealed interface for results that clients produce from task-augmented operations. + * When a server requests sampling or elicitation from a client in task mode, the + * client produces either a {@link CreateMessageResult} or {@link ElicitResult}. + * + *

+ * This interface provides type safety for client-side task result handling. + */ + public sealed interface ClientTaskPayloadResult extends Result permits CreateMessageResult, ElicitResult { } - public sealed interface Notification extends Meta - permits ProgressNotification, LoggingMessageNotification, ResourcesUpdatedNotification { + public sealed interface Notification extends Meta permits ProgressNotification, LoggingMessageNotification, + ResourcesUpdatedNotification, TaskStatusNotification { } @@ -385,7 +425,8 @@ public record ClientCapabilities( // @formatter:off @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, @JsonProperty("sampling") Sampling sampling, - @JsonProperty("elicitation") Elicitation elicitation) { // @formatter:on + @JsonProperty("elicitation") Elicitation elicitation, + @JsonProperty("tasks") ClientTaskCapabilities tasks) { // @formatter:on /** * Present if the client supports listing roots. @@ -459,6 +500,242 @@ public Elicitation() { } } + /** + * Present if the client supports task-augmented requests. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ClientTaskCapabilities { + + private final ListTaskCapability list; + + private final CancelTaskCapability cancel; + + private final ClientTaskRequestCapabilities requests; + + @JsonCreator + private ClientTaskCapabilities(@JsonProperty("list") ListTaskCapability list, + @JsonProperty("cancel") CancelTaskCapability cancel, + @JsonProperty("requests") ClientTaskRequestCapabilities requests) { + this.list = list; + this.cancel = cancel; + this.requests = requests; + } + + /** + * Returns whether the client supports tasks/list requests. + * @return the list capability, or null + */ + @JsonProperty("list") + public ListTaskCapability list() { + return this.list; + } + + /** + * Returns whether the client supports tasks/cancel requests. + * @return the cancel capability, or null + */ + @JsonProperty("cancel") + public CancelTaskCapability cancel() { + return this.cancel; + } + + /** + * Returns which request types can be augmented with tasks. + * @return the request capabilities, or null + */ + @JsonProperty("requests") + public ClientTaskRequestCapabilities requests() { + return this.requests; + } + + /** + * Marker class indicating support for tasks/list. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ListTaskCapability { + + @JsonCreator + ListTaskCapability() { + } + + } + + /** + * Marker class indicating support for tasks/cancel. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CancelTaskCapability { + + @JsonCreator + CancelTaskCapability() { + } + + } + + /** + * Specifies which request types can be augmented with tasks. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ClientTaskRequestCapabilities { + + private final SamplingTaskCapabilities sampling; + + private final ElicitationTaskCapabilities elicitation; + + @JsonCreator + ClientTaskRequestCapabilities(@JsonProperty("sampling") SamplingTaskCapabilities sampling, + @JsonProperty("elicitation") ElicitationTaskCapabilities elicitation) { + this.sampling = sampling; + this.elicitation = elicitation; + } + + /** + * Returns the task support for sampling-related requests. + * @return the sampling capabilities, or null + */ + @JsonProperty("sampling") + public SamplingTaskCapabilities sampling() { + return this.sampling; + } + + /** + * Returns the task support for elicitation-related requests. + * @return the elicitation capabilities, or null + */ + @JsonProperty("elicitation") + public ElicitationTaskCapabilities elicitation() { + return this.elicitation; + } + + /** + * Task support for sampling-related requests. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class SamplingTaskCapabilities { + + private final CreateMessageTaskCapability createMessage; + + @JsonCreator + SamplingTaskCapabilities(@JsonProperty("createMessage") CreateMessageTaskCapability createMessage) { + this.createMessage = createMessage; + } + + /** + * Returns whether the client supports task-augmented + * sampling/createMessage requests. + * @return the createMessage capability, or null + */ + @JsonProperty("createMessage") + public CreateMessageTaskCapability createMessage() { + return this.createMessage; + } + + /** + * Marker class indicating support for task-augmented + * sampling/createMessage. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CreateMessageTaskCapability { + + @JsonCreator + CreateMessageTaskCapability() { + } + + } + + } + + /** + * Task support for elicitation-related requests. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ElicitationTaskCapabilities { + + private final CreateTaskCapability create; + + @JsonCreator + ElicitationTaskCapabilities(@JsonProperty("create") CreateTaskCapability create) { + this.create = create; + } + + /** + * Returns whether the client supports task-augmented + * elicitation/create requests. + * @return the create capability, or null + */ + @JsonProperty("create") + public CreateTaskCapability create() { + return this.create; + } + + /** + * Marker class indicating support for task-augmented + * elicitation/create. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CreateTaskCapability { + + @JsonCreator + CreateTaskCapability() { + } + + } + + } + + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ListTaskCapability list; + + private CancelTaskCapability cancel; + + private ClientTaskRequestCapabilities requests; + + public Builder list() { + this.list = new ListTaskCapability(); + return this; + } + + public Builder cancel() { + this.cancel = new CancelTaskCapability(); + return this; + } + + public Builder samplingCreateMessage() { + if (this.requests == null) { + this.requests = new ClientTaskRequestCapabilities(null, null); + } + this.requests = new ClientTaskRequestCapabilities( + new ClientTaskRequestCapabilities.SamplingTaskCapabilities( + new ClientTaskRequestCapabilities.SamplingTaskCapabilities.CreateMessageTaskCapability()), + this.requests.elicitation()); + return this; + } + + public Builder elicitationCreate() { + if (this.requests == null) { + this.requests = new ClientTaskRequestCapabilities(null, null); + } + this.requests = new ClientTaskRequestCapabilities(this.requests.sampling(), + new ClientTaskRequestCapabilities.ElicitationTaskCapabilities( + new ClientTaskRequestCapabilities.ElicitationTaskCapabilities.CreateTaskCapability())); + return this; + } + + public ClientTaskCapabilities build() { + return new ClientTaskCapabilities(list, cancel, requests); + } + + } + + } + public static Builder builder() { return new Builder(); } @@ -473,6 +750,8 @@ public static class Builder { private Elicitation elicitation; + private ClientTaskCapabilities tasks; + public Builder experimental(Map experimental) { this.experimental = experimental; return this; @@ -510,8 +789,18 @@ public Builder elicitation(boolean form, boolean url) { return this; } + /** + * Enables task capabilities with the provided configuration. + * @param tasks the task capabilities + * @return this builder + */ + public Builder tasks(ClientTaskCapabilities tasks) { + this.tasks = tasks; + return this; + } + public ClientCapabilities build() { - return new ClientCapabilities(experimental, roots, sampling, elicitation); + return new ClientCapabilities(experimental, roots, sampling, elicitation, tasks); } } @@ -539,7 +828,8 @@ public record ServerCapabilities( // @formatter:off @JsonProperty("logging") LoggingCapabilities logging, @JsonProperty("prompts") PromptCapabilities prompts, @JsonProperty("resources") ResourceCapabilities resources, - @JsonProperty("tools") ToolCapabilities tools) { // @formatter:on + @JsonProperty("tools") ToolCapabilities tools, + @JsonProperty("tasks") ServerTaskCapabilities tasks) { // @formatter:on /** * Present if the server supports argument autocompletion suggestions. @@ -587,6 +877,176 @@ public record ResourceCapabilities(@JsonProperty("subscribe") Boolean subscribe, public record ToolCapabilities(@JsonProperty("listChanged") Boolean listChanged) { } + /** + * Present if the server supports task-augmented requests. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ServerTaskCapabilities { + + private final ListTaskCapability list; + + private final CancelTaskCapability cancel; + + private final ServerTaskRequestCapabilities requests; + + @JsonCreator + private ServerTaskCapabilities(@JsonProperty("list") ListTaskCapability list, + @JsonProperty("cancel") CancelTaskCapability cancel, + @JsonProperty("requests") ServerTaskRequestCapabilities requests) { + this.list = list; + this.cancel = cancel; + this.requests = requests; + } + + /** + * Returns whether the server supports tasks/list requests. + * @return the list capability, or null + */ + @JsonProperty("list") + public ListTaskCapability list() { + return this.list; + } + + /** + * Returns whether the server supports tasks/cancel requests. + * @return the cancel capability, or null + */ + @JsonProperty("cancel") + public CancelTaskCapability cancel() { + return this.cancel; + } + + /** + * Returns which request types can be augmented with tasks. + * @return the request capabilities, or null + */ + @JsonProperty("requests") + public ServerTaskRequestCapabilities requests() { + return this.requests; + } + + /** + * Marker class indicating support for tasks/list. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ListTaskCapability { + + @JsonCreator + ListTaskCapability() { + } + + } + + /** + * Marker class indicating support for tasks/cancel. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CancelTaskCapability { + + @JsonCreator + CancelTaskCapability() { + } + + } + + /** + * Specifies which request types can be augmented with tasks. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ServerTaskRequestCapabilities { + + private final ToolsTaskCapabilities tools; + + @JsonCreator + ServerTaskRequestCapabilities(@JsonProperty("tools") ToolsTaskCapabilities tools) { + this.tools = tools; + } + + /** + * Returns the task support for tool-related requests. + * @return the tools capabilities, or null + */ + @JsonProperty("tools") + public ToolsTaskCapabilities tools() { + return this.tools; + } + + /** + * Task support for tool-related requests. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ToolsTaskCapabilities { + + private final CallTaskCapability call; + + @JsonCreator + ToolsTaskCapabilities(@JsonProperty("call") CallTaskCapability call) { + this.call = call; + } + + /** + * Returns whether the server supports task-augmented tools/call + * requests. + * @return the call capability, or null + */ + @JsonProperty("call") + public CallTaskCapability call() { + return this.call; + } + + /** + * Marker class indicating support for task-augmented tools/call. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CallTaskCapability { + + @JsonCreator + CallTaskCapability() { + } + + } + + } + + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ListTaskCapability list; + + private CancelTaskCapability cancel; + + private ServerTaskRequestCapabilities requests; + + public Builder list() { + this.list = new ListTaskCapability(); + return this; + } + + public Builder cancel() { + this.cancel = new CancelTaskCapability(); + return this; + } + + public Builder toolsCall() { + this.requests = new ServerTaskRequestCapabilities( + new ServerTaskRequestCapabilities.ToolsTaskCapabilities( + new ServerTaskRequestCapabilities.ToolsTaskCapabilities.CallTaskCapability())); + return this; + } + + public ServerTaskCapabilities build() { + return new ServerTaskCapabilities(list, cancel, requests); + } + + } + + } + /** * Create a mutated copy of this object with the specified changes. * @return A new Builder instance with the same values as this object. @@ -599,6 +1059,7 @@ public Builder mutate() { builder.prompts = this.prompts; builder.resources = this.resources; builder.tools = this.tools; + builder.tasks = this.tasks; return builder; } @@ -620,6 +1081,8 @@ public static class Builder { private ToolCapabilities tools; + private ServerTaskCapabilities tasks; + public Builder completions() { this.completions = new CompletionCapabilities(); return this; @@ -650,8 +1113,18 @@ public Builder tools(Boolean listChanged) { return this; } + /** + * Enables task capabilities with the provided configuration. + * @param tasks the task capabilities + * @return this builder + */ + public Builder tasks(ServerTaskCapabilities tasks) { + this.tasks = tasks; + return this; + } + public ServerCapabilities build() { - return new ServerCapabilities(completions, experimental, logging, prompts, resources, tools); + return new ServerCapabilities(completions, experimental, logging, prompts, resources, tools, tasks); } } @@ -686,66 +1159,529 @@ public enum Role { @JsonProperty("assistant") ASSISTANT } // @formatter:on - // --------------------------- - // Resource Interfaces - // --------------------------- /** - * Base for objects that include optional annotations for the client. The client can - * use annotations to inform how objects are used or displayed + * The status of a task. */ - public interface Annotated { + public enum TaskStatus { - Annotations annotations(); + // @formatter:off + /** + * The request is currently being processed. + */ + @JsonProperty("working") WORKING, + /** + * The task is waiting for input (e.g., elicitation or sampling). + */ + @JsonProperty("input_required") INPUT_REQUIRED, + /** + * The request completed successfully and results are available. + */ + @JsonProperty("completed") COMPLETED, + /** + * The associated request did not complete successfully. For tool calls specifically, + * this includes cases where the tool call result has isError set to true. + */ + @JsonProperty("failed") FAILED, + /** + * The request was cancelled before completion. + */ + @JsonProperty("cancelled") CANCELLED; + // @formatter:on + + /** + * Checks if this status represents a terminal state. + *

+ * Terminal states are those where the task has finished processing and will not + * change further: COMPLETED, FAILED, or CANCELLED. + * @return true if this status is a terminal state + */ + public boolean isTerminal() { + return this == COMPLETED || this == FAILED || this == CANCELLED; + } } /** - * Optional annotations for the client. The client can use annotations to inform how - * objects are used or displayed. + * Represents the state and metadata of an asynchronous operation tracked by the MCP + * task system. Tasks are created when a client requests task-augmented execution of a + * tool, sampling, or elicitation operation. * - * @param audience Describes who the intended customer of this object or data is. It - * can include multiple entries to indicate content useful for multiple audiences - * (e.g., `["user", "assistant"]`). - * @param priority Describes how important this data is for operating the server. A - * value of 1 means "most important," and indicates that the data is effectively - * required, while 0 means "least important," and indicates that the data is entirely - * optional. It is a number between 0 and 1. + *

+ * A task progresses through various states ({@link TaskStatus}) and can include + * optional metadata such as TTL, poll interval, and status messages. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see TaskStatus + * @see Builder */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - public record Annotations( // @formatter:off - @JsonProperty("audience") List audience, - @JsonProperty("priority") Double priority, - @JsonProperty("lastModified") String lastModified - ) { // @formatter:on - - public Annotations(List audience, Double priority) { - this(audience, priority, null); + public static final class Task { + + private final String taskId; + + private final TaskStatus status; + + private final String statusMessage; + + private final String createdAt; + + private final String lastUpdatedAt; + + private final Long ttl; + + private final Long pollInterval; + + @JsonCreator + private Task( // @formatter:off + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") @Nullable String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") @Nullable Long ttl, + @JsonProperty("pollInterval") @Nullable Long pollInterval) { // @formatter:on + Assert.hasText(taskId, "taskId must not be empty"); + Assert.notNull(status, "status must not be null"); + Assert.hasText(createdAt, "createdAt must not be empty"); + Assert.hasText(lastUpdatedAt, "lastUpdatedAt must not be empty"); + // ttl and pollInterval can be null (unlimited/default) + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; } - } - /** - * A common interface for resource content, which includes metadata about the resource - * such as its URI, name, description, MIME type, size, and annotations. This - * interface is implemented by both {@link Resource} and {@link ResourceLink} to - * provide a consistent way to access resource metadata. - */ - public interface ResourceContent extends Identifier, Annotated, Meta { + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } - // name & title from Identifier + /** + * Returns the task status. + * @return the task status + */ + @JsonProperty("status") + public TaskStatus status() { + return this.status; + } - String uri(); + /** + * Returns the optional status message. + * @return the status message, or null + */ + @JsonProperty("statusMessage") + @Nullable + public String statusMessage() { + return this.statusMessage; + } - String description(); + /** + * Returns the creation timestamp. + * @return the ISO 8601 creation timestamp + */ + @JsonProperty("createdAt") + public String createdAt() { + return this.createdAt; + } - String mimeType(); + /** + * Returns the last updated timestamp. + * @return the ISO 8601 last updated timestamp + */ + @JsonProperty("lastUpdatedAt") + public String lastUpdatedAt() { + return this.lastUpdatedAt; + } - Long size(); + /** + * Returns the TTL (time-to-live) in milliseconds. + * @return the TTL, or null for unlimited + */ + @JsonProperty("ttl") + @Nullable + public Long ttl() { + return this.ttl; + } - // annotations from Annotated - // meta from Meta + /** + * Returns the suggested polling interval in milliseconds. + * @return the polling interval, or null + */ + @JsonProperty("pollInterval") + @Nullable + public Long pollInterval() { + return this.pollInterval; + } - } + /** + * Checks if the task is in a terminal state (completed, failed, or cancelled). + * @return true if the task is in a terminal state + */ + public boolean isTerminal() { + return status.isTerminal(); + } + + /** + * Creates a new builder for Task. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link Task}. + */ + public static class Builder { + + private String taskId; + + private TaskStatus status; + + private String statusMessage; + + private String createdAt; + + private String lastUpdatedAt; + + private Long ttl; + + private Long pollInterval; + + /** + * Sets the task identifier. + * @param taskId the task identifier + * @return this builder + */ + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + /** + * Sets the task status. + * @param status the task status + * @return this builder + */ + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + /** + * Sets the optional status message. + * @param statusMessage human-readable status message + * @return this builder + */ + public Builder statusMessage(String statusMessage) { + this.statusMessage = statusMessage; + return this; + } + + /** + * Sets the creation timestamp. + * @param createdAt ISO 8601 timestamp when the task was created + * @return this builder + */ + public Builder createdAt(String createdAt) { + this.createdAt = createdAt; + return this; + } + + /** + * Sets the last updated timestamp. + * @param lastUpdatedAt ISO 8601 timestamp when the task was last updated + * @return this builder + */ + public Builder lastUpdatedAt(String lastUpdatedAt) { + this.lastUpdatedAt = lastUpdatedAt; + return this; + } + + /** + * Sets both createdAt and lastUpdatedAt to the current time in ISO 8601 + * format. + * + *

+ * Note: Timestamps must be valid ISO 8601 format strings. + * This method uses {@code Instant.now().toString()} which produces compliant + * output. + * @return this builder + */ + public Builder timestamps() { + String now = java.time.Instant.now().toString(); + this.createdAt = now; + this.lastUpdatedAt = now; + return this; + } + + /** + * Sets the TTL (time-to-live) in milliseconds. + * @param ttl retention duration from creation in milliseconds, null for + * unlimited + * @return this builder + */ + public Builder ttl(Long ttl) { + this.ttl = ttl; + return this; + } + + /** + * Sets the suggested polling interval in milliseconds. + * @param pollInterval polling interval in milliseconds + * @return this builder + */ + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + /** + * Builds a new {@link Task} instance. + * @return a new Task instance + */ + public Task build() { + return new Task(taskId, status, statusMessage, createdAt, lastUpdatedAt, ttl, pollInterval); + } + + } + + } + + /** + * Metadata for augmenting a request with task execution. Include this in the + * {@code task} field of the request parameters to indicate that the operation should + * be executed as a background task rather than synchronously. + * + *

+ * When present, the server creates a task and returns immediately with task + * information, allowing the client to poll for status and retrieve results later. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Task + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class TaskMetadata { + + private final Long ttl; + + @JsonCreator + private TaskMetadata(@JsonProperty("ttl") Long ttl) { + if (ttl != null && ttl < 0) { + throw new IllegalArgumentException("ttl must not be negative"); + } + this.ttl = ttl; + } + + /** + * Returns the TTL (time-to-live) in milliseconds. + * @return the TTL, or null for no specific retention + */ + @JsonProperty("ttl") + public Long ttl() { + return this.ttl; + } + + /** + * Returns the TTL as a Duration, or null if not set. + * @return the TTL duration, or null + */ + @JsonIgnore + public Duration ttlAsDuration() { + return ttl != null ? Duration.ofMillis(ttl) : null; + } + + /** + * Creates a new builder for TaskMetadata. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating TaskMetadata instances with Duration-based TTL. + */ + public static class Builder { + + private Duration ttl; + + /** + * Sets the TTL (time-to-live) for the task. + * @param ttl the duration to retain the task, converted to milliseconds + * @return this builder + */ + public Builder ttl(Duration ttl) { + this.ttl = ttl; + return this; + } + + /** + * Builds the TaskMetadata instance. + * @return a new TaskMetadata + * @throws IllegalArgumentException if TTL is negative + */ + public TaskMetadata build() { + Long ttlMs = this.ttl != null ? this.ttl.toMillis() : null; + return new TaskMetadata(ttlMs); + } + + } + + } + + /** + * The well-known key for related task metadata in the _meta field. + */ + public static final String RELATED_TASK_META_KEY = "io.modelcontextprotocol/related-task"; + + /** + * Metadata for associating messages with a task. Include this in the {@code _meta} + * field under the key {@link #RELATED_TASK_META_KEY} to indicate that a notification + * or other message is related to a specific task. + * + *

+ * This enables correlation of progress notifications, logging messages, and other + * communications with their originating task context. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Task + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class RelatedTaskMetadata { + + private final String taskId; + + @JsonCreator + private RelatedTaskMetadata(@JsonProperty("taskId") String taskId) { + Assert.hasText(taskId, "taskId must not be empty"); + this.taskId = taskId; + } + + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } + + /** + * Creates a new builder for RelatedTaskMetadata. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link RelatedTaskMetadata}. + */ + public static class Builder { + + private String taskId; + + /** + * Sets the task identifier. + * @param taskId the task identifier + * @return this builder + */ + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + /** + * Builds a new {@link RelatedTaskMetadata} instance. + * @return a new RelatedTaskMetadata instance + */ + public RelatedTaskMetadata build() { + return new RelatedTaskMetadata(taskId); + } + + } + + } + + // --------------------------- + // Resource Interfaces + // --------------------------- + /** + * Base for objects that include optional annotations for the client. The client can + * use annotations to inform how objects are used or displayed + */ + public interface Annotated { + + Annotations annotations(); + + } + + /** + * Optional annotations for the client. The client can use annotations to inform how + * objects are used or displayed. + * + * @param audience Describes who the intended customer of this object or data is. It + * can include multiple entries to indicate content useful for multiple audiences + * (e.g., `["user", "assistant"]`). + * @param priority Describes how important this data is for operating the server. A + * value of 1 means "most important," and indicates that the data is effectively + * required, while 0 means "least important," and indicates that the data is entirely + * optional. It is a number between 0 and 1. + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Annotations( // @formatter:off + @JsonProperty("audience") List audience, + @JsonProperty("priority") Double priority, + @JsonProperty("lastModified") String lastModified + ) { // @formatter:on + + public Annotations(List audience, Double priority) { + this(audience, priority, null); + } + } + + /** + * A common interface for resource content, which includes metadata about the resource + * such as its URI, name, description, MIME type, size, and annotations. This + * interface is implemented by both {@link Resource} and {@link ResourceLink} to + * provide a consistent way to access resource metadata. + */ + public interface ResourceContent extends Identifier, Annotated, Meta { + + // name & title from Identifier + + String uri(); + + String description(); + + String mimeType(); + + Long size(); + + // annotations from Annotated + // meta from Meta + + } /** * Base interface with name (identifier) and title (display name) properties. @@ -1373,30 +2309,123 @@ public record ToolAnnotations( // @formatter:off } /** - * Represents a tool that the server provides. Tools enable servers to expose - * executable functionality to the system. Through these tools, you can interact with - * external systems, perform computations, and take actions in the real world. + * Indicates whether a tool supports task-augmented execution. + */ + public enum TaskSupportMode { + + // @formatter:off + /** + * Tool does not support task-augmented execution. This is the default when absent. + */ + @JsonProperty("forbidden") FORBIDDEN, + /** + * Tool may support task-augmented execution. + */ + @JsonProperty("optional") OPTIONAL, + /** + * Tool requires task-augmented execution. + */ + @JsonProperty("required") REQUIRED + // @formatter:on + + } + + /** + * Execution-related properties for a tool. * - * @param name A unique identifier for the tool. This name is used when calling the - * tool. - * @param title A human-readable title for the tool. - * @param description A human-readable description of what the tool does. This can be - * used by clients to improve the LLM's understanding of available tools. - * @param inputSchema A JSON Schema object that describes the expected structure of - * the arguments when calling this tool. This allows clients to validate tool - * @param outputSchema An optional JSON Schema object defining the structure of the - * tool's output returned in the structuredContent field of a CallToolResult. - * @param annotations Optional additional tool information. - * @param meta See specification for notes on _meta usage + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - public record Tool( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("title") String title, - @JsonProperty("description") String description, - @JsonProperty("inputSchema") JsonSchema inputSchema, + public static final class ToolExecution { + + private final TaskSupportMode taskSupport; + + @JsonCreator + private ToolExecution(@JsonProperty("taskSupport") TaskSupportMode taskSupport) { + this.taskSupport = taskSupport; + } + + /** + * Returns the task support mode for this tool. + * @return the task support mode + */ + public TaskSupportMode taskSupport() { + return this.taskSupport; + } + + /** + * Creates a new builder for ToolExecution. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link ToolExecution}. + */ + public static class Builder { + + private TaskSupportMode taskSupport; + + /** + * Sets the task support mode. Indicates whether this tool supports + * task-augmented execution. This allows clients to handle long-running + * operations through polling the task system. Default is + * {@link TaskSupportMode#FORBIDDEN} when absent. + * @param taskSupport the task support mode + * @return this builder + */ + public Builder taskSupport(TaskSupportMode taskSupport) { + this.taskSupport = taskSupport; + return this; + } + + /** + * Builds a new {@link ToolExecution} instance. + * @return a new ToolExecution instance + */ + public ToolExecution build() { + return new ToolExecution(taskSupport); + } + + } + + } + + /** + * Represents a tool that the server provides. Tools enable servers to expose + * executable functionality to the system. Through these tools, you can interact with + * external systems, perform computations, and take actions in the real world. + * + * @param name A unique identifier for the tool. This name is used when calling the + * tool. + * @param title A human-readable title for the tool. + * @param description A human-readable description of what the tool does. This can be + * used by clients to improve the LLM's understanding of available tools. + * @param inputSchema A JSON Schema object that describes the expected structure of + * the arguments when calling this tool. This allows clients to validate tool + * arguments before sending them to the server. + * @param outputSchema An optional JSON Schema object defining the structure of the + * tool's output returned in the structuredContent field of a CallToolResult. + * @param execution Execution-related properties for the tool, including task support + * mode which indicates whether this tool supports task-augmented execution. + * @param annotations Optional additional tool information. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record Tool( // @formatter:off + @JsonProperty("name") String name, + @JsonProperty("title") String title, + @JsonProperty("description") String description, + @JsonProperty("inputSchema") JsonSchema inputSchema, @JsonProperty("outputSchema") Map outputSchema, + @JsonProperty("execution") ToolExecution execution, @JsonProperty("annotations") ToolAnnotations annotations, @JsonProperty("_meta") Map meta) { // @formatter:on @@ -1416,6 +2445,8 @@ public static class Builder { private Map outputSchema; + private ToolExecution execution; + private ToolAnnotations annotations; private Map meta; @@ -1455,6 +2486,11 @@ public Builder outputSchema(McpJsonMapper jsonMapper, String outputSchema) { return this; } + public Builder execution(ToolExecution execution) { + this.execution = execution; + return this; + } + public Builder annotations(ToolAnnotations annotations) { this.annotations = annotations; return this; @@ -1467,7 +2503,7 @@ public Builder meta(Map meta) { public Tool build() { Assert.hasText(name, "name must not be empty"); - return new Tool(name, title, description, inputSchema, outputSchema, annotations, meta); + return new Tool(name, title, description, inputSchema, outputSchema, execution, annotations, meta); } } @@ -1498,6 +2534,9 @@ private static JsonSchema parseSchema(McpJsonMapper jsonMapper, String schema) { * tools/list. * @param arguments Arguments to pass to the tool. These must conform to the tool's * input schema. + * @param task If specified, the caller is requesting task-augmented execution for + * this request. The request will return a CreateTaskResult immediately, and the + * actual result can be retrieved later via tasks/result. * @param meta Optional metadata about the request. This can include additional * information like `progressToken` */ @@ -1506,14 +2545,19 @@ private static JsonSchema parseSchema(McpJsonMapper jsonMapper, String schema) { public record CallToolRequest( // @formatter:off @JsonProperty("name") String name, @JsonProperty("arguments") Map arguments, + @JsonProperty("task") TaskMetadata task, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on public CallToolRequest(McpJsonMapper jsonMapper, String name, String jsonArguments) { - this(name, parseJsonArguments(jsonMapper, jsonArguments), null); + this(name, parseJsonArguments(jsonMapper, jsonArguments), null, null); } public CallToolRequest(String name, Map arguments) { - this(name, arguments, null); + this(name, arguments, null, null); + } + + public CallToolRequest(String name, Map arguments, Map meta) { + this(name, arguments, null, meta); } private static Map parseJsonArguments(McpJsonMapper jsonMapper, String jsonArguments) { @@ -1535,6 +2579,8 @@ public static class Builder { private Map arguments; + private TaskMetadata task; + private Map meta; public Builder name(String name) { @@ -1552,6 +2598,26 @@ public Builder arguments(McpJsonMapper jsonMapper, String jsonArguments) { return this; } + /** + * Sets task metadata for task-augmented execution. + * @param task the task metadata + * @return this builder + */ + public Builder task(TaskMetadata task) { + this.task = task; + return this; + } + + /** + * Sets task metadata for task-augmented execution with the specified TTL. + * @param ttl requested duration in milliseconds to retain task from creation + * @return this builder + */ + public Builder task(Long ttl) { + this.task = new TaskMetadata(ttl); + return this; + } + public Builder meta(Map meta) { this.meta = meta; return this; @@ -1567,7 +2633,7 @@ public Builder progressToken(Object progressToken) { public CallToolRequest build() { Assert.hasText(name, "name must not be empty"); - return new CallToolRequest(name, arguments, meta); + return new CallToolRequest(name, arguments, task, meta); } } @@ -1590,7 +2656,7 @@ public record CallToolResult( // @formatter:off @JsonProperty("content") List content, @JsonProperty("isError") Boolean isError, @JsonProperty("structuredContent") Object structuredContent, - @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + @JsonProperty("_meta") Map meta) implements ServerTaskPayloadResult { // @formatter:on /** * @deprecated use the builder instead. @@ -1830,9 +2896,6 @@ public ModelPreferences build() { @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ModelHint(@JsonProperty("name") String name) { - public static ModelHint of(String name) { - return new ModelHint(name); - } } /** @@ -1868,6 +2931,9 @@ public record SamplingMessage( // @formatter:off * @param stopSequences Optional stop sequences for sampling * @param metadata Optional metadata to pass through to the LLM provider. The format * of this metadata is provider-specific + * @param task If specified, the caller is requesting task-augmented execution for + * this request. The request will return a CreateTaskResult immediately, and the + * actual result can be retrieved later via tasks/result. * @param meta See specification for notes on _meta usage */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -1881,6 +2947,7 @@ public record CreateMessageRequest( // @formatter:off @JsonProperty("maxTokens") Integer maxTokens, @JsonProperty("stopSequences") List stopSequences, @JsonProperty("metadata") Map metadata, + @JsonProperty("task") TaskMetadata task, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on // backwards compatibility constructor @@ -1888,7 +2955,15 @@ public CreateMessageRequest(List messages, ModelPreferences mod String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, Integer maxTokens, List stopSequences, Map metadata) { this(messages, modelPreferences, systemPrompt, includeContext, temperature, maxTokens, stopSequences, - metadata, null); + metadata, null, null); + } + + // backwards compatibility constructor with _meta + public CreateMessageRequest(List messages, ModelPreferences modelPreferences, + String systemPrompt, ContextInclusionStrategy includeContext, Double temperature, Integer maxTokens, + List stopSequences, Map metadata, Map meta) { + this(messages, modelPreferences, systemPrompt, includeContext, temperature, maxTokens, stopSequences, + metadata, null, meta); } public enum ContextInclusionStrategy { @@ -1921,6 +2996,8 @@ public static class Builder { private Map metadata; + private TaskMetadata task; + private Map meta; public Builder messages(List messages) { @@ -1963,6 +3040,26 @@ public Builder metadata(Map metadata) { return this; } + /** + * Sets task metadata for task-augmented execution. + * @param task the task metadata + * @return this builder + */ + public Builder task(TaskMetadata task) { + this.task = task; + return this; + } + + /** + * Sets task metadata for task-augmented execution with the specified TTL. + * @param ttl requested duration in milliseconds to retain task from creation + * @return this builder + */ + public Builder task(Long ttl) { + this.task = new TaskMetadata(ttl); + return this; + } + public Builder meta(Map meta) { this.meta = meta; return this; @@ -1978,7 +3075,7 @@ public Builder progressToken(Object progressToken) { public CreateMessageRequest build() { return new CreateMessageRequest(messages, modelPreferences, systemPrompt, includeContext, temperature, - maxTokens, stopSequences, metadata, meta); + maxTokens, stopSequences, metadata, task, meta); } } @@ -2003,7 +3100,7 @@ public record CreateMessageResult( // @formatter:off @JsonProperty("content") Content content, @JsonProperty("model") String model, @JsonProperty("stopReason") StopReason stopReason, - @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + @JsonProperty("_meta") Map meta) implements ClientTaskPayloadResult { // @formatter:on public enum StopReason { @@ -2095,6 +3192,9 @@ public CreateMessageResult build() { * @param message The message to present to the user * @param requestedSchema A restricted subset of JSON Schema. Only top-level * properties are allowed, without nesting + * @param task If specified, the caller is requesting task-augmented execution for + * this request. The request will return a CreateTaskResult immediately, and the + * actual result can be retrieved later via tasks/result. * @param meta See specification for notes on _meta usage */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -2102,11 +3202,17 @@ public CreateMessageResult build() { public record ElicitRequest( // @formatter:off @JsonProperty("message") String message, @JsonProperty("requestedSchema") Map requestedSchema, + @JsonProperty("task") TaskMetadata task, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on // backwards compatibility constructor public ElicitRequest(String message, Map requestedSchema) { - this(message, requestedSchema, null); + this(message, requestedSchema, null, null); + } + + // backwards compatibility constructor with _meta + public ElicitRequest(String message, Map requestedSchema, Map meta) { + this(message, requestedSchema, null, meta); } public static Builder builder() { @@ -2119,6 +3225,8 @@ public static class Builder { private Map requestedSchema; + private TaskMetadata task; + private Map meta; public Builder message(String message) { @@ -2131,6 +3239,26 @@ public Builder requestedSchema(Map requestedSchema) { return this; } + /** + * Sets task metadata for task-augmented execution. + * @param task the task metadata + * @return this builder + */ + public Builder task(TaskMetadata task) { + this.task = task; + return this; + } + + /** + * Sets task metadata for task-augmented execution with the specified TTL. + * @param ttl requested duration in milliseconds to retain task from creation + * @return this builder + */ + public Builder task(Long ttl) { + this.task = new TaskMetadata(ttl); + return this; + } + public Builder meta(Map meta) { this.meta = meta; return this; @@ -2145,7 +3273,7 @@ public Builder progressToken(Object progressToken) { } public ElicitRequest build() { - return new ElicitRequest(message, requestedSchema, meta); + return new ElicitRequest(message, requestedSchema, task, meta); } } @@ -2166,7 +3294,7 @@ public ElicitRequest build() { public record ElicitResult( // @formatter:off @JsonProperty("action") Action action, @JsonProperty("content") Map content, - @JsonProperty("_meta") Map meta) implements Result { // @formatter:on + @JsonProperty("_meta") Map meta) implements ClientTaskPayloadResult { // @formatter:on public enum Action { @@ -2928,4 +4056,1423 @@ public ListRootsResult(List roots, String nextCursor) { } } + // --------------------------- + // Tasks + // --------------------------- + + /* + * Note on meta fields in task types: + * + * All task-related types (GetTaskRequest, GetTaskResult, CancelTaskRequest, + * CancelTaskResult, CreateTaskResult, etc.) include optional "_meta" fields that may + * be null. This is intentional - the MCP specification defines these as optional + * extension points for protocol-level metadata. Callers should always check for null + * before accessing meta fields. When not using metadata extensions, simply pass null. + */ + + /** + * A request to retrieve the state of a task. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class GetTaskRequest implements Request { + + private final String taskId; + + private final Map meta; + + @JsonCreator + private GetTaskRequest( // @formatter:off + @JsonProperty("taskId") String taskId, + @JsonProperty("_meta") Map meta) { // @formatter:on + Assert.hasText(taskId, "taskId must not be empty"); + this.taskId = taskId; + this.meta = meta; + } + + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a new builder for GetTaskRequest. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link GetTaskRequest}. + */ + public static class Builder { + + private String taskId; + + private Map meta; + + /** + * Sets the task identifier. + * @param taskId the task identifier + * @return this builder + */ + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + /** + * Sets the metadata. + * @param meta the metadata map + * @return this builder + */ + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + /** + * Builds a new {@link GetTaskRequest} instance. + * @return a new GetTaskRequest instance + */ + public GetTaskRequest build() { + return new GetTaskRequest(taskId, meta); + } + + } + + } + + /** + * The response to a tasks/get request. Contains all Task fields plus Result metadata. + * + *

+ * Design Note: This type is structurally identical to + * {@link CancelTaskResult} but kept as a separate type for compile-time type safety. + * This ensures that code expecting a {@code GetTaskResult} cannot accidentally + * receive a {@code CancelTaskResult} and vice versa, making API boundaries explicit + * in method signatures. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class GetTaskResult implements Result { + + private final String taskId; + + private final TaskStatus status; + + private final String statusMessage; + + private final String createdAt; + + private final String lastUpdatedAt; + + private final Long ttl; + + private final Long pollInterval; + + private final Map meta; + + @JsonCreator + private GetTaskResult( // @formatter:off + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") @Nullable String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") @Nullable Long ttl, + @JsonProperty("pollInterval") @Nullable Long pollInterval, + @JsonProperty("_meta") Map meta) { // @formatter:on + Assert.hasText(taskId, "taskId must not be empty"); + Assert.notNull(status, "status must not be null"); + Assert.hasText(createdAt, "createdAt must not be empty"); + Assert.hasText(lastUpdatedAt, "lastUpdatedAt must not be empty"); + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; + this.meta = meta; + } + + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } + + /** + * Returns the task status. + * @return the task status + */ + @JsonProperty("status") + public TaskStatus status() { + return this.status; + } + + /** + * Returns the optional status message. + * @return the status message, or null + */ + @JsonProperty("statusMessage") + @Nullable + public String statusMessage() { + return this.statusMessage; + } + + /** + * Returns the creation timestamp. + * @return the ISO 8601 creation timestamp + */ + @JsonProperty("createdAt") + public String createdAt() { + return this.createdAt; + } + + /** + * Returns the last updated timestamp. + * @return the ISO 8601 last updated timestamp + */ + @JsonProperty("lastUpdatedAt") + public String lastUpdatedAt() { + return this.lastUpdatedAt; + } + + /** + * Returns the TTL (time-to-live) in milliseconds. + * @return the TTL, or null for unlimited + */ + @JsonProperty("ttl") + @Nullable + public Long ttl() { + return this.ttl; + } + + /** + * Returns the suggested polling interval in milliseconds. + * @return the polling interval, or null + */ + @JsonProperty("pollInterval") + @Nullable + public Long pollInterval() { + return this.pollInterval; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a GetTaskResult from a Task. + * @param task the task to convert + * @return a new GetTaskResult + */ + public static GetTaskResult fromTask(Task task) { + return new GetTaskResult(task.taskId(), task.status(), task.statusMessage(), task.createdAt(), + task.lastUpdatedAt(), task.ttl(), task.pollInterval(), null); + } + + /** + * Converts this result to a Task. + * @return a Task with the same field values + */ + public Task toTask() { + return Task.builder() + .taskId(taskId) + .status(status) + .statusMessage(statusMessage) + .createdAt(createdAt) + .lastUpdatedAt(lastUpdatedAt) + .ttl(ttl) + .pollInterval(pollInterval) + .build(); + } + + /** + * Checks if the task is in a terminal state (completed, failed, or cancelled). + * @return true if the task is in a terminal state + */ + public boolean isTerminal() { + return status.isTerminal(); + } + + /** + * Creates a new builder for GetTaskResult. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link GetTaskResult}. + */ + public static class Builder { + + private String taskId; + + private TaskStatus status; + + private String statusMessage; + + private String createdAt; + + private String lastUpdatedAt; + + private Long ttl; + + private Long pollInterval; + + private Map meta; + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + public Builder statusMessage(String statusMessage) { + this.statusMessage = statusMessage; + return this; + } + + public Builder createdAt(String createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder lastUpdatedAt(String lastUpdatedAt) { + this.lastUpdatedAt = lastUpdatedAt; + return this; + } + + public Builder ttl(Long ttl) { + this.ttl = ttl; + return this; + } + + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public GetTaskResult build() { + return new GetTaskResult(taskId, status, statusMessage, createdAt, lastUpdatedAt, ttl, pollInterval, + meta); + } + + } + + } + + /** + * A request to retrieve the result payload of a completed task. + * + *

+ * This corresponds to the {@code tasks/result} method in the MCP protocol. The name + * "Payload" distinguishes the actual result data (e.g., {@link CallToolResult}, + * {@link CreateMessageResult}) from the task status information returned by + * {@link GetTaskResult}. + * + *

+ * The response type depends on what created the task: + *

    + *
  • Tool calls: {@link CallToolResult}
  • + *
  • Sampling requests: {@link CreateMessageResult}
  • + *
  • Elicitation requests: {@link ElicitResult}
  • + *
+ * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class GetTaskPayloadRequest implements Request { + + private final String taskId; + + private final Map meta; + + @JsonCreator + private GetTaskPayloadRequest( // @formatter:off + @JsonProperty("taskId") String taskId, + @JsonProperty("_meta") Map meta) { // @formatter:on + Assert.hasText(taskId, "taskId must not be empty"); + this.taskId = taskId; + this.meta = meta; + } + + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a new builder for GetTaskPayloadRequest. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link GetTaskPayloadRequest}. + */ + public static class Builder { + + private String taskId; + + private Map meta; + + /** + * Sets the task identifier. + * @param taskId the task identifier + * @return this builder + */ + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + /** + * Sets the metadata. + * @param meta the metadata map + * @return this builder + */ + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + /** + * Builds a new {@link GetTaskPayloadRequest} instance. + * @return a new GetTaskPayloadRequest instance + */ + public GetTaskPayloadRequest build() { + return new GetTaskPayloadRequest(taskId, meta); + } + + } + + } + + /** + * A request to cancel a task. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CancelTaskRequest implements Request { + + private final String taskId; + + private final Map meta; + + @JsonCreator + private CancelTaskRequest( // @formatter:off + @JsonProperty("taskId") String taskId, + @JsonProperty("_meta") Map meta) { // @formatter:on + Assert.hasText(taskId, "taskId must not be empty"); + this.taskId = taskId; + this.meta = meta; + } + + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a new builder for CancelTaskRequest. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link CancelTaskRequest}. + */ + public static class Builder { + + private String taskId; + + private Map meta; + + /** + * Sets the task identifier. + * @param taskId the task identifier + * @return this builder + */ + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + /** + * Sets the metadata. + * @param meta the metadata map + * @return this builder + */ + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + /** + * Builds a new {@link CancelTaskRequest} instance. + * @return a new CancelTaskRequest instance + */ + public CancelTaskRequest build() { + return new CancelTaskRequest(taskId, meta); + } + + } + + } + + /** + * The response to a tasks/cancel request. Contains all Task fields plus Result + * metadata. + * + *

+ * Design Note: This type is structurally identical to + * {@link GetTaskResult} but kept as a separate type for compile-time type safety. + * This ensures that code expecting a {@code CancelTaskResult} cannot accidentally + * receive a {@code GetTaskResult} and vice versa, making API boundaries explicit in + * method signatures. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CancelTaskResult implements Result { + + private final String taskId; + + private final TaskStatus status; + + private final String statusMessage; + + private final String createdAt; + + private final String lastUpdatedAt; + + private final Long ttl; + + private final Long pollInterval; + + private final Map meta; + + @JsonCreator + private CancelTaskResult( // @formatter:off + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") @Nullable String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") @Nullable Long ttl, + @JsonProperty("pollInterval") @Nullable Long pollInterval, + @JsonProperty("_meta") Map meta) { // @formatter:on + Assert.hasText(taskId, "taskId must not be empty"); + Assert.notNull(status, "status must not be null"); + Assert.hasText(createdAt, "createdAt must not be empty"); + Assert.hasText(lastUpdatedAt, "lastUpdatedAt must not be empty"); + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; + this.meta = meta; + } + + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } + + /** + * Returns the task status. + * @return the task status + */ + @JsonProperty("status") + public TaskStatus status() { + return this.status; + } + + /** + * Returns the optional status message. + * @return the status message, or null + */ + @JsonProperty("statusMessage") + @Nullable + public String statusMessage() { + return this.statusMessage; + } + + /** + * Returns the creation timestamp. + * @return the ISO 8601 creation timestamp + */ + @JsonProperty("createdAt") + public String createdAt() { + return this.createdAt; + } + + /** + * Returns the last updated timestamp. + * @return the ISO 8601 last updated timestamp + */ + @JsonProperty("lastUpdatedAt") + public String lastUpdatedAt() { + return this.lastUpdatedAt; + } + + /** + * Returns the TTL (time-to-live) in milliseconds. + * @return the TTL, or null for unlimited + */ + @JsonProperty("ttl") + @Nullable + public Long ttl() { + return this.ttl; + } + + /** + * Returns the suggested polling interval in milliseconds. + * @return the polling interval, or null + */ + @JsonProperty("pollInterval") + @Nullable + public Long pollInterval() { + return this.pollInterval; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a CancelTaskResult from a Task. + * @param task the task to convert + * @return a new CancelTaskResult + */ + public static CancelTaskResult fromTask(Task task) { + return new CancelTaskResult(task.taskId(), task.status(), task.statusMessage(), task.createdAt(), + task.lastUpdatedAt(), task.ttl(), task.pollInterval(), null); + } + + /** + * Checks if the task is in a terminal state (completed, failed, or cancelled). + * @return true if the task is in a terminal state + */ + public boolean isTerminal() { + return status.isTerminal(); + } + + /** + * Creates a new builder for CancelTaskResult. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link CancelTaskResult}. + */ + public static class Builder { + + private String taskId; + + private TaskStatus status; + + private String statusMessage; + + private String createdAt; + + private String lastUpdatedAt; + + private Long ttl; + + private Long pollInterval; + + private Map meta; + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + public Builder statusMessage(String statusMessage) { + this.statusMessage = statusMessage; + return this; + } + + public Builder createdAt(String createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder lastUpdatedAt(String lastUpdatedAt) { + this.lastUpdatedAt = lastUpdatedAt; + return this; + } + + public Builder ttl(Long ttl) { + this.ttl = ttl; + return this; + } + + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public CancelTaskResult build() { + return new CancelTaskResult(taskId, status, statusMessage, createdAt, lastUpdatedAt, ttl, pollInterval, + meta); + } + + } + + } + + /** + * The response to a tasks/list request. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class ListTasksResult implements Result { + + private final List tasks; + + private final String nextCursor; + + private final Map meta; + + @JsonCreator + private ListTasksResult( // @formatter:off + @JsonProperty("tasks") List tasks, + @JsonProperty("nextCursor") @Nullable String nextCursor, + @JsonProperty("_meta") Map meta) { // @formatter:on + this.tasks = tasks != null ? tasks : List.of(); + this.nextCursor = nextCursor; + this.meta = meta; + } + + /** + * Returns the list of tasks. + * @return the tasks list, never null + */ + @JsonProperty("tasks") + public List tasks() { + return this.tasks; + } + + /** + * Returns the next cursor for pagination. + * @return the next cursor, or null if no more results + */ + @JsonProperty("nextCursor") + @Nullable + public String nextCursor() { + return this.nextCursor; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a new builder for ListTasksResult. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link ListTasksResult}. + */ + public static class Builder { + + private List tasks; + + private String nextCursor; + + private Map meta; + + /** + * Sets the list of tasks. + * @param tasks the tasks list + * @return this builder + */ + public Builder tasks(List tasks) { + this.tasks = tasks; + return this; + } + + /** + * Sets the next cursor for pagination. + * @param nextCursor the next cursor + * @return this builder + */ + public Builder nextCursor(String nextCursor) { + this.nextCursor = nextCursor; + return this; + } + + /** + * Sets the metadata. + * @param meta the metadata map + * @return this builder + */ + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + /** + * Builds a new {@link ListTasksResult} instance. + * @return a new ListTasksResult instance + */ + public ListTasksResult build() { + return new ListTasksResult(tasks, nextCursor, meta); + } + + } + + } + + /** + * A response to a task-augmented request, indicating that a task has been created. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class CreateTaskResult implements Result { + + private final Task task; + + private final Map meta; + + @JsonCreator + private CreateTaskResult( // @formatter:off + @JsonProperty("task") Task task, + @JsonProperty("_meta") Map meta) { // @formatter:on + this.task = task; + this.meta = meta; + } + + /** + * Returns the created task. + * @return the task + */ + @JsonProperty("task") + public Task task() { + return this.task; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a new builder for CreateTaskResult. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link CreateTaskResult}. + */ + public static class Builder { + + private Task task; + + private Map meta; + + /** + * Sets the task. + * @param task the task + * @return this builder + */ + public Builder task(Task task) { + this.task = task; + return this; + } + + /** + * Sets the metadata. + * @param meta the metadata map + * @return this builder + */ + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + /** + * Builds a new {@link CreateTaskResult} instance. + * @return a new CreateTaskResult instance + */ + public CreateTaskResult build() { + return new CreateTaskResult(task, meta); + } + + } + + } + + /** + * An optional notification from the receiver to the requestor, informing them that a + * task's status has changed. Receivers are not required to send these notifications. + * + *

+ * Use {@link #builder()} to create instances. + * + * @see Builder + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class TaskStatusNotification implements Notification { + + private final String taskId; + + private final TaskStatus status; + + private final String statusMessage; + + private final String createdAt; + + private final String lastUpdatedAt; + + private final Long ttl; + + private final Long pollInterval; + + private final Map meta; + + @JsonCreator + private TaskStatusNotification( // @formatter:off + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") @Nullable String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") @Nullable Long ttl, + @JsonProperty("pollInterval") @Nullable Long pollInterval, + @JsonProperty("_meta") Map meta) { // @formatter:on + Assert.hasText(taskId, "taskId must not be empty"); + Assert.notNull(status, "status must not be null"); + Assert.hasText(createdAt, "createdAt must not be empty"); + Assert.hasText(lastUpdatedAt, "lastUpdatedAt must not be empty"); + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; + this.meta = meta; + } + + /** + * Returns the task identifier. + * @return the task identifier + */ + @JsonProperty("taskId") + public String taskId() { + return this.taskId; + } + + /** + * Returns the task status. + * @return the task status + */ + @JsonProperty("status") + public TaskStatus status() { + return this.status; + } + + /** + * Returns the optional status message. + * @return the status message, or null + */ + @JsonProperty("statusMessage") + @Nullable + public String statusMessage() { + return this.statusMessage; + } + + /** + * Returns the creation timestamp. + * @return the ISO 8601 creation timestamp + */ + @JsonProperty("createdAt") + public String createdAt() { + return this.createdAt; + } + + /** + * Returns the last updated timestamp. + * @return the ISO 8601 last updated timestamp + */ + @JsonProperty("lastUpdatedAt") + public String lastUpdatedAt() { + return this.lastUpdatedAt; + } + + /** + * Returns the TTL (time-to-live) in milliseconds. + * @return the TTL, or null for unlimited + */ + @JsonProperty("ttl") + @Nullable + public Long ttl() { + return this.ttl; + } + + /** + * Returns the suggested polling interval in milliseconds. + * @return the polling interval, or null + */ + @JsonProperty("pollInterval") + @Nullable + public Long pollInterval() { + return this.pollInterval; + } + + /** + * Returns the metadata. + * @return the metadata map, or null + */ + @JsonProperty("_meta") + public Map meta() { + return this.meta; + } + + /** + * Creates a TaskStatusNotification from a Task. + * @param task the task to convert + * @return a new TaskStatusNotification + */ + public static TaskStatusNotification fromTask(Task task) { + return new TaskStatusNotification(task.taskId(), task.status(), task.statusMessage(), task.createdAt(), + task.lastUpdatedAt(), task.ttl(), task.pollInterval(), null); + } + + /** + * Checks if the task is in a terminal state (completed, failed, or cancelled). + * @return true if the task is in a terminal state + */ + public boolean isTerminal() { + return status.isTerminal(); + } + + /** + * Creates a new builder for TaskStatusNotification. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link TaskStatusNotification}. + */ + public static class Builder { + + private String taskId; + + private TaskStatus status; + + private String statusMessage; + + private String createdAt; + + private String lastUpdatedAt; + + private Long ttl; + + private Long pollInterval; + + private Map meta; + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + public Builder statusMessage(String statusMessage) { + this.statusMessage = statusMessage; + return this; + } + + public Builder createdAt(String createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder lastUpdatedAt(String lastUpdatedAt) { + this.lastUpdatedAt = lastUpdatedAt; + return this; + } + + public Builder ttl(Long ttl) { + this.ttl = ttl; + return this; + } + + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public TaskStatusNotification build() { + return new TaskStatusNotification(taskId, status, statusMessage, createdAt, lastUpdatedAt, ttl, + pollInterval, meta); + } + + } + + } + + // -------------------------- + // Streaming Response Messages + // -------------------------- + + /** + * Sealed interface representing messages yielded during streaming request processing. + * Used by {@code callToolStream()} and other streaming APIs to provide real-time + * updates about task execution progress. + * + *

+ * The message types are: + *

    + *
  • {@link TaskCreatedMessage} - First message for task-augmented requests, + * contains the created task + *
  • {@link TaskStatusMessage} - Status update during task polling + *
  • {@link ResultMessage} - Final successful result (terminal) + *
  • {@link ErrorMessage} - Error occurred (terminal) + *
+ * + *

Streaming Order for Task-Augmented Requests

+ *

+ * For task-augmented requests (those with {@code TaskMetadata}), messages are yielded + * in this order: + *

    + *
  1. One {@link TaskCreatedMessage} - immediately after task creation
  2. + *
  3. Zero or more {@link TaskStatusMessage} - during polling while task is + * running
  4. + *
  5. One terminal message: either {@link ResultMessage} (success) or + * {@link ErrorMessage} (failure)
  6. + *
+ * + *

+ * For non-task requests, the stream yields only a single {@link ResultMessage} or + * {@link ErrorMessage}. + * + * @param The type of result expected from the request + */ + public sealed interface ResponseMessage + permits TaskCreatedMessage, TaskStatusMessage, ResultMessage, ErrorMessage { + + /** + * Returns the message type identifier. + * @return the type string ("taskCreated", "taskStatus", "result", or "error") + */ + String type(); + + } + + /** + * Message indicating a task has been created for a task-augmented request. This is + * the first message yielded for task-augmented requests. + * + * @param The type of result expected from the request + */ + public static final class TaskCreatedMessage implements ResponseMessage { + + private final Task task; + + private TaskCreatedMessage(Task task) { + this.task = task; + } + + /** + * Returns the created task. + * @return the task + */ + public Task task() { + return this.task; + } + + @Override + public String type() { + return "taskCreated"; + } + + /** + * Creates a new TaskCreatedMessage with the given task. + * @param the result type + * @param task the task + * @return a new TaskCreatedMessage + */ + public static TaskCreatedMessage of(Task task) { + return new TaskCreatedMessage<>(task); + } + + } + + /** + * Message indicating a task status update during polling. Yielded periodically while + * waiting for a task to reach a terminal state. + * + * @param The type of result expected from the request + */ + public static final class TaskStatusMessage implements ResponseMessage { + + private final Task task; + + private TaskStatusMessage(Task task) { + this.task = task; + } + + /** + * Returns the task with updated status. + * @return the task + */ + public Task task() { + return this.task; + } + + @Override + public String type() { + return "taskStatus"; + } + + /** + * Creates a new TaskStatusMessage with the given task. + * @param the result type + * @param task the task + * @return a new TaskStatusMessage + */ + public static TaskStatusMessage of(Task task) { + return new TaskStatusMessage<>(task); + } + + } + + /** + * Message containing the final successful result. This is a terminal message - no + * more messages will be yielded after this. + * + * @param The type of result + * @param result The final result + */ + public static final class ResultMessage implements ResponseMessage { + + private final T result; + + private ResultMessage(T result) { + this.result = result; + } + + /** + * Returns the final result. + * @return the result + */ + public T result() { + return this.result; + } + + @Override + public String type() { + return "result"; + } + + /** + * Creates a new ResultMessage with the given result. + * @param the result type + * @param result the result + * @return a new ResultMessage + */ + public static ResultMessage of(T result) { + return new ResultMessage<>(result); + } + + } + + /** + * Message indicating an error occurred. This is a terminal message - no more messages + * will be yielded after this. + * + * @param The type of result expected from the request + * @param error The error that occurred + */ + public static final class ErrorMessage implements ResponseMessage { + + private final McpError error; + + private ErrorMessage(McpError error) { + this.error = error; + } + + /** + * Returns the error that occurred. + * @return the error + */ + public McpError error() { + return this.error; + } + + @Override + public String type() { + return "error"; + } + + /** + * Creates a new ErrorMessage with the given error. + * @param the result type + * @param error the error + * @return a new ErrorMessage + */ + public static ErrorMessage of(McpError error) { + return new ErrorMessage<>(error); + } + + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationTest.java new file mode 100644 index 000000000..19184cf80 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/AbstractTaskAwareToolSpecificationTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.ToolAnnotations; + +/** + * Abstract base class for testing task-aware tool specification builders. + * + *

+ * This class provides common test cases for builder validation and defaults that are + * shared between {@link TaskAwareAsyncToolSpecificationTest} and + * {@link TaskAwareSyncToolSpecificationTest}. + * + * @param the specification type (e.g., TaskAwareAsyncToolSpecification) + * @param the builder type + */ +abstract class AbstractTaskAwareToolSpecificationTest> { + + protected static final JsonSchema TEST_SCHEMA = new JsonSchema("object", Map.of("input", Map.of("type", "string")), + null, null, null, null); + + /** + * Creates a new builder instance. + * @return a new builder + */ + protected abstract B createBuilder(); + + /** + * Configures the builder with a minimal valid createTask handler. + * @param builder the builder to configure + * @return the same builder for chaining + */ + protected abstract B withMinimalCreateTaskHandler(B builder); + + /** + * Builds the specification from the builder. + * @param builder the configured builder + * @return the built specification + */ + protected abstract S build(B builder); + + /** + * Gets the tool from the specification. + * @param spec the specification + * @return the tool definition + */ + protected abstract Tool getTool(S spec); + + /** + * Gets the EMPTY_INPUT_SCHEMA constant from the specification class. + * @return the empty input schema + */ + protected abstract JsonSchema getEmptyInputSchema(); + + // ------------------------------------------ + // Builder Validation Tests + // ------------------------------------------ + + @Test + void builderShouldThrowExceptionWhenNameIsNull() { + B builder = createBuilder(); + withMinimalCreateTaskHandler(builder); + + assertThatThrownBy(() -> build(builder)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("name must not be empty"); + } + + @Test + void builderShouldThrowExceptionWhenNameIsEmpty() { + B builder = createBuilder(); + builder.name(""); + withMinimalCreateTaskHandler(builder); + + assertThatThrownBy(() -> build(builder)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("name must not be empty"); + } + + @Test + void builderShouldThrowExceptionWhenCreateTaskHandlerIsNull() { + B builder = createBuilder(); + builder.name("test-tool"); + // Don't set createTask handler + + assertThatThrownBy(() -> build(builder)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("createTaskHandler must not be null"); + } + + // ------------------------------------------ + // Default Values Tests + // ------------------------------------------ + + @Test + void builderShouldUseDefaultValuesWhenOptionalFieldsNotSet() { + B builder = createBuilder(); + builder.name("minimal-tool"); + withMinimalCreateTaskHandler(builder); + + S spec = build(builder); + Tool tool = getTool(spec); + + assertThat(tool.name()).isEqualTo("minimal-tool"); + // Description defaults to name + assertThat(tool.description()).isEqualTo("minimal-tool"); + // InputSchema defaults to empty object schema + assertThat(tool.inputSchema()).isEqualTo(getEmptyInputSchema()); + // TaskSupportMode defaults to REQUIRED + assertThat(tool.execution().taskSupport()).isEqualTo(TaskSupportMode.REQUIRED); + } + + @Test + void builderShouldSetAnnotationsCorrectly() { + ToolAnnotations annotations = new ToolAnnotations("Test Tool", true, false, true, false, true); + + B builder = createBuilder(); + builder.name("annotated-tool").annotations(annotations); + withMinimalCreateTaskHandler(builder); + + S spec = build(builder); + Tool tool = getTool(spec); + + assertThat(tool.annotations()).isEqualTo(annotations); + } + + @Test + void emptyInputSchemaShouldBeObjectType() { + JsonSchema emptySchema = getEmptyInputSchema(); + assertThat(emptySchema.type()).isEqualTo("object"); + assertThat(emptySchema.properties()).isNull(); + } + + @Test + void taskSupportModeNullShouldDefaultToRequired() { + B builder = createBuilder(); + builder.name("null-mode-tool").taskSupportMode(null); + withMinimalCreateTaskHandler(builder); + + S spec = build(builder); + Tool tool = getTool(spec); + + assertThat(tool.execution().taskSupport()).isEqualTo(TaskSupportMode.REQUIRED); + } + + @Test + void builderShouldSupportAllTaskSupportModes() { + for (TaskSupportMode mode : TaskSupportMode.values()) { + B builder = createBuilder(); + builder.name("mode-test-tool").taskSupportMode(mode); + withMinimalCreateTaskHandler(builder); + + S spec = build(builder); + Tool tool = getTool(spec); + + assertThat(tool.execution().taskSupport()).isEqualTo(mode); + } + } + + @Test + void builderShouldCreateValidSpecification() { + B builder = createBuilder(); + builder.name("test-task-tool") + .description("A test task tool") + .inputSchema(TEST_SCHEMA) + .taskSupportMode(TaskSupportMode.REQUIRED); + withMinimalCreateTaskHandler(builder); + + S spec = build(builder); + Tool tool = getTool(spec); + + assertThat(spec).isNotNull(); + assertThat(tool).isNotNull(); + assertThat(tool.name()).isEqualTo("test-task-tool"); + assertThat(tool.description()).isEqualTo("A test task tool"); + assertThat(tool.inputSchema()).isEqualTo(TEST_SCHEMA); + assertThat(tool.execution().taskSupport()).isEqualTo(TaskSupportMode.REQUIRED); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueueTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueueTests.java new file mode 100644 index 000000000..d1e495a77 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskMessageQueueTests.java @@ -0,0 +1,288 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import static io.modelcontextprotocol.experimental.tasks.TaskTestUtils.runConcurrent; +import static org.assertj.core.api.Assertions.assertThat; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +/** + * Unit tests for {@link InMemoryTaskMessageQueue}. + */ +class InMemoryTaskMessageQueueTests { + + private InMemoryTaskMessageQueue messageQueue; + + @BeforeEach + void setUp() { + messageQueue = new InMemoryTaskMessageQueue(); + } + + // Helper to create test notifications + private QueuedMessage.Notification notification(String method) { + return new QueuedMessage.Notification(method, null); + } + + // Helper to assert notification method + private void assertNotificationMethod(QueuedMessage msg, String expectedMethod) { + assertThat(msg).isInstanceOf(QueuedMessage.Notification.class); + assertThat(((QueuedMessage.Notification) msg).method()).isEqualTo(expectedMethod); + } + + @Test + void testEnqueueThenDequeueReturnsSameMessage() { + StepVerifier + .create(messageQueue.enqueue("task-1", notification("test/method"), null) + .then(messageQueue.dequeue("task-1"))) + .consumeNextWith(msg -> assertNotificationMethod(msg, "test/method")) + .verifyComplete(); + } + + @Test + void testMultipleEnqueuesRespectFifoOrdering() { + StepVerifier + .create(messageQueue.enqueue("task-1", notification("method-1"), null) + .then(messageQueue.enqueue("task-1", notification("method-2"), null)) + .then(messageQueue.enqueue("task-1", notification("method-3"), null)) + .then(messageQueue.dequeue("task-1"))) + .consumeNextWith(msg -> assertNotificationMethod(msg, "method-1")) + .verifyComplete(); + + StepVerifier.create(messageQueue.dequeue("task-1")) + .consumeNextWith(msg -> assertNotificationMethod(msg, "method-2")) + .verifyComplete(); + + StepVerifier.create(messageQueue.dequeue("task-1")) + .consumeNextWith(msg -> assertNotificationMethod(msg, "method-3")) + .verifyComplete(); + } + + @Test + void testDequeueAllReturnsAllMessagesInOrder() { + StepVerifier + .create(messageQueue.enqueue("task-1", notification("method-1"), null) + .then(messageQueue.enqueue("task-1", notification("method-2"), null)) + .then(messageQueue.enqueue("task-1", notification("method-3"), null)) + .then(messageQueue.dequeueAll("task-1"))) + .consumeNextWith(messages -> { + assertThat(messages).hasSize(3); + assertNotificationMethod(messages.get(0), "method-1"); + assertNotificationMethod(messages.get(1), "method-2"); + assertNotificationMethod(messages.get(2), "method-3"); + }) + .verifyComplete(); + + // Queue should be empty after dequeueAll + StepVerifier.create(messageQueue.dequeue("task-1")).verifyComplete(); + } + + @Test + void testMaxSizeEnforcementDropsOldestMessage() { + // Max size of 2 - oldest should be dropped + StepVerifier + .create(messageQueue.enqueue("task-1", notification("method-1"), 2) + .then(messageQueue.enqueue("task-1", notification("method-2"), 2)) + .then(messageQueue.enqueue("task-1", notification("method-3"), 2)) + .then(messageQueue.dequeueAll("task-1"))) + .consumeNextWith(messages -> { + assertThat(messages).hasSize(2); + assertNotificationMethod(messages.get(0), "method-2"); + assertNotificationMethod(messages.get(1), "method-3"); + }) + .verifyComplete(); + } + + @Test + void testQueueIsolationPerTaskId() { + StepVerifier + .create(messageQueue.enqueue("task-1", notification("task1-method"), null) + .then(messageQueue.enqueue("task-2", notification("task2-method"), null)) + .then(messageQueue.dequeue("task-1"))) + .consumeNextWith(msg -> assertNotificationMethod(msg, "task1-method")) + .verifyComplete(); + + StepVerifier.create(messageQueue.dequeue("task-2")) + .consumeNextWith(msg -> assertNotificationMethod(msg, "task2-method")) + .verifyComplete(); + } + + @Test + void testDequeueFromEmptyQueue() { + // Dequeue from non-existent task should complete empty + StepVerifier.create(messageQueue.dequeue("nonexistent")).verifyComplete(); + } + + @Test + void testDequeueAllFromEmptyQueue() { + // dequeueAll from non-existent task should return empty list + StepVerifier.create(messageQueue.dequeueAll("nonexistent")).consumeNextWith(messages -> { + assertThat(messages).isEmpty(); + }).verifyComplete(); + } + + @Test + void testClearRemovesAllMessagesForTask() { + messageQueue.enqueue("task-1", notification("method-1"), null) + .then(messageQueue.enqueue("task-1", notification("method-2"), null)) + .block(); + + messageQueue.clear("task-1"); + StepVerifier.create(messageQueue.dequeue("task-1")).verifyComplete(); + } + + @Test + void testClearAllRemovesAllQueues() { + messageQueue.enqueue("task-1", notification("task1-method"), null) + .then(messageQueue.enqueue("task-2", notification("task2-method"), null)) + .block(); + + messageQueue.clearAll(); + StepVerifier.create(messageQueue.dequeue("task-1")).verifyComplete(); + StepVerifier.create(messageQueue.dequeue("task-2")).verifyComplete(); + } + + @Test + void testDifferentMessageTypes() { + var request = new QueuedMessage.Request("req-1", "sampling/createMessage", + new McpSchema.CreateMessageRequest(null, null, null, null, null, null, null, null)); + var response = new QueuedMessage.Response("req-1", + new McpSchema.CreateMessageResult(null, null, null, null, null)); + var notification = new QueuedMessage.Notification("notifications/progress", null); + + StepVerifier + .create(messageQueue.enqueue("task-1", request, null) + .then(messageQueue.enqueue("task-1", response, null)) + .then(messageQueue.enqueue("task-1", notification, null)) + .then(messageQueue.dequeueAll("task-1"))) + .consumeNextWith(messages -> { + assertThat(messages).hasSize(3); + assertThat(messages.get(0)).isInstanceOf(QueuedMessage.Request.class); + assertThat(messages.get(1)).isInstanceOf(QueuedMessage.Response.class); + assertThat(messages.get(2)).isInstanceOf(QueuedMessage.Notification.class); + }) + .verifyComplete(); + } + + // ------------------------------------------ + // Concurrency Tests + // ------------------------------------------ + + @Test + void testConcurrentEnqueueDequeuePreservesConsistency() throws InterruptedException { + String taskId = "concurrent-task"; + int totalOps = 150; // 100 enqueues + 50 dequeues + + runConcurrent(totalOps, 30, i -> { + if (i < 100) { + messageQueue.enqueue(taskId, notification("method-" + i), null).block(); + } + else { + messageQueue.dequeue(taskId).block(); + } + }); + + // Remaining messages = enqueues - dequeues = 100 - 50 = 50 + var remaining = messageQueue.dequeueAll(taskId).block(); + assertThat(remaining).isNotNull(); + assertThat(remaining).hasSize(50); + } + + @Test + void testConcurrentEnqueueToMultipleQueues() throws InterruptedException { + int numTasks = 10; + int messagesPerTask = 50; + + runConcurrent(numTasks * messagesPerTask, 20, i -> { + String taskId = "task-" + (i / messagesPerTask); + messageQueue.enqueue(taskId, notification("method-" + i), null).block(); + }); + + // Each task queue should have the right number of messages + for (int t = 0; t < numTasks; t++) { + var messages = messageQueue.dequeueAll("task-" + t).block(); + assertThat(messages).hasSize(messagesPerTask); + } + } + + // ------------------------------------------ + // Edge Case Tests + // ------------------------------------------ + + @Test + void testMaxSizeOne() { + // Queue with max size 1 should only keep the most recent message + StepVerifier + .create(messageQueue.enqueue("task-1", notification("method-1"), 1) + .then(messageQueue.enqueue("task-1", notification("method-2"), 1)) + .then(messageQueue.dequeueAll("task-1"))) + .consumeNextWith(messages -> { + assertThat(messages).hasSize(1); + assertNotificationMethod(messages.get(0), "method-2"); + }) + .verifyComplete(); + } + + @Test + void testEnqueueAfterClear() { + messageQueue.enqueue("task-1", notification("method-1"), null).block(); + messageQueue.clear("task-1"); + messageQueue.enqueue("task-1", notification("method-2"), null).block(); + + StepVerifier.create(messageQueue.dequeue("task-1")) + .consumeNextWith(msg -> assertNotificationMethod(msg, "method-2")) + .verifyComplete(); + } + + // ------------------------------------------ + // Validation Tests + // ------------------------------------------ + + @Test + void testEnqueueWithMaxSizeBelowMinimumThrows() { + // maxSize of 0 should throw + StepVerifier.create(messageQueue.enqueue("task-1", notification("test"), 0)) + .expectError(IllegalArgumentException.class) + .verify(); + } + + @Test + void testEnqueueWithMaxSizeAboveMaximumThrows() { + // maxSize above MAX_ALLOWED_QUEUE_SIZE should throw + StepVerifier + .create(messageQueue.enqueue("task-1", notification("test"), TaskDefaults.MAX_ALLOWED_QUEUE_SIZE + 1)) + .expectError(IllegalArgumentException.class) + .verify(); + } + + @Test + void testEnqueueWithMaxSizeAtBoundary() { + // maxSize at exactly MAX_ALLOWED_QUEUE_SIZE should work + StepVerifier.create(messageQueue.enqueue("task-1", notification("test"), TaskDefaults.MAX_ALLOWED_QUEUE_SIZE)) + .verifyComplete(); + + // Verify it was added + StepVerifier.create(messageQueue.dequeue("task-1")) + .consumeNextWith(msg -> assertNotificationMethod(msg, "test")) + .verifyComplete(); + } + + @Test + void testClearTaskReactive() { + // Add some messages + messageQueue.enqueue("task-1", notification("method-1"), null).block(); + messageQueue.enqueue("task-1", notification("method-2"), null).block(); + + // Clear using the reactive method + StepVerifier.create(messageQueue.clearTask("task-1")).verifyComplete(); + + // Verify queue is empty + StepVerifier.create(messageQueue.dequeue("task-1")).verifyComplete(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStoreTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStoreTests.java new file mode 100644 index 000000000..53f7b8243 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/InMemoryTaskStoreTests.java @@ -0,0 +1,1277 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import static io.modelcontextprotocol.experimental.tasks.TaskTestUtils.runConcurrent; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.test.StepVerifier; + +/** + * Unit tests for {@link InMemoryTaskStore}. + */ +class InMemoryTaskStoreTests { + + private InMemoryTaskStore taskStore; + + @BeforeEach + void setUp() { + taskStore = new InMemoryTaskStore<>(); + } + + @AfterEach + void tearDown() { + taskStore.shutdown().block(); + } + + // ------------------------------------------ + // Helper Methods + // ------------------------------------------ + + private CallToolResult createTestResult(String text) { + return CallToolResult.builder().content(List.of(new TextContent(null, null, text))).isError(false).build(); + } + + /** + * Creates a test request for use in CreateTaskOptions. Using CallToolRequest as the + * standard test request type. Static to support @MethodSource methods. + */ + private static McpSchema.CallToolRequest createTestRequest(String toolName) { + return new McpSchema.CallToolRequest(toolName, null); + } + + /** + * Creates default CreateTaskOptions with a test request. + */ + private CreateTaskOptions createDefaultOptions() { + return CreateTaskOptions.builder(createTestRequest("test-tool")).build(); + } + + /** + * Creates CreateTaskOptions with the specified session ID and a test request. + */ + private CreateTaskOptions createOptionsWithSession(String sessionId) { + return CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId(sessionId).build(); + } + + // ------------------------------------------ + // Basic Tests + // ------------------------------------------ + + @Test + void testCreateTaskWithCustomTtlAndPollInterval() { + var options = CreateTaskOptions.builder(createTestRequest("test-tool")) + .requestedTtl(60000L) + .pollInterval(1000L) + .build(); + + StepVerifier.create(taskStore.createTask(options)).consumeNextWith(task -> { + assertThat(task.taskId()).isNotNull().isNotEmpty(); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + assertThat(task.statusMessage()).isNull(); + assertThat(task.createdAt()).isNotNull(); + assertThat(task.lastUpdatedAt()).isNotNull(); + assertThat(task.ttl()).isEqualTo(60000L); + assertThat(task.pollInterval()).isEqualTo(1000L); + }).verifyComplete(); + } + + @Test + void testCreateTaskWithDefaults() { + StepVerifier.create(taskStore.createTask(createDefaultOptions())).consumeNextWith(task -> { + assertThat(task.taskId()).isNotNull().isNotEmpty(); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + assertThat(task.ttl()).isEqualTo(60000L); // Default TTL + assertThat(task.pollInterval()).isEqualTo(1000L); // Default poll interval + }).verifyComplete(); + } + + @Test + void testGetTaskReturnsCreatedTask() { + var createMono = taskStore.createTask(createDefaultOptions()); + + StepVerifier + .create(createMono + .flatMap(task -> taskStore.getTask(task.taskId(), null).map(GetTaskFromStoreResult::task))) + .consumeNextWith(task -> { + assertThat(task).isNotNull(); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + }) + .verifyComplete(); + } + + @Test + void testGetTaskNotFound() { + // When a task is not found, the Mono completes without emitting a value + // (Mono.fromCallable with null result completes empty) + StepVerifier.create(taskStore.getTask("nonexistent", null)).verifyComplete(); + } + + @Test + void testUpdateTaskStatus() { + var createMono = taskStore.createTask(createDefaultOptions()); + + StepVerifier + .create(createMono.flatMap(task -> taskStore + .updateTaskStatus(task.taskId(), null, TaskStatus.INPUT_REQUIRED, "Waiting for input") + .then(taskStore.getTask(task.taskId(), null).map(GetTaskFromStoreResult::task)))) + .consumeNextWith(task -> { + assertThat(task.status()).isEqualTo(TaskStatus.INPUT_REQUIRED); + assertThat(task.statusMessage()).isEqualTo("Waiting for input"); + }) + .verifyComplete(); + } + + @Test + void testUpdateTaskStatusDoesNotUpdateTerminalState() { + var createMono = taskStore.createTask(createDefaultOptions()); + + StepVerifier + .create(createMono + .flatMap(task -> taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.COMPLETED, null) + .then(taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.WORKING, "Should not update")) + .then(taskStore.getTask(task.taskId(), null).map(GetTaskFromStoreResult::task)))) + .consumeNextWith(task -> { + // Status should remain COMPLETED, not change back to WORKING + assertThat(task.status()).isEqualTo(TaskStatus.COMPLETED); + }) + .verifyComplete(); + } + + @Test + void testStoreTaskResultUpdatesTaskToTerminalStatus() { + var createMono = taskStore.createTask(createDefaultOptions()); + var toolResult = createTestResult("Success"); + + StepVerifier + .create(createMono + .flatMap(task -> taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, toolResult) + .then(taskStore.getTask(task.taskId(), null).map(GetTaskFromStoreResult::task)))) + .consumeNextWith(task -> assertThat(task.status()).isEqualTo(TaskStatus.COMPLETED)) + .verifyComplete(); + } + + @Test + void testGetTaskResultReturnsStoredPayload() { + var createMono = taskStore.createTask(createDefaultOptions()); + var toolResult = createTestResult("Success"); + + StepVerifier.create(createMono + .flatMap(task -> taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, toolResult) + .then(taskStore.getTaskResult(task.taskId(), null)))) + .consumeNextWith(result -> { + assertThat(result).isInstanceOf(CallToolResult.class); + assertThat(((CallToolResult) result).content()).hasSize(1); + }) + .verifyComplete(); + } + + @Test + void testListTasksReturnsPaginatedResults() { + // Create multiple tasks + var create1 = taskStore.createTask(createDefaultOptions()); + var create2 = taskStore.createTask(createDefaultOptions()); + var create3 = taskStore.createTask(createDefaultOptions()); + + StepVerifier.create(create1.then(create2).then(create3).then(taskStore.listTasks(null, null))) + .consumeNextWith(result -> { + assertThat(result.tasks()).hasSize(3); + assertThat(result.nextCursor()).isNull(); // No pagination needed for 3 + // tasks + }) + .verifyComplete(); + } + + @Test + void testRequestCancellation() { + var createMono = taskStore.createTask(createDefaultOptions()); + + StepVerifier.create(createMono.flatMap(task -> taskStore.requestCancellation(task.taskId(), null))) + .consumeNextWith(task -> { + assertThat(task.status()).isEqualTo(TaskStatus.CANCELLED); + assertThat(task.statusMessage()).isEqualTo("Cancellation requested"); + }) + .verifyComplete(); + } + + @Test + void testIsCancellationRequested() { + var createMono = taskStore.createTask(createDefaultOptions()); + + StepVerifier.create(createMono.flatMap(task -> taskStore.requestCancellation(task.taskId(), null) + .then(taskStore.isCancellationRequested(task.taskId(), null)))).consumeNextWith(isCancelled -> { + assertThat(isCancelled).isTrue(); + }).verifyComplete(); + } + + @Test + void testIsCancellationRequestedReturnsFalseForNonCancelledTask() { + var createMono = taskStore.createTask(createDefaultOptions()); + + StepVerifier.create(createMono.flatMap(task -> taskStore.isCancellationRequested(task.taskId(), null))) + .consumeNextWith(isCancelled -> { + assertThat(isCancelled).isFalse(); + }) + .verifyComplete(); + } + + // ------------------------------------------ + // Concurrency Tests + // ------------------------------------------ + + @Test + void testConcurrentTaskCreation() throws InterruptedException { + int numTasks = 100; + List taskIds = Collections.synchronizedList(new ArrayList<>()); + + runConcurrent(numTasks, numTasks, i -> { + var task = taskStore.createTask(createDefaultOptions()).block(); + if (task != null) { + taskIds.add(task.taskId()); + } + }); + + // Verify all tasks were created with unique IDs + assertThat(taskIds).hasSize(numTasks); + assertThat(new HashSet<>(taskIds)).hasSize(numTasks); + } + + @Test + void testConcurrentUpdateAndRead() throws InterruptedException { + var task = taskStore.createTask(createDefaultOptions()).block(); + assertThat(task).isNotNull(); + String taskId = task.taskId(); + + // Run concurrent updates and reads + runConcurrent(100, 20, i -> { + if (i % 2 == 0) { + taskStore.updateTaskStatus(taskId, null, TaskStatus.WORKING, "Update " + i).block(); + } + else { + taskStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block(); + } + }); + + // Task should still be valid + assertThat(taskStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block()).isNotNull(); + } + + @Test + void testConcurrentListTasksWhileCreating() throws InterruptedException { + int totalOps = 70; // 50 creates + 20 lists + + runConcurrent(totalOps, 30, i -> { + if (i < 50) { + taskStore.createTask(createDefaultOptions()).block(); + } + else { + taskStore.listTasks(null, null).block(); + } + }); + + // Final list should show all created tasks + var finalResult = taskStore.listTasks(null, null).block(); + assertThat(finalResult).isNotNull(); + assertThat(finalResult.tasks()).hasSize(50); + } + + // ------------------------------------------ + // Edge Case Tests + // ------------------------------------------ + + @Test + void testUpdateTaskStatusOnNonExistentTask() { + // updateTaskStatus on non-existent task should complete without error + // (silently ignores due to computeIfPresent) + StepVerifier.create(taskStore.updateTaskStatus("nonexistent", null, TaskStatus.COMPLETED, "done")) + .verifyComplete(); + } + + @Test + void testStoreTaskResultOnNonExistentTask() { + // storeTaskResult on non-existent task should throw McpError + StepVerifier + .create(taskStore.storeTaskResult("nonexistent", null, TaskStatus.COMPLETED, createTestResult("Result"))) + .expectError(io.modelcontextprotocol.spec.McpError.class) + .verify(); + } + + @Test + void testRequestCancellationOnNonExistentTask() { + // requestCancellation on non-existent task should return null + StepVerifier.create(taskStore.requestCancellation("nonexistent", null)).verifyComplete(); + } + + @Test + void testRequestCancellationOnAlreadyTerminalTask() { + // Per MCP spec: cancellation of tasks in terminal status MUST be rejected with + // error code -32602 + var createMono = taskStore.createTask(createDefaultOptions()); + + StepVerifier.create(createMono.flatMap(task -> + // First mark as completed + taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.COMPLETED, null) + // Then try to cancel - should fail with McpError + .then(taskStore.requestCancellation(task.taskId(), null)))).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(io.modelcontextprotocol.spec.McpError.class); + var mcpError = (io.modelcontextprotocol.spec.McpError) error; + assertThat(mcpError.getJsonRpcError().code()) + .isEqualTo(io.modelcontextprotocol.spec.McpSchema.ErrorCodes.INVALID_PARAMS); + assertThat(mcpError.getMessage()).contains("terminal"); + }).verify(); + } + + @Test + void testIsCancellationRequestedOnNonExistentTask() { + // Should return false for non-existent task + StepVerifier.create(taskStore.isCancellationRequested("nonexistent", null)).consumeNextWith(isCancelled -> { + assertThat(isCancelled).isFalse(); + }).verifyComplete(); + } + + @Test + void testListTasksWithInvalidCursor() { + // Create some tasks first + taskStore.createTask(createDefaultOptions()).block(); + taskStore.createTask(createDefaultOptions()).block(); + + // List with invalid cursor should return empty result + StepVerifier.create(taskStore.listTasks("invalid-cursor-id", null)).consumeNextWith(result -> { + assertThat(result.tasks()).isEmpty(); + assertThat(result.nextCursor()).isNull(); + }).verifyComplete(); + } + + // ------------------------------------------ + // Watch Task Until Terminal Tests + // ------------------------------------------ + + @Test + void testWatchTaskUntilTerminalCompletesWhenTaskIsTerminal() { + // Create and immediately complete a task + var task = taskStore.createTask(createDefaultOptions()).block(); + assertThat(task).isNotNull(); + + // Mark as completed + taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.COMPLETED, null).block(); + + // Watch should complete quickly since task is already terminal + StepVerifier.create(taskStore.watchTaskUntilTerminal(task.taskId(), null, java.time.Duration.ofSeconds(5))) + .consumeNextWith(t -> { + assertThat(t.status()).isEqualTo(TaskStatus.COMPLETED); + assertThat(t.isTerminal()).isTrue(); + }) + .verifyComplete(); + } + + @Test + void testWatchTaskUntilTerminalEmitsUpdatesUntilTerminal() { + var task = taskStore.createTask(createDefaultOptions()).block(); + assertThat(task).isNotNull(); + + // Schedule status updates in background + var executor = Executors.newSingleThreadScheduledExecutor(); + try { + executor.schedule(() -> { + taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.COMPLETED, null).block(); + }, 500, TimeUnit.MILLISECONDS); + + // Watch should emit at least one update before completing + StepVerifier.create(taskStore.watchTaskUntilTerminal(task.taskId(), null, java.time.Duration.ofSeconds(10))) + .thenAwait(java.time.Duration.ofSeconds(2)) + .expectNextMatches(t -> t.taskId().equals(task.taskId())) + .thenCancel(); + } + finally { + executor.shutdownNow(); + } + } + + @Test + void testWatchTaskUntilTerminalTimesOut() { + var task = taskStore.createTask(createDefaultOptions()).block(); + assertThat(task).isNotNull(); + + // Don't complete the task - it should timeout + StepVerifier.create(taskStore.watchTaskUntilTerminal(task.taskId(), null, java.time.Duration.ofMillis(500))) + .expectError(java.util.concurrent.TimeoutException.class) + .verify(); + } + + @Test + void testWatchTaskUntilTerminalWithNonexistentTask() { + StepVerifier.create(taskStore.watchTaskUntilTerminal("nonexistent", null, java.time.Duration.ofSeconds(1))) + .expectError(McpError.class) + .verify(); + } + + // ------------------------------------------ + // Session Isolation Tests + // ------------------------------------------ + + @Test + void testSessionIsolation() { + // Create tasks for different sessions + var session1Task = taskStore + .createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId("session-1").build()) + .block(); + var session2Task = taskStore + .createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId("session-2").build()) + .block(); + var noSessionTask = taskStore.createTask(createDefaultOptions()).block(); + + // List tasks for session-1 + var session1Result = taskStore.listTasks(null, "session-1").block(); + assertThat(session1Result.tasks()).hasSize(1); + assertThat(session1Result.tasks().get(0).taskId()).isEqualTo(session1Task.taskId()); + + // List tasks for session-2 + var session2Result = taskStore.listTasks(null, "session-2").block(); + assertThat(session2Result.tasks()).hasSize(1); + assertThat(session2Result.tasks().get(0).taskId()).isEqualTo(session2Task.taskId()); + + // List all tasks (no session filter) + var allResult = taskStore.listTasks(null, null).block(); + assertThat(allResult.tasks()).hasSize(3); + } + + @Test + void testShutdownStopsCleanupExecutor() { + // Create a separate store instance to test shutdown + var store = new InMemoryTaskStore(); + + // Shutdown should complete without error + store.shutdown().block(); + + // After shutdown, operations should still work (map is still accessible) + // but cleanup will no longer run (hard to verify directly) + var task = store.createTask(createDefaultOptions()).block(); + assertThat(task).isNotNull(); + + // Clean up + store.shutdown().block(); + } + + // ------------------------------------------ + // Stress Tests + // ------------------------------------------ + + /** + * Runs a concurrent test with success counting. + * @param numTasks number of tasks to submit + * @param numThreads number of threads in the pool + * @param timeoutSeconds timeout for all tasks to complete + * @param task the task to execute (receives task index, throws on failure) + * @return the number of successful completions + */ + private int runConcurrentWithCount(int numTasks, int numThreads, int timeoutSeconds, StressTestAction task) + throws InterruptedException { + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(numTasks); + AtomicInteger successCount = new AtomicInteger(0); + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + try { + for (int i = 0; i < numTasks; i++) { + final int idx = i; + executor.submit(() -> { + try { + startLatch.await(); + task.execute(idx); + successCount.incrementAndGet(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + catch (Exception ignored) { + // Task failed - don't count as success + } + finally { + doneLatch.countDown(); + } + }); + } + startLatch.countDown(); + assertThat(doneLatch.await(timeoutSeconds, TimeUnit.SECONDS)).isTrue(); + return successCount.get(); + } + finally { + executor.shutdownNow(); + } + } + + @FunctionalInterface + interface StressTestAction { + + void execute(int index) throws Exception; + + } + + @Test + void testConcurrentStoreTaskResult() throws InterruptedException { + // Create multiple tasks + int numTasks = 50; + List taskIds = new ArrayList<>(); + for (int i = 0; i < numTasks; i++) { + var task = taskStore.createTask(createDefaultOptions()).block(); + taskIds.add(task.taskId()); + } + + // Concurrently store results for all tasks + int successes = runConcurrentWithCount(numTasks, 20, 10, idx -> { + String taskId = taskIds.get(idx); + CallToolResult result = CallToolResult.builder().addTextContent("Result for " + taskId).build(); + taskStore.storeTaskResult(taskId, null, TaskStatus.COMPLETED, result).block(); + }); + + assertThat(successes).isEqualTo(numTasks); + + // Verify all results are stored + for (String taskId : taskIds) { + assertThat(taskStore.getTaskResult(taskId, null).block()).isNotNull(); + } + } + + @Test + void testConcurrentStoreAndReadResults() throws InterruptedException { + // Create a task and store initial result + var task = taskStore.createTask(createDefaultOptions()).block(); + String taskId = task.taskId(); + CallToolResult initialResult = CallToolResult.builder().addTextContent("Initial result").build(); + taskStore.storeTaskResult(taskId, null, TaskStatus.COMPLETED, initialResult).block(); + + // Concurrent reads of the result + int numReads = 100; + int successfulReads = runConcurrentWithCount(numReads, 20, 10, idx -> { + var result = taskStore.getTaskResult(taskId, null).block(); + if (result == null) { + throw new AssertionError("Result was null"); + } + }); + + assertThat(successfulReads).isEqualTo(numReads); + } + + @Test + void testRapidCreateAndCancelTasks() throws InterruptedException { + int numTasks = 100; + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(numTasks * 2); + AtomicInteger cancellations = new AtomicInteger(0); + List taskIds = Collections.synchronizedList(new ArrayList<>()); + ExecutorService executor = Executors.newFixedThreadPool(20); + + try { + // Create tasks + for (int i = 0; i < numTasks; i++) { + executor.submit(() -> { + try { + startLatch.await(); + var task = taskStore.createTask(createDefaultOptions()).block(); + if (task != null) { + taskIds.add(task.taskId()); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + finally { + doneLatch.countDown(); + } + }); + } + + // Attempt to cancel tasks (some may not exist yet) + for (int i = 0; i < numTasks; i++) { + final int idx = i; + executor.submit(() -> { + try { + startLatch.await(); + // Try to cancel with delay to let some tasks be created + Thread.sleep(idx % 5); + if (!taskIds.isEmpty()) { + String taskId = taskIds.get(idx % Math.max(1, taskIds.size())); + var cancelled = taskStore.requestCancellation(taskId, null).block(); + if (cancelled != null) { + cancellations.incrementAndGet(); + } + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + catch (Exception ignored) { + // Expected - task may not exist yet + } + finally { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + assertThat(doneLatch.await(15, TimeUnit.SECONDS)).isTrue(); + + // At least some tasks should have been created + assertThat(taskIds).isNotEmpty(); + } + finally { + executor.shutdownNow(); + } + } + + @Test + void testHighVolumeTaskCreation() throws InterruptedException { + int numTasks = 1000; + int created = runConcurrentWithCount(numTasks, 50, 30, idx -> { + var task = taskStore.createTask(createDefaultOptions()).block(); + if (task == null || task.taskId() == null) { + throw new AssertionError("Task creation failed"); + } + }); + + assertThat(created).isEqualTo(numTasks); + + // Verify all tasks are listable + int totalListed = 0; + String cursor = null; + do { + var result = taskStore.listTasks(cursor, null).block(); + assertThat(result).isNotNull(); + totalListed += result.tasks().size(); + cursor = result.nextCursor(); + } + while (cursor != null); + assertThat(totalListed).isEqualTo(numTasks); + } + + @Test + void testHighVolumeStatusUpdates() throws InterruptedException { + var task = taskStore.createTask(createDefaultOptions()).block(); + String taskId = task.taskId(); + + int numUpdates = 500; + int updates = runConcurrentWithCount(numUpdates, 20, 15, idx -> { + taskStore.updateTaskStatus(taskId, null, TaskStatus.WORKING, "Update " + idx).block(); + }); + + assertThat(updates).isEqualTo(numUpdates); + + // Task should still be valid + var finalTask = taskStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block(); + assertThat(finalTask).isNotNull(); + assertThat(finalTask.status()).isEqualTo(TaskStatus.WORKING); + } + + @Test + void testMultipleStoreInstances() { + // Create multiple store instances + InMemoryTaskStore store1 = new InMemoryTaskStore<>(); + InMemoryTaskStore store2 = new InMemoryTaskStore<>(); + InMemoryTaskStore store3 = new InMemoryTaskStore<>(); + + try { + // Each store should work independently + var task1 = store1.createTask(createDefaultOptions()).block(); + var task2 = store2.createTask(createDefaultOptions()).block(); + var task3 = store3.createTask(createDefaultOptions()).block(); + + assertThat(task1.taskId()).isNotEqualTo(task2.taskId()); + assertThat(task2.taskId()).isNotEqualTo(task3.taskId()); + + // Each store should only see its own tasks + assertThat(store1.listTasks(null, null).block().tasks()).hasSize(1); + assertThat(store2.listTasks(null, null).block().tasks()).hasSize(1); + assertThat(store3.listTasks(null, null).block().tasks()).hasSize(1); + } + finally { + store1.shutdown().block(); + store2.shutdown().block(); + store3.shutdown().block(); + } + } + + // ------------------------------------------ + // Boundary Condition Tests + // ------------------------------------------ + + static Stream zeroValueOptions() { + return Stream.of( + Arguments.of("ttl", CreateTaskOptions.builder(createTestRequest("test-tool")).requestedTtl(0L).build(), + 0L, null), + Arguments.of("pollInterval", + CreateTaskOptions.builder(createTestRequest("test-tool")).pollInterval(0L).build(), null, 0L)); + } + + @ParameterizedTest(name = "zero {0} should be allowed") + @MethodSource("zeroValueOptions") + void testCreateTaskOptionsWithZeroValuesAllowed(String field, CreateTaskOptions options, Long expectedTtl, + Long expectedPollInterval) { + StepVerifier.create(taskStore.createTask(options)).consumeNextWith(task -> { + if (expectedTtl != null) { + assertThat(task.ttl()).isEqualTo(expectedTtl); + } + if (expectedPollInterval != null) { + assertThat(task.pollInterval()).isEqualTo(expectedPollInterval); + } + }).verifyComplete(); + } + + static Stream invalidCreateTaskOptions() { + return Stream.of(Arguments.of("negative TTL", + (Runnable) () -> CreateTaskOptions.builder(createTestRequest("test-tool")).requestedTtl(-1L).build(), + "requestedTtl"), + Arguments.of("negative pollInterval", + (Runnable) () -> CreateTaskOptions.builder(createTestRequest("test-tool")) + .pollInterval(-1L) + .build(), + "pollInterval"), + Arguments.of("TTL exceeds max", + (Runnable) () -> CreateTaskOptions.builder(createTestRequest("test-tool")) + .requestedTtl(TaskDefaults.MAX_TTL_MS + 1) + .build(), + "must not exceed"), + Arguments.of("pollInterval exceeds max", + (Runnable) () -> CreateTaskOptions.builder(createTestRequest("test-tool")) + .pollInterval(TaskDefaults.MAX_POLL_INTERVAL_MS + 1) + .build(), + "must not exceed"), + Arguments.of("pollInterval below min", + (Runnable) () -> CreateTaskOptions.builder(createTestRequest("test-tool")) + .pollInterval(1L) + .build(), + "must be at least")); + } + + @ParameterizedTest(name = "{0} throws IllegalArgumentException") + @MethodSource("invalidCreateTaskOptions") + void testCreateTaskOptionsWithInvalidValuesThrows(String description, Runnable builder, String expectedMessage) { + assertThatThrownBy(builder::run).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining(expectedMessage); + } + + @Test + void testCreateTaskOptionsWithMaxBoundaryValues() { + // Use the maximum allowed values (at the boundary) + var options = CreateTaskOptions.builder(createTestRequest("test-tool")) + .requestedTtl(TaskDefaults.MAX_TTL_MS) + .pollInterval(TaskDefaults.MAX_POLL_INTERVAL_MS) + .build(); + + StepVerifier.create(taskStore.createTask(options)).consumeNextWith(task -> { + assertThat(task.ttl()).isEqualTo(TaskDefaults.MAX_TTL_MS); + assertThat(task.pollInterval()).isEqualTo(TaskDefaults.MAX_POLL_INTERVAL_MS); + }).verifyComplete(); + } + + @Test + void testCreateTaskOptionsWithZeroPollIntervalAllowed() { + // Zero is allowed (means use default) + var options = CreateTaskOptions.builder(createTestRequest("test-tool")).pollInterval(0L).build(); + assertThat(options.pollInterval()).isEqualTo(0L); + } + + @Test + void testStoreWithVeryShortTtl() { + InMemoryTaskStore shortTtlStore = new InMemoryTaskStore<>(1L, 100L); + + try { + var task = shortTtlStore.createTask(createDefaultOptions()).block(); + assertThat(task).isNotNull(); + assertThat(task.ttl()).isEqualTo(1L); + } + finally { + shortTtlStore.shutdown().block(); + } + } + + @Test + void testStoreWithVeryLongPollInterval() { + InMemoryTaskStore longPollStore = new InMemoryTaskStore<>(60000L, + Long.MAX_VALUE); + + try { + var task = longPollStore.createTask(createDefaultOptions()).block(); + assertThat(task.pollInterval()).isEqualTo(Long.MAX_VALUE); + } + finally { + longPollStore.shutdown().block(); + } + } + + @ParameterizedTest(name = "getTask with \"{0}\" returns empty") + @ValueSource(strings = { "", " ", "\t", "\n" }) + void testGetTaskWithInvalidIdReturnsEmpty(String invalidId) { + StepVerifier.create(taskStore.getTask(invalidId, null)).verifyComplete(); + } + + @Test + void testStoreResultForAlreadyCompletedTask() { + var task = taskStore.createTask(createDefaultOptions()).block(); + String taskId = task.taskId(); + + CallToolResult result1 = CallToolResult.builder().addTextContent("First").build(); + taskStore.storeTaskResult(taskId, null, TaskStatus.COMPLETED, result1).block(); + + CallToolResult result2 = CallToolResult.builder().addTextContent("Second").build(); + taskStore.storeTaskResult(taskId, null, TaskStatus.COMPLETED, result2).block(); + + var storedResult = taskStore.getTaskResult(taskId, null).block(); + assertThat(storedResult).isNotNull(); + } + + @Test + void testDefaultTaskContextWithEmptyTaskIdThrows() { + assertThatThrownBy(() -> new DefaultTaskContext("", null, taskStore)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Task ID"); + } + + @Test + void testDefaultTaskContextWithNullTaskStoreThrows() { + assertThatThrownBy(() -> new DefaultTaskContext("valid-id", null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("TaskStore"); + } + + @Test + void testDefaultTaskContextCompleteWithNullResultFails() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + assertThatThrownBy(() -> context.complete(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Result must not be null"); + } + + static Stream shortTimeouts() { + return Stream.of(Arguments.of("very short (1ms)", Duration.ofMillis(1)), Arguments.of("zero", Duration.ZERO)); + } + + @ParameterizedTest(name = "{0} timeout should expire quickly") + @MethodSource("shortTimeouts") + void testWatchTaskWithShortTimeoutsExpires(String description, Duration timeout) { + var task = taskStore.createTask(createDefaultOptions()).block(); + StepVerifier.create(taskStore.watchTaskUntilTerminal(task.taskId(), null, timeout)) + .expectError(java.util.concurrent.TimeoutException.class) + .verify(Duration.ofSeconds(5)); + } + + @Test + void testListTasksWithEmptySessionId() { + var task = taskStore.createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId("").build()) + .block(); + + StepVerifier.create(taskStore.listTasks(null, "")).consumeNextWith(result -> { + assertThat(result.tasks()).hasSize(1); + }).verifyComplete(); + } + + @Test + void testListTasksWithNullSessionIdReturnsAll() { + taskStore.createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId("session1").build()) + .block(); + taskStore.createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId("session2").build()) + .block(); + taskStore.createTask(createDefaultOptions()).block(); + + StepVerifier.create(taskStore.listTasks(null, null)).consumeNextWith(result -> { + assertThat(result.tasks()).hasSize(3); + }).verifyComplete(); + } + + // ------------------------------------------ + // Max Tasks Limit Tests + // ------------------------------------------ + + @Test + void testMaxTasksLimitEnforced() { + // Create a store with a small max tasks limit + var limitedStore = new InMemoryTaskStore(60000, 1000, null, 3); + try { + // Create 3 tasks successfully + limitedStore.createTask(createDefaultOptions()).block(); + limitedStore.createTask(createDefaultOptions()).block(); + limitedStore.createTask(createDefaultOptions()).block(); + + // 4th task should fail with McpError + StepVerifier.create(limitedStore.createTask(createDefaultOptions())) + .expectErrorMatches(e -> e instanceof McpError && e.getMessage() != null + && e.getMessage().contains("Maximum task limit reached (3)")) + .verify(); + } + finally { + limitedStore.shutdown().block(); + } + } + + // ------------------------------------------ + // TTL Expiration Tests + // ------------------------------------------ + + @Test + void testTtlExpirationCleansUpTask() { + // Create a store with very short default TTL (10ms) + var shortTtlStore = new InMemoryTaskStore(10, 1000, null, 100); + try { + // Create a task (will use the store's default 10ms TTL) + var task = shortTtlStore.createTask(createDefaultOptions()).block(); + String taskId = task.taskId(); + + // Store a result for the task + var result = createTestResult("test result"); + shortTtlStore.storeTaskResult(taskId, null, TaskStatus.COMPLETED, result).block(); + + // Verify task exists + assertThat(shortTtlStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block()).isNotNull(); + assertThat(shortTtlStore.getTaskResult(taskId, null).block()).isNotNull(); + + // Wait for TTL to expire and verify cleanup removes task and result + await().atMost(Duration.ofMillis(200)).pollInterval(Duration.ofMillis(10)).untilAsserted(() -> { + // Manually trigger cleanup (normally runs every minute) + shortTtlStore.cleanupExpiredTasks(); + // Verify task and result are cleaned up + assertThat(shortTtlStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block()).isNull(); + assertThat(shortTtlStore.getTaskResult(taskId, null).block()).isNull(); + }); + } + finally { + shortTtlStore.shutdown().block(); + } + } + + @Test + void testTtlExpirationCleansUpRelatedData() { + // Create a store with very short default TTL (10ms) + var shortTtlStore = new InMemoryTaskStore(10, 1000, null, 100); + try { + // Create a task + var task = shortTtlStore.createTask(createDefaultOptions()).block(); + String taskId = task.taskId(); + + // Request cancellation (adds to cancellationRequests set) + shortTtlStore.requestCancellation(taskId, null).block(); + + // Verify cancellation request exists + assertThat(shortTtlStore.isCancellationRequested(taskId, null).block()).isTrue(); + + // Wait for TTL to expire and verify cleanup removes all related data + await().atMost(Duration.ofMillis(200)).pollInterval(Duration.ofMillis(10)).untilAsserted(() -> { + shortTtlStore.cleanupExpiredTasks(); + // Verify all related data is cleaned up + assertThat(shortTtlStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block()).isNull(); + assertThat(shortTtlStore.isCancellationRequested(taskId, null).block()).isFalse(); + }); + } + finally { + shortTtlStore.shutdown().block(); + } + } + + @Test + void testNullTtlMeansUnlimitedLifetime() throws InterruptedException { + // The default store has a 60s TTL, but we can test with custom options + // Create task with explicit null TTL via CreateTaskOptions with requestedTtl + var task = taskStore + .createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).requestedTtl(null).build()) + .block(); + + // The task should use default TTL (not null), so this test verifies the flow + // works + assertThat(task).isNotNull(); + assertThat(task.ttl()).isNotNull(); // Should have default TTL applied + } + + // ------------------------------------------ + // DefaultTaskContext Unit Tests + // ------------------------------------------ + + @Test + void testDefaultTaskContextGetTaskId() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + assertThat(context.getTaskId()).isEqualTo(task.taskId()); + } + + @Test + void testDefaultTaskContextGetTask() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + StepVerifier.create(context.getTask()).consumeNextWith(fetchedTask -> { + assertThat(fetchedTask.taskId()).isEqualTo(task.taskId()); + assertThat(fetchedTask.status()).isEqualTo(TaskStatus.WORKING); + }).verifyComplete(); + } + + @Test + void testDefaultTaskContextIsCancelled() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + // Initially not cancelled + StepVerifier.create(context.isCancelled()).expectNext(false).verifyComplete(); + + // Request cancellation via store + taskStore.requestCancellation(task.taskId(), null).block(); + + // Now should be cancelled + StepVerifier.create(context.isCancelled()).expectNext(true).verifyComplete(); + } + + @Test + void testDefaultTaskContextRequestCancellation() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + // Request cancellation via context + StepVerifier.create(context.requestCancellation()).verifyComplete(); + + // Verify task is cancelled + StepVerifier.create(context.getTask()).consumeNextWith(fetchedTask -> { + assertThat(fetchedTask.status()).isEqualTo(TaskStatus.CANCELLED); + }).verifyComplete(); + } + + @Test + void testDefaultTaskContextUpdateStatus() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + // Update status with message + StepVerifier.create(context.updateStatus("Processing 50%")).verifyComplete(); + + // Verify status is still WORKING but message is updated + StepVerifier.create(context.getTask()).consumeNextWith(fetchedTask -> { + assertThat(fetchedTask.status()).isEqualTo(TaskStatus.WORKING); + assertThat(fetchedTask.statusMessage()).isEqualTo("Processing 50%"); + }).verifyComplete(); + } + + @Test + void testDefaultTaskContextRequireInput() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + // Require input + StepVerifier.create(context.requireInput("Need user confirmation")).verifyComplete(); + + // Verify status is INPUT_REQUIRED + StepVerifier.create(context.getTask()).consumeNextWith(fetchedTask -> { + assertThat(fetchedTask.status()).isEqualTo(TaskStatus.INPUT_REQUIRED); + assertThat(fetchedTask.statusMessage()).isEqualTo("Need user confirmation"); + }).verifyComplete(); + } + + @Test + void testDefaultTaskContextComplete() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + // Complete with result + var result = createTestResult("Success!"); + StepVerifier.create(context.complete(result)).verifyComplete(); + + // Verify task is completed + StepVerifier.create(context.getTask()).consumeNextWith(fetchedTask -> { + assertThat(fetchedTask.status()).isEqualTo(TaskStatus.COMPLETED); + }).verifyComplete(); + + // Verify result is stored + StepVerifier.create(taskStore.getTaskResult(task.taskId(), null)).expectNext(result).verifyComplete(); + } + + @Test + void testDefaultTaskContextFail() { + var task = taskStore.createTask(createDefaultOptions()).block(); + DefaultTaskContext context = new DefaultTaskContext<>(task.taskId(), null, + taskStore); + + // Fail with error message + StepVerifier.create(context.fail("Something went wrong")).verifyComplete(); + + // Verify task is failed + StepVerifier.create(context.getTask()).consumeNextWith(fetchedTask -> { + assertThat(fetchedTask.status()).isEqualTo(TaskStatus.FAILED); + assertThat(fetchedTask.statusMessage()).isEqualTo("Something went wrong"); + }).verifyComplete(); + } + + // -------------------------- + // Session Validation Tests + // -------------------------- + + @Test + void testGetTaskReturnsTaskForMatchingSession() { + String sessionId = "session-123"; + var task = taskStore + .createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId(sessionId).build()) + .block(); + + // getTask with matching session ID should return the result + StepVerifier.create(taskStore.getTask(task.taskId(), sessionId)).consumeNextWith(result -> { + assertThat(result.task().taskId()).isEqualTo(task.taskId()); + assertThat(result.task().status()).isEqualTo(TaskStatus.WORKING); + }).verifyComplete(); + } + + @Test + void testGetTaskReturnsEmptyForMismatchedSession() { + String sessionId = "session-123"; + String differentSessionId = "session-456"; + var task = taskStore + .createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId(sessionId).build()) + .block(); + + // Request with different session ID should return empty (access denied) + StepVerifier.create(taskStore.getTask(task.taskId(), differentSessionId)).verifyComplete(); + } + + @Test + void testGetTaskReturnsTaskWhenNoSessionRestriction() { + // Create task without session ID + var task = taskStore.createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).build()).block(); + + // Any session should be able to access it + StepVerifier.create(taskStore.getTask(task.taskId(), "any-session")).consumeNextWith(result -> { + assertThat(result.task().taskId()).isEqualTo(task.taskId()); + }).verifyComplete(); + } + + @Test + void testGetTaskReturnsEmptyForNonExistentTask() { + StepVerifier.create(taskStore.getTask("non-existent-task", "some-session")).verifyComplete(); + } + + // ------------------------------------------ + // Race Condition Tests + // ------------------------------------------ + + @Test + void testConcurrentCancellationAndStatusUpdate() throws InterruptedException { + // Create task in WORKING state + var task = taskStore.createTask(createDefaultOptions()).block(); + assertThat(task).isNotNull(); + String taskId = task.taskId(); + + // Race: one thread cancels, one thread tries to complete + runConcurrent(2, 2, i -> { + if (i == 0) { + taskStore.requestCancellation(taskId, null).block(); + } + else { + taskStore.updateTaskStatus(taskId, null, TaskStatus.COMPLETED, "Completed").block(); + } + }); + + // Task should end in exactly one terminal state (either CANCELLED or COMPLETED) + var finalTask = taskStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block(); + assertThat(finalTask).isNotNull(); + assertThat(finalTask.status()).satisfiesAnyOf(status -> assertThat(status).isEqualTo(TaskStatus.CANCELLED), + status -> assertThat(status).isEqualTo(TaskStatus.COMPLETED)); + } + + @Test + void testConcurrentGetTaskDuringCleanup() throws InterruptedException { + // Create a task store with very short TTL (10ms) and default poll interval + var shortTtlStore = new InMemoryTaskStore(10L, 1000L); + + // Create a task + var task = shortTtlStore + .createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).requestedTtl(10L).build()) + .block(); + assertThat(task).isNotNull(); + String taskId = task.taskId(); + + // Wait for TTL to expire by polling until cleanup would remove the task + await().atMost(Duration.ofMillis(200)).pollInterval(Duration.ofMillis(10)).until(() -> { + shortTtlStore.cleanupExpiredTasks(); + return shortTtlStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block() == null; + }); + + // Concurrent threads try to get task after cleanup - verifies no race conditions + runConcurrent(10, 10, i -> { + // getTask() should return empty without throwing + shortTtlStore.getTask(taskId, null).map(GetTaskFromStoreResult::task).block(); + }); + + // No exceptions thrown = success + } + + @Test + void testConcurrentSessionOperationsDuringListTasks() throws InterruptedException { + String sessionA = "session-A"; + String sessionB = "session-B"; + + // Create initial tasks for both sessions + for (int i = 0; i < 10; i++) { + taskStore.createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId(sessionA).build()) + .block(); + taskStore.createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId(sessionB).build()) + .block(); + } + + // Race: some threads call listTasks, some create new tasks + runConcurrent(50, 20, i -> { + if (i % 3 == 0) { + taskStore.listTasks(null, sessionA).block(); + } + else if (i % 3 == 1) { + taskStore.listTasks(null, sessionB).block(); + } + else { + String session = i % 2 == 0 ? sessionA : sessionB; + taskStore + .createTask(CreateTaskOptions.builder(createTestRequest("test-tool")).sessionId(session).build()) + .block(); + } + }); + + // listTasks should always return consistent snapshot (no partial results or + // exceptions) + var finalListA = taskStore.listTasks(null, sessionA).block(); + var finalListB = taskStore.listTasks(null, sessionB).block(); + assertThat(finalListA).isNotNull(); + assertThat(finalListB).isNotNull(); + // At least the initial 10 tasks should be present for each session + assertThat(finalListA.tasks().size()).isGreaterThanOrEqualTo(10); + assertThat(finalListB.tasks().size()).isGreaterThanOrEqualTo(10); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecificationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecificationTest.java new file mode 100644 index 000000000..54d424676 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareAsyncToolSpecificationTest.java @@ -0,0 +1,290 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Map; +import java.util.concurrent.ForkJoinPool; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetTaskResult; +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; +import io.modelcontextprotocol.spec.McpSchema.Task; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.ToolAnnotations; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link TaskAwareAsyncToolSpecification} and its builder. + */ +class TaskAwareAsyncToolSpecificationTest extends + AbstractTaskAwareToolSpecificationTest { + + @Override + protected TaskAwareAsyncToolSpecification.Builder createBuilder() { + return TaskAwareAsyncToolSpecification.builder(); + } + + @Override + protected TaskAwareAsyncToolSpecification.Builder withMinimalCreateTaskHandler( + TaskAwareAsyncToolSpecification.Builder builder) { + return builder.createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())); + } + + @Override + protected TaskAwareAsyncToolSpecification build(TaskAwareAsyncToolSpecification.Builder builder) { + return builder.build(); + } + + @Override + protected Tool getTool(TaskAwareAsyncToolSpecification spec) { + return spec.tool(); + } + + @Override + protected JsonSchema getEmptyInputSchema() { + return TaskDefaults.EMPTY_INPUT_SCHEMA; + } + + // ------------------------------------------ + // Async-Specific Builder Tests + // ------------------------------------------ + + @Test + void builderShouldAllowMethodChaining() { + TaskAwareAsyncToolSpecification.Builder builder = TaskAwareAsyncToolSpecification.builder(); + + // Verify method chaining returns the same builder instance + assertThat(builder.name("test")).isSameAs(builder); + assertThat(builder.description("desc")).isSameAs(builder); + assertThat(builder.inputSchema(TEST_SCHEMA)).isSameAs(builder); + assertThat(builder.taskSupportMode(TaskSupportMode.OPTIONAL)).isSameAs(builder); + assertThat(builder.annotations(new ToolAnnotations(null, null, null, null, null, null))).isSameAs(builder); + assertThat(builder.createTaskHandler((args, extra) -> Mono.empty())).isSameAs(builder); + assertThat(builder.getTaskHandler((exchange, request) -> Mono.empty())).isSameAs(builder); + assertThat(builder.getTaskResultHandler((exchange, request) -> Mono.empty())).isSameAs(builder); + } + + @Test + void builderShouldCreateSpecificationWithAllOptionalHandlers() { + GetTaskHandler getTaskHandler = (exchange, + request) -> Mono.just(GetTaskResult.builder() + .taskId("task-123") + .status(McpSchema.TaskStatus.COMPLETED) + .statusMessage("done") + .createdAt("now") + .lastUpdatedAt("now") + .build()); + + GetTaskResultHandler getTaskResultHandler = (exchange, request) -> Mono + .just(CallToolResult.builder().addTextContent("custom result").build()); + + TaskAwareAsyncToolSpecification spec = TaskAwareAsyncToolSpecification.builder() + .name("full-tool") + .description("A tool with all handlers") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .getTaskHandler(getTaskHandler) + .getTaskResultHandler(getTaskResultHandler) + .build(); + + assertThat(spec.getTaskHandler()).isNotNull(); + assertThat(spec.getTaskResultHandler()).isNotNull(); + } + + @Test + void builtSpecificationShouldExecuteCreateTaskHandlerCorrectly() { + String expectedTaskId = "created-task-456"; + + TaskAwareAsyncToolSpecification spec = TaskAwareAsyncToolSpecification.builder() + .name("creator-tool") + .createTaskHandler((args, extra) -> { + Task task = Task.builder() + .taskId(expectedTaskId) + .status(McpSchema.TaskStatus.WORKING) + .statusMessage("Starting...") + .createdAt("now") + .lastUpdatedAt("now") + .build(); + return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build()); + }) + .build(); + + Mono resultMono = spec.createTaskHandler().createTask(Map.of(), null); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.task()).isNotNull(); + assertThat(result.task().taskId()).isEqualTo(expectedTaskId); + assertThat(result.task().status()).isEqualTo(McpSchema.TaskStatus.WORKING); + }).verifyComplete(); + } + + @Test + void callHandlerShouldReturnErrorForNonTaskCalls() { + TaskAwareAsyncToolSpecification spec = TaskAwareAsyncToolSpecification.builder() + .name("task-only-tool") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .build(); + + Mono resultMono = spec.callHandler() + .apply(null, new McpSchema.CallToolRequest("task-only-tool", Map.of())); + + StepVerifier.create(resultMono) + .expectErrorMatches(ex -> ex instanceof UnsupportedOperationException + && ex.getMessage().contains("requires task-augmented execution")) + .verify(); + } + + // ------------------------------------------ + // fromSync Conversion Tests + // ------------------------------------------ + + @Test + void fromSyncShouldConvertSyncSpecificationCorrectly() { + String expectedTaskId = "sync-task-789"; + + // Handler that doesn't use 'extra' parameter (avoids null issues in test) + TaskAwareSyncToolSpecification syncSpec = TaskAwareSyncToolSpecification.builder() + .name("sync-task-tool") + .description("A sync task tool") + .taskSupportMode(TaskSupportMode.REQUIRED) + .createTaskHandler((args, extra) -> { + // Note: Not using 'extra' here to avoid NPE in test context + Task task = Task.builder() + .taskId(expectedTaskId) + .status(McpSchema.TaskStatus.WORKING) + .createdAt("now") + .lastUpdatedAt("now") + .build(); + return McpSchema.CreateTaskResult.builder().task(task).build(); + }) + .build(); + + TaskAwareAsyncToolSpecification asyncSpec = TaskAwareAsyncToolSpecification.fromSync(syncSpec, + ForkJoinPool.commonPool()); + + assertThat(asyncSpec).isNotNull(); + assertThat(asyncSpec.tool().name()).isEqualTo("sync-task-tool"); + assertThat(asyncSpec.tool().description()).isEqualTo("A sync task tool"); + assertThat(asyncSpec.tool().execution().taskSupport()).isEqualTo(TaskSupportMode.REQUIRED); + assertThat(asyncSpec.createTaskHandler()).isNotNull(); + assertThat(asyncSpec.callHandler()).isNotNull(); + } + + @Test + void fromSyncShouldConvertOptionalHandlers() { + String customMessage = "Custom handler invoked"; + + TaskAwareSyncToolSpecification syncSpec = TaskAwareSyncToolSpecification.builder() + .name("full-sync-tool") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .getTaskHandler((exchange, request) -> GetTaskResult.builder() + .taskId(request.taskId()) + .status(McpSchema.TaskStatus.COMPLETED) + .statusMessage(customMessage) + .createdAt("now") + .lastUpdatedAt("now") + .build()) + .getTaskResultHandler( + (exchange, request) -> CallToolResult.builder().addTextContent("Custom result content").build()) + .build(); + + TaskAwareAsyncToolSpecification asyncSpec = TaskAwareAsyncToolSpecification.fromSync(syncSpec, + ForkJoinPool.commonPool()); + + assertThat(asyncSpec.getTaskHandler()).isNotNull(); + assertThat(asyncSpec.getTaskResultHandler()).isNotNull(); + + // Test getTaskHandler + Mono getTaskMono = asyncSpec.getTaskHandler() + .handle(null, McpSchema.GetTaskRequest.builder().taskId("test-task").build()); + + StepVerifier.create(getTaskMono).assertNext(result -> { + assertThat(result.statusMessage()).isEqualTo(customMessage); + }).verifyComplete(); + + // Test getTaskResultHandler + Mono getResultMono = asyncSpec.getTaskResultHandler() + .handle(null, McpSchema.GetTaskPayloadRequest.builder().taskId("test-task").build()); + + StepVerifier.create(getResultMono).assertNext(result -> { + CallToolResult callResult = (CallToolResult) result; + assertThat(callResult.content()).hasSize(1); + assertThat(((TextContent) callResult.content().get(0)).text()).isEqualTo("Custom result content"); + }).verifyComplete(); + } + + @Test + void fromSyncShouldThrowWhenSyncSpecIsNull() { + assertThatThrownBy(() -> TaskAwareAsyncToolSpecification.fromSync(null, ForkJoinPool.commonPool())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("sync specification must not be null"); + } + + @Test + void fromSyncShouldThrowWhenExecutorIsNull() { + TaskAwareSyncToolSpecification syncSpec = TaskAwareSyncToolSpecification.builder() + .name("test") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .build(); + + assertThatThrownBy(() -> TaskAwareAsyncToolSpecification.fromSync(syncSpec, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("executor must not be null"); + } + + // ------------------------------------------ + // Exception Handling Tests + // ------------------------------------------ + + @Test + void createTaskHandlerExceptionShouldPropagate() { + TaskAwareAsyncToolSpecification spec = TaskAwareAsyncToolSpecification.builder() + .name("failing-create-task") + .createTaskHandler((args, extra) -> Mono.error(new RuntimeException("createTask failed"))) + .build(); + + StepVerifier.create(spec.createTaskHandler().createTask(Map.of(), null)) + .expectErrorMatches(e -> e instanceof RuntimeException && e.getMessage().equals("createTask failed")) + .verify(); + } + + @Test + void getTaskHandlerExceptionShouldPropagate() { + TaskAwareAsyncToolSpecification spec = TaskAwareAsyncToolSpecification.builder() + .name("failing-get-task") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .getTaskHandler((exchange, request) -> Mono.error(new RuntimeException("getTask failed"))) + .build(); + + assertThat(spec.getTaskHandler()).isNotNull(); + StepVerifier.create(spec.getTaskHandler().handle(null, null)) + .expectErrorMatches(e -> e instanceof RuntimeException && e.getMessage().equals("getTask failed")) + .verify(); + } + + @Test + void getTaskResultHandlerExceptionShouldPropagate() { + TaskAwareAsyncToolSpecification spec = TaskAwareAsyncToolSpecification.builder() + .name("failing-get-result") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .getTaskResultHandler((exchange, request) -> Mono.error(new RuntimeException("getTaskResult failed"))) + .build(); + + assertThat(spec.getTaskResultHandler()).isNotNull(); + StepVerifier.create(spec.getTaskResultHandler().handle(null, null)) + .expectErrorMatches(e -> e instanceof RuntimeException && e.getMessage().equals("getTaskResult failed")) + .verify(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecificationTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecificationTest.java new file mode 100644 index 000000000..dcca0b3df --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskAwareSyncToolSpecificationTest.java @@ -0,0 +1,232 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetTaskResult; +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; +import io.modelcontextprotocol.spec.McpSchema.Task; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.ToolAnnotations; + +/** + * Tests for {@link TaskAwareSyncToolSpecification} and its builder. + */ +class TaskAwareSyncToolSpecificationTest extends + AbstractTaskAwareToolSpecificationTest { + + @Override + protected TaskAwareSyncToolSpecification.Builder createBuilder() { + return TaskAwareSyncToolSpecification.builder(); + } + + @Override + protected TaskAwareSyncToolSpecification.Builder withMinimalCreateTaskHandler( + TaskAwareSyncToolSpecification.Builder builder) { + return builder.createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()); + } + + @Override + protected TaskAwareSyncToolSpecification build(TaskAwareSyncToolSpecification.Builder builder) { + return builder.build(); + } + + @Override + protected Tool getTool(TaskAwareSyncToolSpecification spec) { + return spec.tool(); + } + + @Override + protected JsonSchema getEmptyInputSchema() { + return TaskDefaults.EMPTY_INPUT_SCHEMA; + } + + // ------------------------------------------ + // Sync-Specific Builder Tests + // ------------------------------------------ + + @Test + void builderShouldAllowMethodChaining() { + TaskAwareSyncToolSpecification.Builder builder = TaskAwareSyncToolSpecification.builder(); + + // Verify method chaining returns the same builder instance + assertThat(builder.name("test")).isSameAs(builder); + assertThat(builder.description("desc")).isSameAs(builder); + assertThat(builder.inputSchema(TEST_SCHEMA)).isSameAs(builder); + assertThat(builder.taskSupportMode(TaskSupportMode.OPTIONAL)).isSameAs(builder); + assertThat(builder.annotations(new ToolAnnotations(null, null, null, null, null, null))).isSameAs(builder); + assertThat(builder.createTaskHandler((args, extra) -> null)).isSameAs(builder); + assertThat(builder.getTaskHandler((exchange, request) -> null)).isSameAs(builder); + assertThat(builder.getTaskResultHandler((exchange, request) -> null)).isSameAs(builder); + } + + @Test + void builderShouldCreateSpecificationWithAllOptionalHandlers() { + SyncGetTaskHandler getTaskHandler = (exchange, request) -> GetTaskResult.builder() + .taskId("task-123") + .status(McpSchema.TaskStatus.COMPLETED) + .statusMessage("done") + .createdAt("now") + .lastUpdatedAt("now") + .build(); + + SyncGetTaskResultHandler getTaskResultHandler = (exchange, + request) -> CallToolResult.builder().addTextContent("custom result").build(); + + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("full-tool") + .description("A tool with all handlers") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .getTaskHandler(getTaskHandler) + .getTaskResultHandler(getTaskResultHandler) + .build(); + + assertThat(spec.getTaskHandler()).isNotNull(); + assertThat(spec.getTaskResultHandler()).isNotNull(); + } + + @Test + void builtSpecificationShouldExecuteCreateTaskHandlerCorrectly() { + String expectedTaskId = "created-task-456"; + + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("creator-tool") + .createTaskHandler((args, extra) -> { + Task task = Task.builder() + .taskId(expectedTaskId) + .status(McpSchema.TaskStatus.WORKING) + .statusMessage("Starting...") + .createdAt("now") + .lastUpdatedAt("now") + .build(); + return McpSchema.CreateTaskResult.builder().task(task).build(); + }) + .build(); + + McpSchema.CreateTaskResult result = spec.createTaskHandler().createTask(Map.of(), null); + + assertThat(result).isNotNull(); + assertThat(result.task()).isNotNull(); + assertThat(result.task().taskId()).isEqualTo(expectedTaskId); + assertThat(result.task().status()).isEqualTo(McpSchema.TaskStatus.WORKING); + } + + @Test + void callHandlerShouldThrowForNonTaskCalls() { + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("task-only-tool") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .build(); + + assertThatThrownBy( + () -> spec.callHandler().apply(null, new McpSchema.CallToolRequest("task-only-tool", Map.of()))) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("requires task-augmented execution"); + } + + @Test + void getTaskShouldExecuteCorrectly() { + String customMessage = "Custom status from handler"; + + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("custom-get-tool") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .getTaskHandler((exchange, request) -> GetTaskResult.builder() + .taskId(request.taskId()) + .status(McpSchema.TaskStatus.WORKING) + .statusMessage(customMessage) + .createdAt("now") + .lastUpdatedAt("now") + .build()) + .build(); + + GetTaskResult result = spec.getTaskHandler() + .handle(null, McpSchema.GetTaskRequest.builder().taskId("test-task-id").build()); + + assertThat(result).isNotNull(); + assertThat(result.taskId()).isEqualTo("test-task-id"); + assertThat(result.statusMessage()).isEqualTo(customMessage); + } + + @Test + void getTaskResultShouldExecuteCorrectly() { + String expectedResult = "Custom result content"; + + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("custom-result-tool") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .getTaskResultHandler( + (exchange, request) -> CallToolResult.builder().addTextContent(expectedResult).build()) + .build(); + + McpSchema.Result result = spec.getTaskResultHandler() + .handle(null, McpSchema.GetTaskPayloadRequest.builder().taskId("test-task-id").build()); + + assertThat(result).isNotNull(); + assertThat(result).isInstanceOf(CallToolResult.class); + CallToolResult callResult = (CallToolResult) result; + assertThat(callResult.content()).hasSize(1); + assertThat(((TextContent) callResult.content().get(0)).text()).isEqualTo(expectedResult); + } + + // ------------------------------------------ + // Exception Handling Tests + // ------------------------------------------ + + @Test + void createTaskHandlerExceptionShouldPropagate() { + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("failing-create-task") + .createTaskHandler((args, extra) -> { + throw new RuntimeException("createTask failed"); + }) + .build(); + + assertThatThrownBy(() -> spec.createTaskHandler().createTask(Map.of(), null)) + .isInstanceOf(RuntimeException.class) + .hasMessage("createTask failed"); + } + + @Test + void getTaskHandlerExceptionShouldPropagate() { + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("failing-get-task") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .getTaskHandler((exchange, request) -> { + throw new RuntimeException("getTask failed"); + }) + .build(); + + assertThat(spec.getTaskHandler()).isNotNull(); + assertThatThrownBy(() -> spec.getTaskHandler().handle(null, null)).isInstanceOf(RuntimeException.class) + .hasMessage("getTask failed"); + } + + @Test + void getTaskResultHandlerExceptionShouldPropagate() { + TaskAwareSyncToolSpecification spec = TaskAwareSyncToolSpecification.builder() + .name("failing-get-result") + .createTaskHandler((args, extra) -> McpSchema.CreateTaskResult.builder().build()) + .getTaskResultHandler((exchange, request) -> { + throw new RuntimeException("getTaskResult failed"); + }) + .build(); + + assertThat(spec.getTaskResultHandler()).isNotNull(); + assertThatThrownBy(() -> spec.getTaskResultHandler().handle(null, null)).isInstanceOf(RuntimeException.class) + .hasMessage("getTaskResult failed"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskHelperTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskHelperTests.java new file mode 100644 index 000000000..9f2174cbc --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskHelperTests.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.stream.Stream; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; + +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; + +/** + * Tests for {@link TaskHelper} utility methods. + */ +class TaskHelperTests { + + // ------------------------------------------ + // isTerminal Tests + // ------------------------------------------ + + static Stream terminalStatusCases() { + return Stream.of(Arguments.of(TaskStatus.COMPLETED, true), Arguments.of(TaskStatus.FAILED, true), + Arguments.of(TaskStatus.CANCELLED, true), Arguments.of(TaskStatus.WORKING, false), + Arguments.of(TaskStatus.INPUT_REQUIRED, false), Arguments.of(null, false)); + } + + @ParameterizedTest(name = "isTerminal({0}) should be {1}") + @MethodSource("terminalStatusCases") + void testIsTerminal(TaskStatus status, boolean expected) { + assertThat(TaskHelper.isTerminal(status)).isEqualTo(expected); + } + + // ------------------------------------------ + // isValidTransition Tests + // ------------------------------------------ + + static Stream validTransitionCases() { + return Stream.of(Arguments.of(TaskStatus.WORKING, TaskStatus.COMPLETED), + Arguments.of(TaskStatus.WORKING, TaskStatus.FAILED), + Arguments.of(TaskStatus.WORKING, TaskStatus.CANCELLED), + Arguments.of(TaskStatus.WORKING, TaskStatus.INPUT_REQUIRED), + Arguments.of(TaskStatus.INPUT_REQUIRED, TaskStatus.WORKING), + Arguments.of(TaskStatus.INPUT_REQUIRED, TaskStatus.COMPLETED)); + } + + @ParameterizedTest(name = "transition from {0} to {1} should be valid") + @MethodSource("validTransitionCases") + void testValidTransitions(TaskStatus from, TaskStatus to) { + assertThat(TaskHelper.isValidTransition(from, to)).isTrue(); + } + + static Stream invalidTransitionCases() { + return Stream.of(Arguments.of(TaskStatus.COMPLETED, TaskStatus.WORKING), + Arguments.of(TaskStatus.FAILED, TaskStatus.WORKING), + Arguments.of(TaskStatus.CANCELLED, TaskStatus.WORKING), Arguments.of(null, TaskStatus.WORKING), + Arguments.of(TaskStatus.WORKING, null)); + } + + @ParameterizedTest(name = "transition from {0} to {1} should be invalid") + @MethodSource("invalidTransitionCases") + void testInvalidTransitions(TaskStatus from, TaskStatus to) { + assertThat(TaskHelper.isValidTransition(from, to)).isFalse(); + } + + // ------------------------------------------ + // getStatusDescription Tests + // ------------------------------------------ + + @ParameterizedTest + @EnumSource(TaskStatus.class) + void testGetStatusDescriptionReturnsNonNull(TaskStatus status) { + assertThat(TaskHelper.getStatusDescription(status)).isNotNull().isNotEmpty(); + } + + @Test + void testGetStatusDescriptionWithNull() { + assertThat(TaskHelper.getStatusDescription(null)).isEqualTo("Unknown"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java new file mode 100644 index 000000000..d5ff643da --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java @@ -0,0 +1,126 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.IntConsumer; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.ServerTaskCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.ToolExecution; + +/** + * Testing utilities for MCP tasks in mcp-core module. + * + *

+ * This is a minimal version of task test utilities containing only what's needed for + * mcp-core unit tests. The full TaskTestUtils with polling helpers and test constants is + * available in the mcp-test module. + * + */ +public final class TaskTestUtils { + + private TaskTestUtils() { + // Utility class - no instantiation + } + + /** + * Default server capabilities with tasks enabled for tests. + */ + public static final ServerCapabilities DEFAULT_TASK_CAPABILITIES = ServerCapabilities.builder() + .tasks(ServerTaskCapabilities.builder().list().cancel().toolsCall().build()) + .tools(true) + .build(); + + /** + * Creates a tool with the given name and task support mode for tests. + * @param name the tool name + * @param title the tool title + * @param mode the task support mode + * @return a Tool with task execution support + */ + public static Tool createTaskTool(String name, String title, TaskSupportMode mode) { + return McpSchema.Tool.builder() + .name(name) + .title(title) + .inputSchema(EMPTY_JSON_SCHEMA) + .execution(ToolExecution.builder().taskSupport(mode).build()) + .build(); + } + + /** + * Creates a test CallToolRequest for use in CreateTaskOptions. This is the standard + * originating request type used in tests. + * @param toolName the name of the tool for the request + * @return a CallToolRequest with the given tool name and null arguments + */ + public static McpSchema.CallToolRequest createTestRequest(String toolName) { + return new McpSchema.CallToolRequest(toolName, null); + } + + /** + * Runs concurrent operations with proper synchronization. + * + *

+ * This method executes the given operation across multiple threads, ensuring all + * threads start at approximately the same time using a latch. This is useful for + * testing thread safety and concurrent access patterns. + * + *

+ * Example usage: + * + *

{@code
+	 * TaskTestUtils.runConcurrent(100, 10, i -> {
+	 *     taskStore.createTask(null).block();
+	 * });
+	 * }
+ * @param numOperations total number of operations to execute + * @param numThreads number of threads in the thread pool + * @param operation the operation to run, receiving the operation index (0 to + * numOperations-1) + * @throws InterruptedException if the current thread is interrupted while waiting + * @throws RuntimeException if operations don't complete within 10 seconds + */ + public static void runConcurrent(int numOperations, int numThreads, IntConsumer operation) + throws InterruptedException { + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(numOperations); + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + try { + for (int i = 0; i < numOperations; i++) { + final int index = i; + executor.submit(() -> { + try { + startLatch.await(); + operation.accept(index); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + finally { + doneLatch.countDown(); + } + }); + } + startLatch.countDown(); + boolean completed = doneLatch.await(10, TimeUnit.SECONDS); + if (!completed) { + throw new RuntimeException("Operations did not complete within 10 seconds"); + } + } + finally { + executor.shutdownNow(); + } + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 090710248..d9d026838 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -6,7 +6,13 @@ import java.time.Duration; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import io.modelcontextprotocol.experimental.tasks.CreateTaskOptions; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskTestUtils; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -15,6 +21,8 @@ import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; @@ -43,6 +51,8 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; + private static final String TEST_TASK_TOOL_NAME = "task-tool"; + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { @@ -719,4 +729,359 @@ void testRootsChangeHandlers() { .doesNotThrowAnyException(); } + // --------------------------------------- + // Tasks Tests + // --------------------------------------- + + /** Creates a server with task capabilities and the given task store. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore) { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .build(); + } + + /** Creates a server with task capabilities, task store, and a task-aware tool. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore, + TaskAwareAsyncToolSpecification taskTool) { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .taskTools(taskTool) + .build(); + } + + /** + * Creates a simple task-aware tool that returns a text result. + */ + protected TaskAwareAsyncToolSpecification createSimpleTaskTool(String name, TaskSupportMode mode, + String resultText) { + return TaskAwareAsyncToolSpecification.builder() + .name(name) + .description("Test task tool") + .taskSupportMode(mode) + .createTaskHandler((args, extra) -> extra.createTask().flatMap(task -> { + // Immediately complete the task with the result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(resultText))) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + })) + .build(); + } + + @Test + void testServerWithTaskStore() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + assertThat(server.getTaskStore()).isSameAs(taskStore); + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreCreateAndGet() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier.create(taskStore.createTask( + CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).requestedTtl(60000L).build())) + .consumeNextWith(task -> { + assertThat(task.taskId()).isNotNull().isNotEmpty(); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Get the task + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().taskId()).isEqualTo(taskIdRef.get()); + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.WORKING); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreUpdateStatus() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Update status + StepVerifier.create(taskStore.updateTaskStatus(taskIdRef.get(), null, TaskStatus.WORKING, "Processing...")) + .verifyComplete(); + + // Verify status updated + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.WORKING); + assertThat(storeResult.task().statusMessage()).isEqualTo("Processing..."); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreStoreResult() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Store result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Done!"))) + .isError(false) + .build(); + + StepVerifier.create(taskStore.storeTaskResult(taskIdRef.get(), null, TaskStatus.COMPLETED, result)) + .verifyComplete(); + + // Verify task is completed + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.COMPLETED); + }).verifyComplete(); + + // Verify result can be retrieved + StepVerifier.create(taskStore.getTaskResult(taskIdRef.get(), null)).consumeNextWith(retrievedResult -> { + assertThat(retrievedResult).isInstanceOf(CallToolResult.class); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreListTasks() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a few tasks + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .expectNextCount(1) + .verifyComplete(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .expectNextCount(1) + .verifyComplete(); + + // List tasks + StepVerifier.create(taskStore.listTasks(null, null)).consumeNextWith(result -> { + assertThat(result.tasks()).isNotNull(); + assertThat(result.tasks()).hasSizeGreaterThanOrEqualTo(2); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreRequestCancellation() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Request cancellation + StepVerifier.create(taskStore.requestCancellation(taskIdRef.get(), null)).consumeNextWith(task -> { + assertThat(task.taskId()).isEqualTo(taskIdRef.get()); + }).verifyComplete(); + + // Verify cancellation was requested + StepVerifier.create(taskStore.isCancellationRequested(taskIdRef.get(), null)).consumeNextWith(isCancelled -> { + assertThat(isCancelled).isTrue(); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportRequired() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var tool = createSimpleTaskTool(TEST_TASK_TOOL_NAME, TaskSupportMode.REQUIRED, "Task completed!"); + var server = createTaskServer(taskStore, tool); + + assertThat(server.getTaskStore()).isSameAs(taskStore); + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportOptional() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var tool = createSimpleTaskTool(TEST_TASK_TOOL_NAME, TaskSupportMode.OPTIONAL, "Done"); + var server = createTaskServer(taskStore, tool); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTerminalStateCannotTransition() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create and complete a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Complete the task + CallToolResult result = CallToolResult.builder().content(List.of()).isError(false).build(); + StepVerifier.create(taskStore.storeTaskResult(taskIdRef.get(), null, TaskStatus.COMPLETED, result)) + .verifyComplete(); + + // Trying to update status should fail or be ignored (implementation-dependent) + // The InMemoryTaskStore silently ignores invalid transitions + StepVerifier.create(taskStore.updateTaskStatus(taskIdRef.get(), null, TaskStatus.WORKING, "Should not work")) + .verifyComplete(); + + // Status should still be COMPLETED + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.COMPLETED); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // CreateTaskHandler Tests + // --------------------------------------- + + @Test + void testToolWithCreateTaskHandler() { + // Test that a tool with createTaskHandler can be registered + TaskStore taskStore = new InMemoryTaskStore<>(); + + // Create a tool with createTaskHandler + var tool = TaskAwareAsyncToolSpecification.builder() + .name("create-task-handler-tool") + .description("A tool using createTaskHandler") + .createTaskHandler((args, extra) -> extra.createTask(opts -> opts.requestedTtl(60000L).pollInterval(1000L)) + .flatMap(task -> { + // Store result immediately for this test + CallToolResult result = CallToolResult.builder() + .addTextContent("Created via createTaskHandler") + .isError(false) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + })) + .build(); + + var server = createTaskServer(taskStore, tool); + + assertThat(server.getTaskStore()).isSameAs(taskStore); + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testToolWithAllThreeHandlers() { + // Test that a tool with all three handlers can be registered + TaskStore taskStore = new InMemoryTaskStore<>(); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("three-handler-tool") + .description("A tool with createTask, getTask, and getTaskResult handlers") + .createTaskHandler((args, extra) -> extra.createTask() + .map(task -> McpSchema.CreateTaskResult.builder().task(task).build())) + .getTaskHandler((exchange, request) -> { + // Custom getTask handler + return Mono.just(McpSchema.GetTaskResult.builder() + .taskId(request.taskId()) + .status(TaskStatus.WORKING) + .statusMessage("Custom status from handler") + .build()); + }) + .getTaskResultHandler((exchange, request) -> { + // Custom getTaskResult handler + return Mono.just(CallToolResult.builder().addTextContent("Custom result from handler").build()); + }) + .build(); + + var server = createTaskServer(taskStore, tool); + + // Verify all handlers are set + assertThat(tool.createTaskHandler()).isNotNull(); + assertThat(tool.getTaskHandler()).isNotNull(); + assertThat(tool.getTaskResultHandler()).isNotNull(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void builderShouldThrowWhenNormalToolAndTaskToolShareSameName() { + String duplicateName = "duplicate-tool-name"; + + Tool normalTool = McpSchema.Tool.builder() + .name(duplicateName) + .title("A normal tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> { + prepareAsyncServerBuilder() + .tool(normalTool, + (exchange, args) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .taskTools(TaskAwareAsyncToolSpecification.builder() + .name(duplicateName) + .description("A task tool") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .build()) + .build(); + }).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("already registered") + .hasMessageContaining(duplicateName); + } + + @Test + void builderShouldThrowWhenTaskToolsRegisteredWithoutTaskStore() { + assertThatThrownBy(() -> { + prepareAsyncServerBuilder() + .taskTools(TaskAwareAsyncToolSpecification.builder() + .name("task-tool-without-store") + .description("A task tool that needs a TaskStore") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .build()) + // Note: NOT setting .taskStore() + .build(); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Task-aware tools registered but no TaskStore configured"); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java index 1f5387f37..13e085bf8 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -9,6 +9,7 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -22,7 +23,14 @@ import java.util.stream.Collectors; import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskMessageQueue; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue; +import io.modelcontextprotocol.experimental.tasks.TaskStore; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -40,7 +48,15 @@ import io.modelcontextprotocol.spec.McpSchema.PromptReference; import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ResponseMessage; +import io.modelcontextprotocol.spec.McpSchema.ResultMessage; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.ServerTaskCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TaskCreatedMessage; +import io.modelcontextprotocol.spec.McpSchema.TaskMetadata; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TaskStatusMessage; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.util.Utils; @@ -56,10 +72,12 @@ import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertWith; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; +// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -849,6 +867,49 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { } } + /** + * Verifies that calling a normal (non-task-aware) tool with task metadata returns an + * error instead of silently stripping the task metadata. + */ + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testNormalToolRejectsTaskMetadata(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("normal-tool") + .description("A normal tool that does not support tasks") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + return CallToolResult.builder().addTextContent("This should not be reached").build(); + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + mcpClient.initialize(); + + // Call normal tool WITH task metadata - should return error, not silently + // strip + var requestWithTaskMetadata = new McpSchema.CallToolRequest("normal-tool", Map.of(), + McpSchema.TaskMetadata.builder().ttl(Duration.ofMillis(60000L)).build(), null); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> mcpClient.callTool(requestWithTaskMetadata)) + .satisfies(error -> { + assertThat(error.getJsonRpcError().code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND); + assertThat(error.getMessage()).contains("normal-tool"); + assertThat(error.getMessage()).contains("does not support task-augmented requests"); + }); + } + finally { + mcpServer.closeGracefully(); + } + } + @ParameterizedTest(name = "{0} : {displayName} ") @MethodSource("clientsForTesting") void testToolCallSuccessWithTransportContextExtraction(String clientType) { @@ -1753,4 +1814,812 @@ private double evaluateExpression(String expression) { }; } + // =================================================================== + // Task Lifecycle Integration Tests + // =================================================================== + + /** Default server capabilities with tasks enabled for task lifecycle tests. */ + protected static final ServerCapabilities TASK_SERVER_CAPABILITIES = ServerCapabilities.builder() + .tasks(ServerTaskCapabilities.builder().list().cancel().toolsCall().build()) + .tools(true) + .build(); + + /** Default client capabilities with tasks enabled for task lifecycle tests. */ + protected static final ClientCapabilities TASK_CLIENT_CAPABILITIES = ClientCapabilities.builder() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder().list().cancel().build()) + .build(); + + /** Default task metadata for test calls. */ + protected static final TaskMetadata DEFAULT_TASK_METADATA = TaskMetadata.builder() + .ttl(Duration.ofMillis(60000L)) + .build(); + + /** Default request timeout for task test clients. */ + protected static final Duration TASK_REQUEST_TIMEOUT = Duration.ofSeconds(30); + + /** Creates a server with task capabilities and the given task-aware tools. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore, + TaskAwareAsyncToolSpecification... taskTools) { + return createTaskServer(taskStore, null, taskTools); + } + + /** Creates a server with task capabilities, message queue, and task-aware tools. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore, + TaskMessageQueue messageQueue, TaskAwareAsyncToolSpecification... taskTools) { + var builder = prepareAsyncServerBuilder().serverInfo("task-test-server", "1.0.0") + .capabilities(TASK_SERVER_CAPABILITIES) + .taskStore(taskStore); + + if (messageQueue != null) { + builder.taskMessageQueue(messageQueue); + } + + if (taskTools != null && taskTools.length > 0) { + builder.taskTools(taskTools); + } + + return builder.build(); + } + + /** Creates a client with task capabilities. */ + protected McpSyncClient createTaskClient(String clientType, String name) { + return createTaskClient(clientType, name, TASK_CLIENT_CAPABILITIES, null); + } + + /** Creates a client with custom capabilities and optional elicitation handler. */ + protected McpSyncClient createTaskClient(String clientType, String name, ClientCapabilities capabilities, + Function elicitationHandler) { + var builder = clientBuilders.get(clientType) + .clientInfo(new McpSchema.Implementation(name, "0.0.0")) + .capabilities(capabilities) + .requestTimeout(TASK_REQUEST_TIMEOUT); + + if (elicitationHandler != null) { + builder.elicitation(elicitationHandler); + } + + return builder.build(); + } + + /** Extracts the task ID from a list of response messages. */ + protected String extractTaskId(List> messages) { + for (var msg : messages) { + if (msg instanceof TaskCreatedMessage tcm) { + return tcm.task().taskId(); + } + } + return null; + } + + /** Extracts all task statuses from a list of response messages. */ + protected List extractTaskStatuses(List> messages) { + List statuses = new ArrayList<>(); + for (var msg : messages) { + if (msg instanceof TaskCreatedMessage tcm) { + statuses.add(tcm.task().status()); + } + else if (msg instanceof TaskStatusMessage tsm) { + statuses.add(tsm.task().status()); + } + } + return statuses; + } + + /** + * Asserts that task status transitions are valid (no transitions from terminal + * states). + */ + protected void assertValidStateTransitions(List statuses) { + TaskStatus previousState = null; + for (TaskStatus state : statuses) { + if (previousState != null) { + assertThat(previousState).as("Terminal states cannot transition to other states") + .isNotIn(TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED); + } + previousState = state; + } + } + + // ===== List Tasks Tests ===== + + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testListTasks(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + try (var client = createTaskClient(clientType, "Task Test Client")) { + client.initialize(); + + var result = client.listTasks(); + assertThat(result).isNotNull(); + assertThat(result.tasks()).isNotNull(); + } + finally { + server.closeGracefully().block(); + } + } + + // ===== INPUT_REQUIRED and Elicitation Flow Tests ===== + + /** + * Test: Elicitation during task execution. + * + *

+ * This test demonstrates the elicitation flow during task execution: + *

    + *
  1. Client calls task-augmented tool + *
  2. Tool creates task in WORKING state + *
  3. Tool needs user input → sends elicitation request + *
  4. Client responds to elicitation + *
  5. Task continues → COMPLETED + *
+ */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testElicitationDuringTaskExecution(String clientType) throws InterruptedException { + TaskStore taskStore = new InMemoryTaskStore<>(); + TaskMessageQueue messageQueue = new InMemoryTaskMessageQueue(); + + AtomicReference taskIdRef = new AtomicReference<>(); + AtomicReference elicitationResponse = new AtomicReference<>(); + CountDownLatch elicitationReceivedLatch = new CountDownLatch(1); + + // Tool that needs user input during execution + BiFunction> handler = (exchange, + request) -> { + String taskId = exchange.getCurrentTaskId(); + if (taskId == null) { + return Mono.error(new RuntimeException("Task ID not available")); + } + taskIdRef.set(taskId); + + return exchange.createElicitation(new ElicitRequest("Please provide a number:", null, null, null)) + .doOnNext(result -> { + elicitationResponse.set(result.content() != null && !result.content().isEmpty() + ? result.content().get("value").toString() : "no-response"); + elicitationReceivedLatch.countDown(); + }) + .then(Mono.defer(() -> Mono.just(CallToolResult.builder() + .content(List.of(new TextContent("Got user input: " + elicitationResponse.get()))) + .isError(false) + .build()))); + }; + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("needs-input-tool") + .description("Test tool requiring input") + .taskSupportMode(TaskSupportMode.REQUIRED) + .createTaskHandler((args, extra) -> extra.createTask() + .flatMap(task -> handler + .apply(extra.exchange().withTaskContext(task.taskId()), + new McpSchema.CallToolRequest("needs-input-tool", args, null, null)) + .flatMap(result -> extra.taskStore() + .storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result) + .thenReturn(task)) + .onErrorResume(error -> extra.taskStore() + .updateTaskStatus(task.taskId(), null, TaskStatus.FAILED, error.getMessage()) + .thenReturn(task))) + .map(t -> McpSchema.CreateTaskResult.builder().task(t).build())) + .build(); + + var server = createTaskServer(taskStore, messageQueue, tool); + + ClientCapabilities elicitationCapabilities = ClientCapabilities.builder() + .elicitation() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder().list().cancel().build()) + .build(); + + try (var client = createTaskClient(clientType, "Elicitation Test Client", elicitationCapabilities, + (elicitRequest) -> new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("value", "42"), null))) { + client.initialize(); + + var request = new McpSchema.CallToolRequest("needs-input-tool", Map.of(), DEFAULT_TASK_METADATA, null); + var messages = client.callToolStream(request); + var observedStates = extractTaskStatuses(messages); + + assertThat(taskIdRef.get()).as("Task ID should have been set").isNotNull(); + + boolean elicitationCompleted = elicitationReceivedLatch.await(10, TimeUnit.SECONDS); + assertThat(elicitationCompleted).as("Elicitation should be received and processed").isTrue(); + + await().atMost(Duration.ofSeconds(15)).untilAsserted(() -> { + var task = client.getTask(McpSchema.GetTaskRequest.builder().taskId(taskIdRef.get()).build()); + assertThat(task.status()).isIn(TaskStatus.COMPLETED, TaskStatus.FAILED); + }); + + assertThat(elicitationResponse.get()).isEqualTo("42"); + assertValidStateTransitions(observedStates); + } + finally { + server.closeGracefully().block(); + } + } + + // ===== Task Capability Negotiation Tests ===== + + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testServerReportsTaskCapabilities(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + try (var client = createTaskClient(clientType, "Task Test Client")) { + var initResult = client.initialize(); + assertThat(initResult).isNotNull(); + assertThat(initResult.capabilities()).isNotNull(); + assertThat(initResult.capabilities().tasks()).isNotNull(); + } + finally { + server.closeGracefully().block(); + } + } + + // ===== Automatic Polling Shim Tests ===== + + /** + * Tests the automatic polling shim: when a tool with createTaskHandler is called + * WITHOUT task metadata, the server should automatically create a task, poll until + * completion, and return the final result directly. + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testAutomaticPollingShimWithCreateTaskHandler(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // The tool creates a task, stores result immediately, and returns + var tool = TaskAwareAsyncToolSpecification.builder() + .name("auto-polling-tool") + .description("A tool that uses createTaskHandler") + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> extra.createTask(opts -> opts.requestedTtl(60000L).pollInterval(100L)) + .flatMap(task -> { + // Immediately store result (simulating fast completion) + CallToolResult result = CallToolResult.builder() + .addTextContent("Result from createTaskHandler: " + args.getOrDefault("input", "default")) + .isError(false) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + })) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "Auto Polling Test Client", ClientCapabilities.builder().build(), + null)) { + client.initialize(); + + // Call tool WITHOUT task metadata - should trigger automatic polling shim + var request = new McpSchema.CallToolRequest("auto-polling-tool", Map.of("input", "test-value"), null, null); + var messages = client.callToolStream(request); + + // The automatic polling shim should poll and return the final result + assertThat(messages).as("Should have response messages").isNotEmpty(); + + // The last message should be a ResultMessage with the final CallToolResult + ResponseMessage lastMsg = messages.get(messages.size() - 1); + assertThat(lastMsg).as("Last message should be ResultMessage").isInstanceOf(ResultMessage.class); + + ResultMessage resultMsg = (ResultMessage) lastMsg; + assertThat(resultMsg.result()).isNotNull(); + assertThat(resultMsg.result().content()).isNotEmpty(); + + // Verify the content came from our createTaskHandler + TextContent textContent = (TextContent) resultMsg.result().content().get(0); + assertThat(textContent.text()).contains("Result from createTaskHandler").contains("test-value"); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Tests that a tool with createTaskHandler still works correctly when called WITH + * task metadata (the normal task-augmented flow). + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testCreateTaskHandlerWithTaskMetadata(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // Track if createTaskHandler was invoked + AtomicBoolean createTaskHandlerInvoked = new AtomicBoolean(false); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("create-task-tool") + .description("A tool that uses createTaskHandler") + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> { + createTaskHandlerInvoked.set(true); + + return extra.createTask(opts -> opts.pollInterval(500L)).flatMap(task -> { + // Store result immediately + CallToolResult result = CallToolResult.builder() + .addTextContent("Task created via createTaskHandler!") + .isError(false) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + }); + }) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "CreateTask Test Client")) { + client.initialize(); + + // Call with task metadata - should use createTaskHandler directly + var request = new McpSchema.CallToolRequest("create-task-tool", Map.of(), DEFAULT_TASK_METADATA, null); + var messages = client.callToolStream(request); + + assertThat(createTaskHandlerInvoked.get()).as("createTaskHandler should have been invoked").isTrue(); + assertThat(messages).as("Should have response messages").isNotEmpty(); + + // Should have task creation and result messages + String taskId = extractTaskId(messages); + assertThat(taskId).as("Should have created a task").isNotNull(); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Demonstrates wrapping an external async API with tasks. + * + *

+ * Tasks are designed for external services that process jobs asynchronously. Status + * checks happen lazily when the client polls - no background threads needed. + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testExternalAsyncApiPattern(String clientType) throws InterruptedException { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // Simulates an external async API (e.g., a batch processing service) + var externalApi = new SimulatedExternalAsyncApi(); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("external-job") + .description("Submits work to an external async API") + .inputSchema(EMPTY_JSON_SCHEMA) + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> { + // Submit job to external API and use its ID as the MCP task ID + String externalJobId = externalApi.submitJob((String) args.get("input")); + return extra.createTask(opts -> opts.taskId(externalJobId)) + .map(task -> McpSchema.CreateTaskResult.builder().task(task).build()); + }) + .getTaskHandler((exchange, request) -> { + // request.taskId() IS the external job ID - no mapping needed! + SimulatedExternalAsyncApi.JobStatus status = externalApi.checkStatus(request.taskId()); + TaskStatus mcpStatus = switch (status) { + case PENDING, RUNNING -> TaskStatus.WORKING; + case COMPLETED -> TaskStatus.COMPLETED; + case FAILED -> TaskStatus.FAILED; + }; + + // Get timestamps from the TaskStore + return taskStore.getTask(request.taskId(), null) + .map(storeResult -> McpSchema.GetTaskResult.builder() + .taskId(request.taskId()) + .status(mcpStatus) + .statusMessage(status.toString()) + .createdAt(storeResult.task().createdAt()) + .lastUpdatedAt(storeResult.task().lastUpdatedAt()) + .ttl(storeResult.task().ttl()) + .pollInterval(storeResult.task().pollInterval()) + .build()); + }) + .getTaskResultHandler((exchange, request) -> { + // request.taskId() IS the external job ID + String result = externalApi.getResult(request.taskId()); + return Mono.just(CallToolResult.builder().addTextContent(result).build()); + }) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "External API Client")) { + client.initialize(); + + // Submit job via tool call + // Note: TaskMetadata is inlined here (instead of using a constant) for + // copy-paste clarity when using this test as an example + var request = new McpSchema.CallToolRequest("external-job", Map.of("input", "test-data"), + McpSchema.TaskMetadata.builder().ttl(Duration.ofMillis(60000L)).build(), null); + var createResult = client.callToolTask(request); + + assertThat(createResult.task()).isNotNull(); + String taskId = createResult.task().taskId(); + + // Poll until external job completes + await().atMost(Duration.ofSeconds(10)).pollInterval(Duration.ofMillis(100)).untilAsserted(() -> { + var task = client.getTask(taskId); + assertThat(task.status()).isEqualTo(TaskStatus.COMPLETED); + }); + + // Fetch result + var result = client.getTaskResult(taskId, new TypeRef() { + }); + + assertThat(result.content()).hasSize(1); + assertThat(((TextContent) result.content().get(0)).text()).contains("Processed: test-data"); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Test: Task cancellation workflow example. + * + *

+ * This test demonstrates the correct pattern for task cancellation, including: + *

    + *
  1. Creating a task that takes time to complete
  2. + *
  3. Requesting cancellation while task is running
  4. + *
  5. Verifying the task is in CANCELLED state
  6. + *
  7. Verifying that cancelling a terminal task returns an error
  8. + *
+ * + *

+ * Per the MCP specification, cancellation of tasks in terminal status MUST be + * rejected with error code -32602 (Invalid params). + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testTaskCancellationWorkflow(String clientType) throws InterruptedException { + TaskStore taskStore = new InMemoryTaskStore<>(); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("slow-job") + .description("A job that takes a while to complete") + .inputSchema(EMPTY_JSON_SCHEMA) + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> extra.createTask() + .map(task -> McpSchema.CreateTaskResult.builder().task(task).build())) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "Cancellation Test Client")) { + client.initialize(); + + // Create a task + var request = new McpSchema.CallToolRequest("slow-job", Map.of(), + McpSchema.TaskMetadata.builder().ttl(Duration.ofMillis(60000L)).build(), null); + var createResult = client.callToolTask(request); + + assertThat(createResult.task()).isNotNull(); + String taskId = createResult.task().taskId(); + + // Verify task is in WORKING state + var task = client.getTask(taskId); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + + // Cancel the task + var cancelResult = client.cancelTask(taskId); + assertThat(cancelResult.status()).isEqualTo(TaskStatus.CANCELLED); + + // Verify task state persisted as CANCELLED + task = client.getTask(taskId); + assertThat(task.status()).isEqualTo(TaskStatus.CANCELLED); + + // Attempt to cancel the already-cancelled task - should fail with -32602 + assertThatThrownBy(() -> client.cancelTask(taskId)).isInstanceOf(McpError.class).satisfies(e -> { + McpError error = (McpError) e; + assertThat(error.getJsonRpcError().code()).isEqualTo(McpSchema.ErrorCodes.INVALID_PARAMS); + assertThat(error.getMessage()).contains("terminal"); + }); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Test: Task failure handling example using SimulatedExternalAsyncApi. + * + *

+ * This test demonstrates how to handle task failures in the external API pattern: + *

    + *
  1. Submit a job that will fail
  2. + *
  3. Poll for task status until it reaches FAILED
  4. + *
  5. Verify the error information is accessible
  6. + *
+ */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testTaskFailureHandling(String clientType) throws InterruptedException { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // Simulates an external async API that will fail + var externalApi = new SimulatedExternalAsyncApi(); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("failing-job") + .description("Submits work that will fail") + .inputSchema(EMPTY_JSON_SCHEMA) + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> { + // Submit job but with a non-existent ID to simulate failure + return extra.createTask(opts -> opts.taskId("non-existent-job")) + .map(task -> McpSchema.CreateTaskResult.builder().task(task).build()); + }) + .getTaskHandler((exchange, request) -> { + // Check status - this job doesn't exist in the API, so it will report + // FAILED + SimulatedExternalAsyncApi.JobStatus status = externalApi.checkStatus(request.taskId()); + TaskStatus mcpStatus = switch (status) { + case PENDING, RUNNING -> TaskStatus.WORKING; + case COMPLETED -> TaskStatus.COMPLETED; + case FAILED -> TaskStatus.FAILED; + }; + + return taskStore.getTask(request.taskId(), null) + .map(storeResult -> McpSchema.GetTaskResult.builder() + .taskId(request.taskId()) + .status(mcpStatus) + .statusMessage(mcpStatus == TaskStatus.FAILED ? "External job failed" : null) + .createdAt(storeResult.task().createdAt()) + .lastUpdatedAt(storeResult.task().lastUpdatedAt()) + .ttl(storeResult.task().ttl()) + .pollInterval(storeResult.task().pollInterval()) + .build()); + }) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "Failure Test Client")) { + client.initialize(); + + // Submit job that will fail + var request = new McpSchema.CallToolRequest("failing-job", Map.of(), + McpSchema.TaskMetadata.builder().ttl(Duration.ofMillis(60000L)).build(), null); + var createResult = client.callToolTask(request); + + assertThat(createResult.task()).isNotNull(); + String taskId = createResult.task().taskId(); + + // The job should immediately report as FAILED since it doesn't exist in the + // API + var task = client.getTask(taskId); + assertThat(task.status()).isEqualTo(TaskStatus.FAILED); + assertThat(task.statusMessage()).contains("failed"); + } + finally { + server.closeGracefully().block(); + } + } + + // ===== Client-Side Task Hosting Tests ===== + + /** + * Test: Client-side task hosting for sampling requests. + * + *

+ * This test verifies that when a server sends a task-augmented sampling request to a + * client that has a taskStore configured, the client correctly: + *

    + *
  1. Creates a task in its local taskStore + *
  2. Returns CreateTaskResult immediately + *
  3. Executes the sampling handler in the background + *
  4. Stores the result when complete + *
+ */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testClientSideTaskHostingForSampling(String clientType) throws InterruptedException { + CountDownLatch samplingHandlerInvoked = new CountDownLatch(1); + AtomicReference receivedPrompt = new AtomicReference<>(); + + // Create a server with a tool that sends task-augmented sampling to client + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("trigger-sampling") + .description("Triggers a task-augmented sampling request to client") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // Send task-augmented sampling request to client + CreateMessageRequest samplingRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(Role.USER, new TextContent("Test prompt")))) + .systemPrompt("system-prompt") + .maxTokens(100) + .task(TaskMetadata.builder().ttl(Duration.ofMillis(30000L)).build()) + .build(); + + return exchange.createMessageTask(samplingRequest).flatMap(createTaskResult -> { + // Poll for task completion + String taskId = createTaskResult.task().taskId(); + return pollForTaskCompletion(exchange, taskId, new TypeRef() { + }).map(result -> CallToolResult.builder() + .content(List.of(new TextContent("Sampling task completed: " + taskId))) + .build()); + }); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + // Create client with taskStore for hosting tasks + TaskStore clientTaskStore = new InMemoryTaskStore<>(); + var clientBuilder = clientBuilders.get(clientType) + .clientInfo(new McpSchema.Implementation("Task-Hosting Client", "1.0.0")) + .capabilities(ClientCapabilities.builder() + .sampling() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder() + .list() + .cancel() + .samplingCreateMessage() + .build()) + .build()) + .taskStore(clientTaskStore) + .sampling(request -> { + receivedPrompt.set(request.systemPrompt()); + samplingHandlerInvoked.countDown(); + return new CreateMessageResult(Role.ASSISTANT, new TextContent("Test response"), "model-id", + CreateMessageResult.StopReason.END_TURN); + }); + + try (var client = clientBuilder.build()) { + client.initialize(); + + // Trigger the tool which will send task-augmented sampling to client + var result = client.callTool(new McpSchema.CallToolRequest("trigger-sampling", Map.of())); + + // Verify sampling handler was invoked + boolean handlerInvoked = samplingHandlerInvoked.await(10, TimeUnit.SECONDS); + assertThat(handlerInvoked).as("Sampling handler should have been invoked").isTrue(); + assertThat(receivedPrompt.get()).isEqualTo("system-prompt"); + + // Verify the tool completed successfully + assertThat(result.content()).isNotEmpty(); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Test: Client-side task hosting for elicitation requests. + * + *

+ * Similar to sampling, verifies task-augmented elicitation works correctly when the + * client has a taskStore configured. + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testClientSideTaskHostingForElicitation(String clientType) throws InterruptedException { + CountDownLatch elicitationHandlerInvoked = new CountDownLatch(1); + AtomicReference receivedMessage = new AtomicReference<>(); + + // Create a server with a tool that sends task-augmented elicitation to client + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("trigger-elicitation") + .description("Triggers a task-augmented elicitation request to client") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // Send task-augmented elicitation request to client + ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please enter your name:") + .task(TaskMetadata.builder().ttl(Duration.ofMillis(30000L)).build()) + .build(); + + return exchange.createElicitationTask(elicitRequest).flatMap(createTaskResult -> { + // Poll for task completion + String taskId = createTaskResult.task().taskId(); + return pollForTaskCompletion(exchange, taskId, new TypeRef() { + }).map(result -> CallToolResult.builder() + .content(List.of(new TextContent("Elicitation task completed: " + taskId))) + .build()); + }); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + // Create client with taskStore for hosting tasks + TaskStore clientTaskStore = new InMemoryTaskStore<>(); + var clientBuilder = clientBuilders.get(clientType) + .clientInfo(new McpSchema.Implementation("Task-Hosting Client", "1.0.0")) + .capabilities(ClientCapabilities.builder() + .elicitation() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder().list().cancel().elicitationCreate().build()) + .build()) + .taskStore(clientTaskStore) + .elicitation(request -> { + receivedMessage.set(request.message()); + elicitationHandlerInvoked.countDown(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("name", "Test User"), null); + }); + + try (var client = clientBuilder.build()) { + client.initialize(); + + // Trigger the tool which will send task-augmented elicitation to client + var result = client.callTool(new McpSchema.CallToolRequest("trigger-elicitation", Map.of())); + + // Verify elicitation handler was invoked + boolean handlerInvoked = elicitationHandlerInvoked.await(10, TimeUnit.SECONDS); + assertThat(handlerInvoked).as("Elicitation handler should have been invoked").isTrue(); + assertThat(receivedMessage.get()).isEqualTo("Please enter your name:"); + + // Verify the tool completed successfully + assertThat(result.content()).isNotEmpty(); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Helper to poll for task completion on client-hosted tasks. + * @param The expected result type (e.g., CreateMessageResult, ElicitResult) + */ + private Mono pollForTaskCompletion(McpAsyncServerExchange exchange, + String taskId, TypeRef resultTypeRef) { + return Mono.defer(() -> exchange.getTask(McpSchema.GetTaskRequest.builder().taskId(taskId).build())) + .flatMap(task -> { + if (task.status().isTerminal()) { + return exchange.getTaskResult(McpSchema.GetTaskPayloadRequest.builder().taskId(taskId).build(), + resultTypeRef); + } + return Mono.delay(Duration.ofMillis(100)).then(pollForTaskCompletion(exchange, taskId, resultTypeRef)); + }) + .timeout(Duration.ofSeconds(30)); + } + + /** + * Simulates an external async API (e.g., batch processing, ML inference). + */ + static class SimulatedExternalAsyncApi { + + enum JobStatus { + + PENDING, RUNNING, COMPLETED, FAILED + + } + + private final ConcurrentHashMap jobs = new ConcurrentHashMap<>(); + + private record JobState(String input, long completionTime) { + } + + String submitJob(String input) { + String jobId = "job-" + System.nanoTime(); + // Job completes after 300ms + jobs.put(jobId, new JobState(input, System.currentTimeMillis() + 300)); + return jobId; + } + + JobStatus checkStatus(String jobId) { + JobState state = jobs.get(jobId); + if (state == null) { + return JobStatus.FAILED; + } + return System.currentTimeMillis() >= state.completionTime ? JobStatus.COMPLETED : JobStatus.RUNNING; + } + + String getResult(String jobId) { + JobState state = jobs.get(jobId); + return state != null ? "Processed: " + state.input : "Error: job not found"; + } + + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 915c658e3..fed74df59 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -6,6 +6,10 @@ import java.util.List; +import io.modelcontextprotocol.experimental.tasks.CreateTaskOptions; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskTestUtils; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -14,6 +18,8 @@ import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; @@ -40,6 +46,8 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; + private static final String TEST_TASK_TOOL_NAME = "task-tool"; + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { @@ -675,4 +683,262 @@ void testRootsChangeHandlers() { assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); } + // --------------------------------------- + // Tasks Tests + // --------------------------------------- + + /** Creates a server with task capabilities and the given task store. */ + protected McpSyncServer createTaskServer(TaskStore taskStore) { + return prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .build(); + } + + @Test + void testServerWithTaskStore() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + assertThat(server.getAsyncServer().getTaskStore()).isSameAs(taskStore); + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreCreateAndGet() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task (blocking) + var task = taskStore.createTask( + CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).requestedTtl(60000L).build()) + .block(); + + assertThat(task).isNotNull(); + assertThat(task.taskId()).isNotNull().isNotEmpty(); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + + // Get the task (blocking) + var storeResult = taskStore.getTask(task.taskId(), null).block(); + var retrievedTask = storeResult.task(); + + assertThat(retrievedTask).isNotNull(); + assertThat(retrievedTask.taskId()).isEqualTo(task.taskId()); + assertThat(retrievedTask.status()).isEqualTo(TaskStatus.WORKING); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreUpdateStatus() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Update status + taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.WORKING, "Processing...").block(); + + // Verify status updated + var updatedTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(updatedTask).isNotNull(); + assertThat(updatedTask.status()).isEqualTo(TaskStatus.WORKING); + assertThat(updatedTask.statusMessage()).isEqualTo("Processing..."); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreStoreResult() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Store result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Done!"))) + .isError(false) + .build(); + + taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result).block(); + + // Verify task is completed + var completedTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(completedTask).isNotNull(); + assertThat(completedTask.status()).isEqualTo(TaskStatus.COMPLETED); + + // Verify result can be retrieved + var retrievedResult = taskStore.getTaskResult(task.taskId(), null).block(); + assertThat(retrievedResult).isNotNull(); + assertThat(retrievedResult).isInstanceOf(CallToolResult.class); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreListTasks() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a few tasks + taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()).block(); + taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()).block(); + + // List tasks + var listResult = taskStore.listTasks(null, null).block(); + assertThat(listResult).isNotNull(); + assertThat(listResult.tasks()).isNotNull(); + assertThat(listResult.tasks()).hasSizeGreaterThanOrEqualTo(2); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreRequestCancellation() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Request cancellation + var cancelledTask = taskStore.requestCancellation(task.taskId(), null).block(); + assertThat(cancelledTask).isNotNull(); + assertThat(cancelledTask.taskId()).isEqualTo(task.taskId()); + + // Verify cancellation was requested + var isCancelled = taskStore.isCancellationRequested(task.taskId(), null).block(); + assertThat(isCancelled).isTrue(); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportRequired() { + TaskStore taskStore = new InMemoryTaskStore<>(); + Tool taskTool = TaskTestUtils.createTaskTool(TEST_TASK_TOOL_NAME, "Task-based tool", TaskSupportMode.REQUIRED); + + var server = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .tool(taskTool, + (exchange, args) -> CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Task completed!"))) + .isError(false) + .build()) + .build(); + + assertThat(server.getAsyncServer().getTaskStore()).isSameAs(taskStore); + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportOptional() { + TaskStore taskStore = new InMemoryTaskStore<>(); + Tool taskTool = TaskTestUtils.createTaskTool(TEST_TASK_TOOL_NAME, "Optional task tool", + TaskSupportMode.OPTIONAL); + + var server = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .tool(taskTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTerminalStateCannotTransition() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create and complete a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Complete the task + CallToolResult result = CallToolResult.builder().content(List.of()).isError(false).build(); + taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result).block(); + + // Trying to update status should fail or be ignored (implementation-dependent) + // The InMemoryTaskStore silently ignores invalid transitions + taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.WORKING, "Should not work").block(); + + // Status should still be COMPLETED + var finalTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(finalTask).isNotNull(); + assertThat(finalTask.status()).isEqualTo(TaskStatus.COMPLETED); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + /** + * Example: Using sync tool specification for external API pattern. + * + *

+ * This test demonstrates the sync equivalent of the external API pattern shown in + * integration tests. The key differences from the async version are: + *

    + *
  1. Use {@code TaskAwareSyncToolSpecification} instead of async variant
  2. + *
  3. Handlers return values directly instead of {@code Mono}
  4. + *
  5. Task store calls use {@code .block()} for synchronous execution
  6. + *
+ * + *

+ * This example shows how to create a task-aware sync tool that wraps an external + * async API, demonstrating that the pattern works the same way regardless of whether + * you're using the sync or async server API. + */ + @Test + void testSyncExternalApiPatternExample() { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // For this example, we simulate an external API call and manually create a + // task + // This demonstrates the sync tool pattern equivalent to the async + // testExternalAsyncApiPattern + + var server = createTaskServer(taskStore); + + // Step 1: Create a task (simulating what a sync createTask handler would do) + var task = taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("external-job")) + .taskId("external-job-123") + .requestedTtl(60000L) + .build()) + .block(); + + assertThat(task).isNotNull(); + assertThat(task.taskId()).isEqualTo("external-job-123"); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + + // Step 2: Simulate external API completing the job and storing result + // storeTaskResult atomically sets the terminal status AND stores the result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Processed: test-data"))) + .isError(false) + .build(); + taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result).block(); + + // Verify final state + var finalTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(finalTask).isNotNull(); + assertThat(finalTask.status()).isEqualTo(TaskStatus.COMPLETED); + + var finalResult = taskStore.getTaskResult(task.taskId(), null).block(); + assertThat(finalResult).isNotNull(); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index 640d34c9c..7c820924c 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -24,6 +24,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -695,4 +696,300 @@ void testPingMultipleCalls() { verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } + // --------------------------------------- + // List Tasks Tests + // --------------------------------------- + + @Test + void testListTasksWithNullCapabilities() { + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + + StepVerifier.create(exchangeWithNullCapabilities.listTasks("cursor")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_TASKS_LIST), any(), any(TypeRef.class)); + } + + @Test + void testListTasksWithoutTasksCapabilities() { + McpSchema.ClientCapabilities capabilitiesWithoutTasks = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange exchangeWithoutTasks = new McpAsyncServerExchange(mockSession, capabilitiesWithoutTasks, + clientInfo); + + StepVerifier.create(exchangeWithoutTasks.listTasks("cursor")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be configured with tasks capabilities"); + }); + + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_TASKS_LIST), any(), any(TypeRef.class)); + } + + @Test + void testListTasksWithoutListCapability() { + McpSchema.ClientCapabilities capabilitiesWithoutList = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().build()) + .build(); + + McpAsyncServerExchange exchangeWithoutList = new McpAsyncServerExchange(mockSession, capabilitiesWithoutList, + clientInfo); + + StepVerifier.create(exchangeWithoutList.listTasks("cursor")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be configured with tasks.list capability"); + }); + + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_TASKS_LIST), any(), any(TypeRef.class)); + } + + @Test + void testListTasksWithSinglePage() { + McpSchema.ClientCapabilities capabilitiesWithTasks = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().list().build()) + .build(); + + McpAsyncServerExchange exchangeWithTasks = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithTasks, clientInfo, McpTransportContext.EMPTY); + + List tasks = Arrays.asList( + McpSchema.Task.builder() + .taskId("task-1") + .status(McpSchema.TaskStatus.WORKING) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:00:00Z") + .build(), + McpSchema.Task.builder() + .taskId("task-2") + .status(McpSchema.TaskStatus.COMPLETED) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:01:00Z") + .build()); + McpSchema.ListTasksResult singlePageResult = McpSchema.ListTasksResult.builder().tasks(tasks).build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_TASKS_LIST), any(McpSchema.PaginatedRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.just(singlePageResult)); + + StepVerifier.create(exchangeWithTasks.listTasks()).assertNext(result -> { + assertThat(result.tasks()).hasSize(2); + assertThat(result.tasks().get(0).taskId()).isEqualTo("task-1"); + assertThat(result.tasks().get(1).taskId()).isEqualTo("task-2"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.tasks() + .add(McpSchema.Task.builder() + .taskId("test") + .status(McpSchema.TaskStatus.WORKING) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:00:00Z") + .build())) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + @Test + void testListTasksWithMultiplePages() { + McpSchema.ClientCapabilities capabilitiesWithTasks = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().list().build()) + .build(); + + McpAsyncServerExchange exchangeWithTasks = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithTasks, clientInfo, McpTransportContext.EMPTY); + + List page1Tasks = Arrays.asList(McpSchema.Task.builder() + .taskId("task-1") + .status(McpSchema.TaskStatus.WORKING) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:00:00Z") + .build()); + List page2Tasks = Arrays.asList(McpSchema.Task.builder() + .taskId("task-2") + .status(McpSchema.TaskStatus.COMPLETED) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:01:00Z") + .build()); + + McpSchema.ListTasksResult page1Result = McpSchema.ListTasksResult.builder() + .tasks(page1Tasks) + .nextCursor("cursor1") + .build(); + McpSchema.ListTasksResult page2Result = McpSchema.ListTasksResult.builder().tasks(page2Tasks).build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_TASKS_LIST), eq(new McpSchema.PaginatedRequest(null)), + any(TypeRef.class))) + .thenReturn(Mono.just(page1Result)); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_TASKS_LIST), eq(new McpSchema.PaginatedRequest("cursor1")), + any(TypeRef.class))) + .thenReturn(Mono.just(page2Result)); + + StepVerifier.create(exchangeWithTasks.listTasks()).assertNext(result -> { + assertThat(result.tasks()).hasSize(2); + assertThat(result.tasks().get(0).taskId()).isEqualTo("task-1"); + assertThat(result.tasks().get(1).taskId()).isEqualTo("task-2"); + assertThat(result.nextCursor()).isNull(); + + // Verify that the returned list is unmodifiable + assertThatThrownBy(() -> result.tasks() + .add(McpSchema.Task.builder() + .taskId("test") + .status(McpSchema.TaskStatus.WORKING) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:00:00Z") + .build())) + .isInstanceOf(UnsupportedOperationException.class); + }).verifyComplete(); + } + + // --------------------------------------- + // Cancel Task Tests + // --------------------------------------- + + @Test + void testCancelTaskWithNullCapabilities() { + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange(mockSession, null, clientInfo); + + StepVerifier + .create(exchangeWithNullCapabilities + .cancelTask(McpSchema.CancelTaskRequest.builder().taskId("task-1").build())) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_TASKS_CANCEL), any(), any(TypeRef.class)); + } + + @Test + void testCancelTaskWithoutTasksCapabilities() { + McpSchema.ClientCapabilities capabilitiesWithoutTasks = McpSchema.ClientCapabilities.builder() + .roots(true) + .build(); + + McpAsyncServerExchange exchangeWithoutTasks = new McpAsyncServerExchange(mockSession, capabilitiesWithoutTasks, + clientInfo); + + StepVerifier + .create(exchangeWithoutTasks.cancelTask(McpSchema.CancelTaskRequest.builder().taskId("task-1").build())) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be configured with tasks capabilities"); + }); + + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_TASKS_CANCEL), any(), any(TypeRef.class)); + } + + @Test + void testCancelTaskWithoutCancelCapability() { + McpSchema.ClientCapabilities capabilitiesWithoutCancel = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().build()) + .build(); + + McpAsyncServerExchange exchangeWithoutCancel = new McpAsyncServerExchange(mockSession, + capabilitiesWithoutCancel, clientInfo); + + StepVerifier + .create(exchangeWithoutCancel.cancelTask(McpSchema.CancelTaskRequest.builder().taskId("task-1").build())) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be configured with tasks.cancel capability"); + }); + + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_TASKS_CANCEL), any(), any(TypeRef.class)); + } + + @Test + void testCancelTaskSuccess() { + McpSchema.ClientCapabilities capabilitiesWithTasks = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().cancel().build()) + .build(); + + McpAsyncServerExchange exchangeWithTasks = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithTasks, clientInfo, McpTransportContext.EMPTY); + + McpSchema.CancelTaskResult expectedResult = McpSchema.CancelTaskResult.builder() + .taskId("task-1") + .status(McpSchema.TaskStatus.CANCELLED) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:01:00Z") + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_TASKS_CANCEL), + argThat((McpSchema.CancelTaskRequest req) -> "task-1".equals(req.taskId())), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier + .create(exchangeWithTasks.cancelTask(McpSchema.CancelTaskRequest.builder().taskId("task-1").build())) + .assertNext(result -> { + assertThat(result.taskId()).isEqualTo("task-1"); + assertThat(result.status()).isEqualTo(McpSchema.TaskStatus.CANCELLED); + }) + .verifyComplete(); + } + + @Test + void testCancelTaskByIdSuccess() { + McpSchema.ClientCapabilities capabilitiesWithTasks = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().cancel().build()) + .build(); + + McpAsyncServerExchange exchangeWithTasks = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithTasks, clientInfo, McpTransportContext.EMPTY); + + McpSchema.CancelTaskResult expectedResult = McpSchema.CancelTaskResult.builder() + .taskId("task-1") + .status(McpSchema.TaskStatus.CANCELLED) + .createdAt("2024-01-01T00:00:00Z") + .lastUpdatedAt("2024-01-01T00:01:00Z") + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_TASKS_CANCEL), + argThat((McpSchema.CancelTaskRequest req) -> "task-1".equals(req.taskId())), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithTasks.cancelTask("task-1")).assertNext(result -> { + assertThat(result.taskId()).isEqualTo("task-1"); + assertThat(result.status()).isEqualTo(McpSchema.TaskStatus.CANCELLED); + }).verifyComplete(); + } + + @Test + void testCancelTaskByIdWithNullId() { + McpSchema.ClientCapabilities capabilitiesWithTasks = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().cancel().build()) + .build(); + + McpAsyncServerExchange exchangeWithTasks = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithTasks, clientInfo, McpTransportContext.EMPTY); + + assertThatThrownBy(() -> exchangeWithTasks.cancelTask((String) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Task ID must not be null or empty"); + } + + @Test + void testCancelTaskWithSessionError() { + McpSchema.ClientCapabilities capabilitiesWithTasks = McpSchema.ClientCapabilities.builder() + .tasks(McpSchema.ClientCapabilities.ClientTaskCapabilities.builder().cancel().build()) + .build(); + + McpAsyncServerExchange exchangeWithTasks = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithTasks, clientInfo, McpTransportContext.EMPTY); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_TASKS_CANCEL), any(McpSchema.CancelTaskRequest.class), + any(TypeRef.class))) + .thenReturn(Mono.error(new RuntimeException("Session communication error"))); + + StepVerifier + .create(exchangeWithTasks.cancelTask(McpSchema.CancelTaskRequest.builder().taskId("task-1").build())) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Session communication error"); + }); + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 0b5ce55cd..e4311af9b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -317,9 +317,11 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { Optional contentType = response.headers().contentType(); long contentLength = response.headers().contentLength().orElse(-1); // Existing SDKs consume notifications with no response body nor - // content type - if (contentType.isEmpty() || contentLength == 0 - || response.statusCode().equals(HttpStatus.ACCEPTED)) { + // content type. Per the MCP spec, 202 Accepted is used to + // acknowledge + // notifications/responses where no reply is expected. + boolean isAccepted = response.statusCode() == HttpStatus.ACCEPTED; + if (contentType.isEmpty() || contentLength == 0 || isAccepted) { logger.trace("Message was successfully sent via POST for session {}", sessionRepresentation); // signal the caller that the message was successfully diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 270bc4308..d332833a2 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol; @@ -9,8 +9,10 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -22,7 +24,16 @@ import java.util.stream.Collectors; import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskMessageQueue; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskMessageQueue; +import io.modelcontextprotocol.experimental.tasks.TaskStore; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpAsyncServer; +import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; @@ -44,7 +55,15 @@ import io.modelcontextprotocol.spec.McpSchema.PromptReference; import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ResponseMessage; +import io.modelcontextprotocol.spec.McpSchema.ResultMessage; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.ServerTaskCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TaskCreatedMessage; +import io.modelcontextprotocol.spec.McpSchema.TaskMetadata; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TaskStatusMessage; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.util.Utils; @@ -64,6 +83,7 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; +// KEEP IN SYNC with the class in mcp-core module public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -1757,4 +1777,669 @@ private double evaluateExpression(String expression) { }; } + // =================================================================== + // Task Lifecycle Integration Tests + // =================================================================== + + /** Default server capabilities with tasks enabled for task lifecycle tests. */ + protected static final ServerCapabilities TASK_SERVER_CAPABILITIES = ServerCapabilities.builder() + .tasks(ServerTaskCapabilities.builder().list().cancel().toolsCall().build()) + .tools(true) + .build(); + + /** Default client capabilities with tasks enabled for task lifecycle tests. */ + protected static final ClientCapabilities TASK_CLIENT_CAPABILITIES = ClientCapabilities.builder() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder().list().cancel().build()) + .build(); + + /** Default task metadata for test calls. */ + protected static final TaskMetadata DEFAULT_TASK_METADATA = TaskMetadata.builder() + .ttl(Duration.ofMillis(60000L)) + .build(); + + /** Default request timeout for task test clients. */ + protected static final Duration TASK_REQUEST_TIMEOUT = Duration.ofSeconds(30); + + /** Creates a server with task capabilities and the given tools. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore, + TaskAwareAsyncToolSpecification... taskTools) { + return createTaskServer(taskStore, null, taskTools); + } + + /** Creates a server with task capabilities, message queue, and task-aware tools. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore, + TaskMessageQueue messageQueue, TaskAwareAsyncToolSpecification... taskTools) { + var builder = prepareAsyncServerBuilder().serverInfo("task-test-server", "1.0.0") + .capabilities(TASK_SERVER_CAPABILITIES) + .taskStore(taskStore); + + if (messageQueue != null) { + builder.taskMessageQueue(messageQueue); + } + + if (taskTools != null && taskTools.length > 0) { + builder.taskTools(taskTools); + } + + return builder.build(); + } + + /** Creates a client with task capabilities. */ + protected McpSyncClient createTaskClient(String clientType, String name) { + return createTaskClient(clientType, name, TASK_CLIENT_CAPABILITIES, null); + } + + /** Creates a client with custom capabilities and optional elicitation handler. */ + protected McpSyncClient createTaskClient(String clientType, String name, ClientCapabilities capabilities, + Function elicitationHandler) { + var builder = clientBuilders.get(clientType) + .clientInfo(new McpSchema.Implementation(name, "0.0.0")) + .capabilities(capabilities) + .requestTimeout(TASK_REQUEST_TIMEOUT); + + if (elicitationHandler != null) { + builder.elicitation(elicitationHandler); + } + + return builder.build(); + } + + /** Extracts the task ID from a list of response messages. */ + protected String extractTaskId(List> messages) { + for (var msg : messages) { + if (msg instanceof TaskCreatedMessage tcm) { + return tcm.task().taskId(); + } + } + return null; + } + + /** Extracts all task statuses from a list of response messages. */ + protected List extractTaskStatuses(List> messages) { + List statuses = new ArrayList<>(); + for (var msg : messages) { + if (msg instanceof TaskCreatedMessage tcm) { + statuses.add(tcm.task().status()); + } + else if (msg instanceof TaskStatusMessage tsm) { + statuses.add(tsm.task().status()); + } + } + return statuses; + } + + /** + * Asserts that task status transitions are valid (no transitions from terminal + * states). + */ + protected void assertValidStateTransitions(List statuses) { + TaskStatus previousState = null; + for (TaskStatus state : statuses) { + if (previousState != null) { + assertThat(previousState).as("Terminal states cannot transition to other states") + .isNotIn(TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED); + } + previousState = state; + } + } + + // ===== Task-Augmented Tool Call Tests ===== + + /** + * Demonstrates wrapping an external async API with tasks. + * + *

+ * Tasks are designed for external services that process jobs asynchronously. Status + * checks happen lazily when the client polls - no background threads needed. + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testExternalAsyncApiPattern(String clientType) throws InterruptedException { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // Simulates an external async API (e.g., a batch processing service) + var externalApi = new SimulatedExternalAsyncApi(); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("external-job") + .description("Submits work to an external async API") + .inputSchema(EMPTY_JSON_SCHEMA) + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> { + // Submit job to external API and use its ID as the MCP task ID + String externalJobId = externalApi.submitJob((String) args.get("input")); + return extra.createTask(opts -> opts.taskId(externalJobId)) + .map(task -> McpSchema.CreateTaskResult.builder().task(task).build()); + }) + .getTaskHandler((exchange, request) -> { + // request.taskId() IS the external job ID - no mapping needed! + SimulatedExternalAsyncApi.JobStatus status = externalApi.checkStatus(request.taskId()); + TaskStatus mcpStatus = switch (status) { + case PENDING, RUNNING -> TaskStatus.WORKING; + case COMPLETED -> TaskStatus.COMPLETED; + case FAILED -> TaskStatus.FAILED; + }; + + // Get timestamps from the TaskStore + return taskStore.getTask(request.taskId(), null) + .map(storeResult -> McpSchema.GetTaskResult.builder() + .taskId(request.taskId()) + .status(mcpStatus) + .statusMessage(status.toString()) + .createdAt(storeResult.task().createdAt()) + .lastUpdatedAt(storeResult.task().lastUpdatedAt()) + .ttl(storeResult.task().ttl()) + .pollInterval(storeResult.task().pollInterval()) + .build()); + }) + .getTaskResultHandler((exchange, request) -> { + // request.taskId() IS the external job ID + String result = externalApi.getResult(request.taskId()); + return Mono.just(CallToolResult.builder().addTextContent(result).build()); + }) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "External API Client")) { + client.initialize(); + + // Submit job via tool call + var request = new McpSchema.CallToolRequest("external-job", Map.of("input", "test-data"), + DEFAULT_TASK_METADATA, null); + var createResult = client.callToolTask(request); + + assertThat(createResult.task()).isNotNull(); + String taskId = createResult.task().taskId(); + + // Poll until external job completes + await().atMost(Duration.ofSeconds(10)).pollInterval(Duration.ofMillis(100)).untilAsserted(() -> { + var task = client.getTask(taskId); + assertThat(task.status()).isEqualTo(TaskStatus.COMPLETED); + }); + + // Fetch result + var result = client.getTaskResult(taskId, new TypeRef() { + }); + + assertThat(result.content()).hasSize(1); + assertThat(((TextContent) result.content().get(0)).text()).contains("Processed: test-data"); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Simulates an external async API (e.g., batch processing, ML inference). + */ + static class SimulatedExternalAsyncApi { + + enum JobStatus { + + PENDING, RUNNING, COMPLETED, FAILED + + } + + private final ConcurrentHashMap jobs = new ConcurrentHashMap<>(); + + private record JobState(String input, long completionTime) { + } + + String submitJob(String input) { + String jobId = "job-" + UUID.randomUUID().toString().substring(0, 8); + jobs.put(jobId, new JobState(input, System.currentTimeMillis() + 300)); + return jobId; + } + + JobStatus checkStatus(String jobId) { + JobState state = jobs.get(jobId); + if (state == null) { + return JobStatus.FAILED; + } + return System.currentTimeMillis() >= state.completionTime ? JobStatus.COMPLETED : JobStatus.RUNNING; + } + + String getResult(String jobId) { + JobState state = jobs.get(jobId); + return state != null ? "Processed: " + state.input : "Error: job not found"; + } + + } + + // ===== List Tasks Tests ===== + + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testListTasks(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + try (var client = createTaskClient(clientType, "Task Test Client")) { + client.initialize(); + + var result = client.listTasks(); + assertThat(result).isNotNull(); + assertThat(result.tasks()).isNotNull(); + } + finally { + server.closeGracefully().block(); + } + } + + // ===== INPUT_REQUIRED and Elicitation Flow Tests ===== + + /** + * Test: Elicitation during task execution. + * + *

+ * This test demonstrates the elicitation flow during task execution: + *

    + *
  1. Client calls task-augmented tool + *
  2. Tool creates task in WORKING state + *
  3. Tool needs user input → sends elicitation request + *
  4. Client responds to elicitation + *
  5. Task continues → COMPLETED + *
+ */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testElicitationDuringTaskExecution(String clientType) throws InterruptedException { + TaskStore taskStore = new InMemoryTaskStore<>(); + TaskMessageQueue messageQueue = new InMemoryTaskMessageQueue(); + + AtomicReference taskIdRef = new AtomicReference<>(); + AtomicReference elicitationResponse = new AtomicReference<>(); + CountDownLatch elicitationReceivedLatch = new CountDownLatch(1); + + // Tool that needs user input during execution + BiFunction> handler = (exchange, + request) -> { + String taskId = exchange.getCurrentTaskId(); + if (taskId == null) { + return Mono.error(new RuntimeException("Task ID not available")); + } + taskIdRef.set(taskId); + + return exchange.createElicitation(new ElicitRequest("Please provide a number:", null, null, null)) + .doOnNext(result -> { + elicitationResponse.set(result.content() != null && !result.content().isEmpty() + ? result.content().get("value").toString() : "no-response"); + elicitationReceivedLatch.countDown(); + }) + .then(Mono.defer(() -> Mono.just(CallToolResult.builder() + .content(List.of(new TextContent("Got user input: " + elicitationResponse.get()))) + .isError(false) + .build()))); + }; + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("needs-input-tool") + .description("Test tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .taskSupportMode(TaskSupportMode.REQUIRED) + .createTaskHandler((args, extra) -> extra.createTask().flatMap(task -> { + McpSchema.CallToolRequest syntheticRequest = new McpSchema.CallToolRequest("needs-input-tool", args, + null, null); + return handler.apply(extra.exchange().withTaskContext(task.taskId()), syntheticRequest) + .flatMap(result -> extra.taskStore() + .storeTaskResult(task.taskId(), extra.sessionId(), TaskStatus.COMPLETED, result) + .thenReturn(task)) + .onErrorResume(error -> extra.taskStore() + .updateTaskStatus(task.taskId(), extra.sessionId(), TaskStatus.FAILED, error.getMessage()) + .thenReturn(task)); + }).map(task -> McpSchema.CreateTaskResult.builder().task(task).build())) + .build(); + + var server = createTaskServer(taskStore, messageQueue, tool); + + ClientCapabilities elicitationCapabilities = ClientCapabilities.builder() + .elicitation() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder().list().cancel().build()) + .build(); + + try (var client = createTaskClient(clientType, "Elicitation Test Client", elicitationCapabilities, + (elicitRequest) -> new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("value", "42"), null))) { + client.initialize(); + + var request = new McpSchema.CallToolRequest("needs-input-tool", Map.of(), DEFAULT_TASK_METADATA, null); + var messages = client.callToolStream(request); + var observedStates = extractTaskStatuses(messages); + + if (taskIdRef.get() != null) { + boolean elicitationCompleted = elicitationReceivedLatch.await(10, TimeUnit.SECONDS); + assertThat(elicitationCompleted).as("Elicitation should be received and processed").isTrue(); + + await().atMost(Duration.ofSeconds(15)).untilAsserted(() -> { + var task = client.getTask(McpSchema.GetTaskRequest.builder().taskId(taskIdRef.get()).build()); + assertThat(task.status()).isIn(TaskStatus.COMPLETED, TaskStatus.FAILED); + }); + + assertThat(elicitationResponse.get()).isEqualTo("42"); + assertValidStateTransitions(observedStates); + } + } + finally { + server.closeGracefully().block(); + } + } + + // ===== Task Capability Negotiation Tests ===== + + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testServerReportsTaskCapabilities(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + try (var client = createTaskClient(clientType, "Task Test Client")) { + var initResult = client.initialize(); + assertThat(initResult).isNotNull(); + assertThat(initResult.capabilities()).isNotNull(); + assertThat(initResult.capabilities().tasks()).isNotNull(); + } + finally { + server.closeGracefully().block(); + } + } + + // ===== Automatic Polling Shim Tests ===== + + /** + * Tests the automatic polling shim: when a tool with createTaskHandler is called + * WITHOUT task metadata, the server should automatically create a task, poll until + * completion, and return the final result directly. + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testAutomaticPollingShimWithCreateTaskHandler(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // The tool creates a task, stores result immediately, and returns + var tool = TaskAwareAsyncToolSpecification.builder() + .name("auto-polling-tool") + .description("A tool that uses createTaskHandler") + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> extra.createTask(opts -> opts.requestedTtl(60000L).pollInterval(100L)) + .flatMap(task -> { + // Immediately store result (simulating fast completion) + CallToolResult result = CallToolResult.builder() + .addTextContent("Result from createTaskHandler: " + args.getOrDefault("input", "default")) + .isError(false) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), extra.sessionId(), TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + })) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "Auto Polling Test Client", ClientCapabilities.builder().build(), + null)) { + client.initialize(); + + // Call tool WITHOUT task metadata - should trigger automatic polling shim + var request = new McpSchema.CallToolRequest("auto-polling-tool", Map.of("input", "test-value"), null, null); + var messages = client.callToolStream(request); + + // The automatic polling shim should poll and return the final result + assertThat(messages).as("Should have response messages").isNotEmpty(); + + // The last message should be a ResultMessage with the final CallToolResult + ResponseMessage lastMsg = messages.get(messages.size() - 1); + assertThat(lastMsg).as("Last message should be ResultMessage").isInstanceOf(ResultMessage.class); + + ResultMessage resultMsg = (ResultMessage) lastMsg; + assertThat(resultMsg.result()).isNotNull(); + assertThat(resultMsg.result().content()).isNotEmpty(); + + // Verify the content came from our createTaskHandler + TextContent textContent = (TextContent) resultMsg.result().content().get(0); + assertThat(textContent.text()).contains("Result from createTaskHandler").contains("test-value"); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Tests that a tool with createTaskHandler still works correctly when called WITH + * task metadata (the normal task-augmented flow). + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testCreateTaskHandlerWithTaskMetadata(String clientType) { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // Track if createTaskHandler was invoked + AtomicBoolean createTaskHandlerInvoked = new AtomicBoolean(false); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("create-task-tool") + .description("A tool that uses createTaskHandler") + .taskSupportMode(TaskSupportMode.OPTIONAL) + .createTaskHandler((args, extra) -> { + createTaskHandlerInvoked.set(true); + + return extra.createTask(opts -> opts.pollInterval(500L)).flatMap(task -> { + // Store result immediately + CallToolResult result = CallToolResult.builder() + .addTextContent("Task created via createTaskHandler!") + .isError(false) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), extra.sessionId(), TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + }); + }) + .build(); + + var server = createTaskServer(taskStore, tool); + + try (var client = createTaskClient(clientType, "CreateTask Test Client")) { + client.initialize(); + + // Call with task metadata - should use createTaskHandler directly + var request = new McpSchema.CallToolRequest("create-task-tool", Map.of(), DEFAULT_TASK_METADATA, null); + var messages = client.callToolStream(request); + + assertThat(createTaskHandlerInvoked.get()).as("createTaskHandler should have been invoked").isTrue(); + assertThat(messages).as("Should have response messages").isNotEmpty(); + + // Should have task creation and result messages + String taskId = extractTaskId(messages); + assertThat(taskId).as("Should have created a task").isNotNull(); + } + finally { + server.closeGracefully().block(); + } + } + + // ===== Client-Side Task Hosting Tests ===== + + /** + * Test: Client-side task hosting for sampling requests. + * + *

+ * This test verifies that when a server sends a task-augmented sampling request to a + * client that has a taskStore configured, the client correctly: + *

    + *
  1. Creates a task in its local taskStore + *
  2. Returns CreateTaskResult immediately + *
  3. Executes the sampling handler in the background + *
  4. Stores the result when complete + *
+ */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testClientSideTaskHostingForSampling(String clientType) throws InterruptedException { + CountDownLatch samplingHandlerInvoked = new CountDownLatch(1); + AtomicReference receivedPrompt = new AtomicReference<>(); + + // Create a server with a tool that sends task-augmented sampling to client + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("trigger-sampling") + .description("Triggers a task-augmented sampling request to client") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // Send task-augmented sampling request to client + CreateMessageRequest samplingRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(Role.USER, new TextContent("Test prompt")))) + .systemPrompt("system-prompt") + .maxTokens(100) + .task(TaskMetadata.builder().ttl(Duration.ofMillis(30000L)).build()) + .build(); + + return exchange.createMessageTask(samplingRequest).flatMap(createTaskResult -> { + // Poll for task completion + String taskId = createTaskResult.task().taskId(); + return pollForTaskCompletion(exchange, taskId, new TypeRef() { + }).map(result -> CallToolResult.builder() + .content(List.of(new TextContent("Sampling task completed: " + taskId))) + .build()); + }); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + // Create client with taskStore for hosting tasks + TaskStore clientTaskStore = new InMemoryTaskStore<>(); + var clientBuilder = clientBuilders.get(clientType) + .clientInfo(new McpSchema.Implementation("Task-Hosting Client", "1.0.0")) + .capabilities(ClientCapabilities.builder() + .sampling() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder() + .list() + .cancel() + .samplingCreateMessage() + .build()) + .build()) + .taskStore(clientTaskStore) + .sampling(request -> { + receivedPrompt.set(request.systemPrompt()); + samplingHandlerInvoked.countDown(); + return new CreateMessageResult(Role.ASSISTANT, new TextContent("Test response"), "model-id", + CreateMessageResult.StopReason.END_TURN); + }); + + try (var client = clientBuilder.build()) { + client.initialize(); + + // Trigger the tool which will send task-augmented sampling to client + var result = client.callTool(new McpSchema.CallToolRequest("trigger-sampling", Map.of())); + + // Verify sampling handler was invoked + boolean handlerInvoked = samplingHandlerInvoked.await(10, TimeUnit.SECONDS); + assertThat(handlerInvoked).as("Sampling handler should have been invoked").isTrue(); + assertThat(receivedPrompt.get()).isEqualTo("system-prompt"); + + // Verify the tool completed successfully + assertThat(result.content()).isNotEmpty(); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Test: Client-side task hosting for elicitation requests. + * + *

+ * Similar to sampling, verifies task-augmented elicitation works correctly when the + * client has a taskStore configured. + */ + @ParameterizedTest(name = "{0} : {displayName}") + @MethodSource("clientsForTesting") + void testClientSideTaskHostingForElicitation(String clientType) throws InterruptedException { + CountDownLatch elicitationHandlerInvoked = new CountDownLatch(1); + AtomicReference receivedMessage = new AtomicReference<>(); + + // Create a server with a tool that sends task-augmented elicitation to client + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("trigger-elicitation") + .description("Triggers a task-augmented elicitation request to client") + .inputSchema(EMPTY_JSON_SCHEMA) + .build()) + .callHandler((exchange, request) -> { + // Send task-augmented elicitation request to client + ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please enter your name:") + .task(TaskMetadata.builder().ttl(Duration.ofMillis(30000L)).build()) + .build(); + + return exchange.createElicitationTask(elicitRequest).flatMap(createTaskResult -> { + // Poll for task completion + String taskId = createTaskResult.task().taskId(); + return pollForTaskCompletion(exchange, taskId, new TypeRef() { + }).map(result -> CallToolResult.builder() + .content(List.of(new TextContent("Elicitation task completed: " + taskId))) + .build()); + }); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + // Create client with taskStore for hosting tasks + TaskStore clientTaskStore = new InMemoryTaskStore<>(); + var clientBuilder = clientBuilders.get(clientType) + .clientInfo(new McpSchema.Implementation("Task-Hosting Client", "1.0.0")) + .capabilities(ClientCapabilities.builder() + .elicitation() + .tasks(ClientCapabilities.ClientTaskCapabilities.builder().list().cancel().elicitationCreate().build()) + .build()) + .taskStore(clientTaskStore) + .elicitation(request -> { + receivedMessage.set(request.message()); + elicitationHandlerInvoked.countDown(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("name", "Test User"), null); + }); + + try (var client = clientBuilder.build()) { + client.initialize(); + + // Trigger the tool which will send task-augmented elicitation to client + var result = client.callTool(new McpSchema.CallToolRequest("trigger-elicitation", Map.of())); + + // Verify elicitation handler was invoked + boolean handlerInvoked = elicitationHandlerInvoked.await(10, TimeUnit.SECONDS); + assertThat(handlerInvoked).as("Elicitation handler should have been invoked").isTrue(); + assertThat(receivedMessage.get()).isEqualTo("Please enter your name:"); + + // Verify the tool completed successfully + assertThat(result.content()).isNotEmpty(); + } + finally { + server.closeGracefully().block(); + } + } + + /** + * Helper to poll for task completion on client-hosted tasks. + * @param The expected result type (e.g., CreateMessageResult, ElicitResult) + */ + private Mono pollForTaskCompletion(McpAsyncServerExchange exchange, + String taskId, TypeRef resultTypeRef) { + return Mono.defer(() -> exchange.getTask(McpSchema.GetTaskRequest.builder().taskId(taskId).build())) + .flatMap(task -> { + if (task.status().isTerminal()) { + return exchange.getTaskResult(McpSchema.GetTaskPayloadRequest.builder().taskId(taskId).build(), + resultTypeRef); + } + return Mono.delay(Duration.ofMillis(100)).then(pollForTaskCompletion(exchange, taskId, resultTypeRef)); + }) + .timeout(Duration.ofSeconds(30)); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java new file mode 100644 index 000000000..5d9b5b6a6 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/experimental/tasks/TaskTestUtils.java @@ -0,0 +1,330 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.experimental.tasks; + +import java.time.Duration; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.function.IntConsumer; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.JsonSchema; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.ServerTaskCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Task; +import io.modelcontextprotocol.spec.McpSchema.TaskMetadata; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.ToolExecution; +import reactor.core.publisher.Mono; + +/** + * Testing utilities for MCP tasks. + * + *

+ * This class provides helper methods for testing task-based operations, including polling + * for task status changes and waiting for specific task states. + * + *

+ * This is an experimental API that may change in future releases. + * + */ +public final class TaskTestUtils { + + private TaskTestUtils() { + // Utility class - no instantiation + } + + /** + * Default timeout for waiting for task status changes. + */ + public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(30); + + /** + * Default poll interval for checking task status. + */ + public static final Duration DEFAULT_POLL_INTERVAL = Duration.ofMillis(100); + + /** + * Waits for a task to reach a specific status. + * + *

+ * This method repeatedly calls the provided getTask function at regular intervals + * until the task reaches the desired status or the timeout is exceeded. + * + *

+ * Example usage: + * + *

{@code
+	 * Task completedTask = TaskTestUtils.waitForTaskStatus(
+	 *     taskId -> client.getTask(taskId),
+	 *     "task-123",
+	 *     TaskStatus.COMPLETED
+	 * ).block();
+	 * }
+ * @param getTask function that retrieves a task by ID (returns Mono.empty() if not + * found) + * @param taskId the task identifier to poll + * @param desiredStatus the status to wait for + * @return a Mono emitting the Task once it reaches the desired status + * @throws java.util.concurrent.TimeoutException if the task doesn't reach the desired + * status within the default timeout (30 seconds) + */ + public static Mono waitForTaskStatus(Function> getTask, String taskId, + TaskStatus desiredStatus) { + return waitForTaskStatus(getTask, taskId, desiredStatus, DEFAULT_TIMEOUT, DEFAULT_POLL_INTERVAL); + } + + /** + * Waits for a task to reach a specific status with custom timeout. + * + *

+ * This method repeatedly calls the provided getTask function at regular intervals + * until the task reaches the desired status or the timeout is exceeded. + * @param getTask function that retrieves a task by ID (returns Mono.empty() if not + * found) + * @param taskId the task identifier to poll + * @param desiredStatus the status to wait for + * @param timeout maximum time to wait for the desired status + * @return a Mono emitting the Task once it reaches the desired status + * @throws java.util.concurrent.TimeoutException if the task doesn't reach the desired + * status within the timeout + */ + public static Mono waitForTaskStatus(Function> getTask, String taskId, + TaskStatus desiredStatus, Duration timeout) { + return waitForTaskStatus(getTask, taskId, desiredStatus, timeout, DEFAULT_POLL_INTERVAL); + } + + /** + * Waits for a task to reach a specific status with custom timeout and poll interval. + * + *

+ * This method repeatedly calls the provided getTask function at regular intervals + * until the task reaches the desired status or the timeout is exceeded. + * @param getTask function that retrieves a task by ID (returns Mono.empty() if not + * found) + * @param taskId the task identifier to poll + * @param desiredStatus the status to wait for + * @param timeout maximum time to wait for the desired status + * @param pollInterval interval between status checks + * @return a Mono emitting the Task once it reaches the desired status + * @throws java.util.concurrent.TimeoutException if the task doesn't reach the desired + * status within the timeout + */ + public static Mono waitForTaskStatus(Function> getTask, String taskId, + TaskStatus desiredStatus, Duration timeout, Duration pollInterval) { + return reactor.core.publisher.Flux.interval(pollInterval) + .flatMap(tick -> getTask.apply(taskId)) + .filter(task -> task != null && task.status() == desiredStatus) + .next() + .timeout(timeout); + } + + /** + * Waits for a task to reach any terminal status (COMPLETED, FAILED, or CANCELLED). + * + *

+ * This is useful when you want to wait for a task to finish, regardless of whether it + * succeeds or fails. + * @param getTask function that retrieves a task by ID + * @param taskId the task identifier to poll + * @return a Mono emitting the Task once it reaches a terminal status + * @throws java.util.concurrent.TimeoutException if the task doesn't reach a terminal + * status within the default timeout (30 seconds) + */ + public static Mono waitForTerminal(Function> getTask, String taskId) { + return waitForTerminal(getTask, taskId, DEFAULT_TIMEOUT); + } + + /** + * Waits for a task to reach any terminal status with custom timeout. + * @param getTask function that retrieves a task by ID + * @param taskId the task identifier to poll + * @param timeout maximum time to wait + * @return a Mono emitting the Task once it reaches a terminal status + * @throws java.util.concurrent.TimeoutException if the task doesn't reach a terminal + * status within the timeout + */ + public static Mono waitForTerminal(Function> getTask, String taskId, Duration timeout) { + return reactor.core.publisher.Flux.interval(DEFAULT_POLL_INTERVAL) + .flatMap(tick -> getTask.apply(taskId)) + .filter(task -> task != null && task.isTerminal()) + .next() + .timeout(timeout); + } + + // ------------------------------------------ + // Shared Test Constants + // ------------------------------------------ + + /** + * An empty JSON schema for tools that don't require input parameters. + */ + public static final JsonSchema EMPTY_JSON_SCHEMA = new JsonSchema("object", Collections.emptyMap(), null, null, + null, null); + + /** + * Default task metadata with 60-second TTL for tests. + */ + public static final TaskMetadata DEFAULT_TASK_METADATA = TaskMetadata.builder() + .ttl(Duration.ofMillis(60000L)) + .build(); + + /** + * Default server capabilities with tasks enabled for tests. + */ + public static final ServerCapabilities DEFAULT_TASK_CAPABILITIES = ServerCapabilities.builder() + .tasks(ServerTaskCapabilities.builder().list().cancel().toolsCall().build()) + .tools(true) + .build(); + + /** + * Default tool name for task-supported tools in tests. + */ + public static final String DEFAULT_TASK_TOOL_NAME = "slow-operation"; + + /** + * Default tool arguments for task-supported tools in tests. + */ + public static final Map DEFAULT_TASK_TOOL_ARGS = Map.of("message", "Test message from Java SDK"); + + // ------------------------------------------ + // Shared Test Helpers + // ------------------------------------------ + + /** + * Creates a task-augmented CallToolRequest using the default tool and arguments. + * @return a CallToolRequest with default task metadata + */ + public static CallToolRequest createDefaultTaskRequest() { + return new CallToolRequest(DEFAULT_TASK_TOOL_NAME, DEFAULT_TASK_TOOL_ARGS, DEFAULT_TASK_METADATA, null); + } + + /** + * Creates a task-augmented CallToolRequest with custom tool name and arguments. + * @param toolName the tool name + * @param args the tool arguments + * @return a CallToolRequest with default task metadata + */ + public static CallToolRequest createTaskRequest(String toolName, Map args) { + return new CallToolRequest(toolName, args, DEFAULT_TASK_METADATA, null); + } + + /** + * Creates a tool with the given name and task support mode for tests. + * @param name the tool name + * @param title the tool title + * @param mode the task support mode + * @return a Tool with task execution support + */ + public static Tool createTaskTool(String name, String title, TaskSupportMode mode) { + return McpSchema.Tool.builder() + .name(name) + .title(title) + .inputSchema(EMPTY_JSON_SCHEMA) + .execution(ToolExecution.builder().taskSupport(mode).build()) + .build(); + } + + /** + * Creates a test CallToolRequest for use in CreateTaskOptions. This is the standard + * originating request type used in tests. + * @param toolName the name of the tool for the request + * @return a CallToolRequest with the given tool name and null arguments + */ + public static CallToolRequest createTestRequest(String toolName) { + return new CallToolRequest(toolName, null); + } + + // ------------------------------------------ + // Concurrency Test Helpers + // ------------------------------------------ + + /** + * Runs concurrent operations with proper synchronization. + * + *

+ * This method executes the given operation across multiple threads, ensuring all + * threads start at approximately the same time using a latch. This is useful for + * testing thread safety and concurrent access patterns. + * + *

+ * Example usage: + * + *

{@code
+	 * TaskTestUtils.runConcurrent(100, 10, i -> {
+	 *     taskStore.createTask(null).block();
+	 * });
+	 * }
+ * @param numOperations total number of operations to execute + * @param numThreads number of threads in the thread pool + * @param operation the operation to run, receiving the operation index (0 to + * numOperations-1) + * @param assertionProvider provides the assertion method to use (to avoid test + * dependency in main sources) + * @throws InterruptedException if the current thread is interrupted while waiting + * @throws RuntimeException if operations don't complete within 10 seconds + */ + public static void runConcurrent(int numOperations, int numThreads, IntConsumer operation, + java.util.function.BiConsumer assertionProvider) throws InterruptedException { + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(numOperations); + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + try { + for (int i = 0; i < numOperations; i++) { + final int index = i; + executor.submit(() -> { + try { + startLatch.await(); + operation.accept(index); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + finally { + doneLatch.countDown(); + } + }); + } + startLatch.countDown(); + boolean completed = doneLatch.await(10, TimeUnit.SECONDS); + assertionProvider.accept(completed, "Operations did not complete within 10 seconds"); + } + finally { + executor.shutdownNow(); + } + } + + /** + * Runs concurrent operations with proper synchronization (simplified version that + * throws on timeout). + * + *

+ * This method executes the given operation across multiple threads, ensuring all + * threads start at approximately the same time using a latch. + * @param numOperations total number of operations to execute + * @param numThreads number of threads in the thread pool + * @param operation the operation to run, receiving the operation index (0 to + * numOperations-1) + * @throws InterruptedException if the current thread is interrupted while waiting + * @throws RuntimeException if operations don't complete within 10 seconds + */ + public static void runConcurrent(int numOperations, int numThreads, IntConsumer operation) + throws InterruptedException { + runConcurrent(numOperations, numThreads, operation, (result, message) -> { + if (!result) { + throw new RuntimeException(message); + } + }); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index d6677ec9a..74235f220 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -1,12 +1,18 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server; import java.time.Duration; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import io.modelcontextprotocol.experimental.tasks.CreateTaskOptions; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskAwareAsyncToolSpecification; +import io.modelcontextprotocol.experimental.tasks.TaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskTestUtils; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -15,13 +21,13 @@ import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -36,6 +42,7 @@ * * @author Christian Tzolov */ +// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -44,6 +51,8 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; + private static final String TEST_TASK_TOOL_NAME = "task-tool"; + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { @@ -64,10 +73,7 @@ void tearDown() { // --------------------------------------- // Server Lifecycle Tests // --------------------------------------- - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "sse", "streamable" }) - void testConstructorWithInvalidArguments(String serverType) { + void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); @@ -723,4 +729,359 @@ void testRootsChangeHandlers() { .doesNotThrowAnyException(); } + // --------------------------------------- + // Tasks Tests + // --------------------------------------- + + /** Creates a server with task capabilities and the given task store. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore) { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .build(); + } + + /** Creates a server with task capabilities, task store, and a task-aware tool. */ + protected McpAsyncServer createTaskServer(TaskStore taskStore, + TaskAwareAsyncToolSpecification taskTool) { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .taskTools(taskTool) + .build(); + } + + /** + * Creates a simple task-aware tool that returns a text result. + */ + protected TaskAwareAsyncToolSpecification createSimpleTaskTool(String name, TaskSupportMode mode, + String resultText) { + return TaskAwareAsyncToolSpecification.builder() + .name(name) + .description("Test task tool") + .taskSupportMode(mode) + .createTaskHandler((args, extra) -> extra.createTask().flatMap(task -> { + // Immediately complete the task with the result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(resultText))) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + })) + .build(); + } + + @Test + void testServerWithTaskStore() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + assertThat(server.getTaskStore()).isSameAs(taskStore); + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreCreateAndGet() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier.create(taskStore.createTask( + CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).requestedTtl(60000L).build())) + .consumeNextWith(task -> { + assertThat(task.taskId()).isNotNull().isNotEmpty(); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Get the task + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().taskId()).isEqualTo(taskIdRef.get()); + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.WORKING); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreUpdateStatus() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Update status + StepVerifier.create(taskStore.updateTaskStatus(taskIdRef.get(), null, TaskStatus.WORKING, "Processing...")) + .verifyComplete(); + + // Verify status updated + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.WORKING); + assertThat(storeResult.task().statusMessage()).isEqualTo("Processing..."); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreStoreResult() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Store result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Done!"))) + .isError(false) + .build(); + + StepVerifier.create(taskStore.storeTaskResult(taskIdRef.get(), null, TaskStatus.COMPLETED, result)) + .verifyComplete(); + + // Verify task is completed + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.COMPLETED); + }).verifyComplete(); + + // Verify result can be retrieved + StepVerifier.create(taskStore.getTaskResult(taskIdRef.get(), null)).consumeNextWith(retrievedResult -> { + assertThat(retrievedResult).isInstanceOf(CallToolResult.class); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreListTasks() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a few tasks + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .expectNextCount(1) + .verifyComplete(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .expectNextCount(1) + .verifyComplete(); + + // List tasks + StepVerifier.create(taskStore.listTasks(null, null)).consumeNextWith(result -> { + assertThat(result.tasks()).isNotNull(); + assertThat(result.tasks()).hasSizeGreaterThanOrEqualTo(2); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreRequestCancellation() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Request cancellation + StepVerifier.create(taskStore.requestCancellation(taskIdRef.get(), null)).consumeNextWith(task -> { + assertThat(task.taskId()).isEqualTo(taskIdRef.get()); + }).verifyComplete(); + + // Verify cancellation was requested + StepVerifier.create(taskStore.isCancellationRequested(taskIdRef.get(), null)).consumeNextWith(isCancelled -> { + assertThat(isCancelled).isTrue(); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportRequired() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var tool = createSimpleTaskTool(TEST_TASK_TOOL_NAME, TaskSupportMode.REQUIRED, "Task completed!"); + var server = createTaskServer(taskStore, tool); + + assertThat(server.getTaskStore()).isSameAs(taskStore); + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportOptional() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var tool = createSimpleTaskTool(TEST_TASK_TOOL_NAME, TaskSupportMode.OPTIONAL, "Done"); + var server = createTaskServer(taskStore, tool); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testTerminalStateCannotTransition() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create and complete a task + AtomicReference taskIdRef = new AtomicReference<>(); + StepVerifier + .create(taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build())) + .consumeNextWith(task -> { + taskIdRef.set(task.taskId()); + }) + .verifyComplete(); + + // Complete the task + CallToolResult result = CallToolResult.builder().content(List.of()).isError(false).build(); + StepVerifier.create(taskStore.storeTaskResult(taskIdRef.get(), null, TaskStatus.COMPLETED, result)) + .verifyComplete(); + + // Trying to update status should fail or be ignored (implementation-dependent) + // The InMemoryTaskStore silently ignores invalid transitions + StepVerifier.create(taskStore.updateTaskStatus(taskIdRef.get(), null, TaskStatus.WORKING, "Should not work")) + .verifyComplete(); + + // Status should still be COMPLETED + StepVerifier.create(taskStore.getTask(taskIdRef.get(), null)).consumeNextWith(storeResult -> { + assertThat(storeResult.task().status()).isEqualTo(TaskStatus.COMPLETED); + }).verifyComplete(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // CreateTaskHandler Tests + // --------------------------------------- + + @Test + void testToolWithCreateTaskHandler() { + // Test that a tool with createTaskHandler can be registered + TaskStore taskStore = new InMemoryTaskStore<>(); + + // Create a tool with createTaskHandler + var tool = TaskAwareAsyncToolSpecification.builder() + .name("create-task-handler-tool") + .description("A tool using createTaskHandler") + .createTaskHandler((args, extra) -> extra.createTask(opts -> opts.requestedTtl(60000L).pollInterval(1000L)) + .flatMap(task -> { + // Store result immediately for this test + CallToolResult result = CallToolResult.builder() + .addTextContent("Created via createTaskHandler") + .isError(false) + .build(); + return extra.taskStore() + .storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result) + .thenReturn(McpSchema.CreateTaskResult.builder().task(task).build()); + })) + .build(); + + var server = createTaskServer(taskStore, tool); + + assertThat(server.getTaskStore()).isSameAs(taskStore); + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testToolWithAllThreeHandlers() { + // Test that a tool with all three handlers can be registered + TaskStore taskStore = new InMemoryTaskStore<>(); + + var tool = TaskAwareAsyncToolSpecification.builder() + .name("three-handler-tool") + .description("A tool with createTask, getTask, and getTaskResult handlers") + .createTaskHandler((args, extra) -> extra.createTask() + .map(task -> McpSchema.CreateTaskResult.builder().task(task).build())) + .getTaskHandler((exchange, request) -> { + // Custom getTask handler + return Mono.just(McpSchema.GetTaskResult.builder() + .taskId(request.taskId()) + .status(TaskStatus.WORKING) + .statusMessage("Custom status from handler") + .build()); + }) + .getTaskResultHandler((exchange, request) -> { + // Custom getTaskResult handler + return Mono.just(CallToolResult.builder().addTextContent("Custom result from handler").build()); + }) + .build(); + + var server = createTaskServer(taskStore, tool); + + // Verify all handlers are set + assertThat(tool.createTaskHandler()).isNotNull(); + assertThat(tool.getTaskHandler()).isNotNull(); + assertThat(tool.getTaskResultHandler()).isNotNull(); + + assertThatCode(() -> server.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void builderShouldThrowWhenNormalToolAndTaskToolShareSameName() { + String duplicateName = "duplicate-tool-name"; + + Tool normalTool = McpSchema.Tool.builder() + .name(duplicateName) + .title("A normal tool") + .inputSchema(EMPTY_JSON_SCHEMA) + .build(); + + assertThatThrownBy(() -> { + prepareAsyncServerBuilder() + .tool(normalTool, + (exchange, args) -> Mono + .just(CallToolResult.builder().content(List.of()).isError(false).build())) + .taskTools(TaskAwareAsyncToolSpecification.builder() + .name(duplicateName) + .description("A task tool") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .build()) + .build(); + }).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("already registered") + .hasMessageContaining(duplicateName); + } + + @Test + void builderShouldThrowWhenTaskToolsRegisteredWithoutTaskStore() { + assertThatThrownBy(() -> { + prepareAsyncServerBuilder() + .taskTools(TaskAwareAsyncToolSpecification.builder() + .name("task-tool-without-store") + .description("A task tool that needs a TaskStore") + .createTaskHandler((args, extra) -> Mono.just(McpSchema.CreateTaskResult.builder().build())) + .build()) + // Note: NOT setting .taskStore() + .build(); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Task-aware tools registered but no TaskStore configured"); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 0a59d0aae..ed0038661 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -1,11 +1,15 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.server; import java.util.List; +import io.modelcontextprotocol.experimental.tasks.CreateTaskOptions; +import io.modelcontextprotocol.experimental.tasks.InMemoryTaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskStore; +import io.modelcontextprotocol.experimental.tasks.TaskTestUtils; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; @@ -14,6 +18,8 @@ import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TaskStatus; +import io.modelcontextprotocol.spec.McpSchema.TaskSupportMode; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; @@ -31,6 +37,7 @@ * * @author Christian Tzolov */ +// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -39,6 +46,8 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; + private static final String TEST_TASK_TOOL_NAME = "task-tool"; + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { @@ -321,7 +330,6 @@ void testAddResource() { Resource resource = Resource.builder() .uri(TEST_RESOURCE_URI) .name("Test Resource") - .title("Test Resource") .mimeType("text/plain") .description("Test resource description") .build(); @@ -353,7 +361,6 @@ void testAddResourceWithoutCapability() { Resource resource = Resource.builder() .uri(TEST_RESOURCE_URI) .name("Test Resource") - .title("Test Resource") .mimeType("text/plain") .description("Test resource description") .build(); @@ -383,7 +390,6 @@ void testListResources() { Resource resource = Resource.builder() .uri(TEST_RESOURCE_URI) .name("Test Resource") - .title("Test Resource") .mimeType("text/plain") .description("Test resource description") .build(); @@ -408,7 +414,6 @@ void testRemoveResource() { Resource resource = Resource.builder() .uri(TEST_RESOURCE_URI) .name("Test Resource") - .title("Test Resource") .mimeType("text/plain") .description("Test resource description") .build(); @@ -678,4 +683,262 @@ void testRootsChangeHandlers() { assertThatCode(noConsumersServer::closeGracefully).doesNotThrowAnyException(); } + // --------------------------------------- + // Tasks Tests + // --------------------------------------- + + /** Creates a server with task capabilities and the given task store. */ + protected McpSyncServer createTaskServer(TaskStore taskStore) { + return prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .build(); + } + + @Test + void testServerWithTaskStore() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + assertThat(server.getAsyncServer().getTaskStore()).isSameAs(taskStore); + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreCreateAndGet() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task (blocking) + var task = taskStore.createTask( + CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).requestedTtl(60000L).build()) + .block(); + + assertThat(task).isNotNull(); + assertThat(task.taskId()).isNotNull().isNotEmpty(); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + + // Get the task (blocking) + var storeResult = taskStore.getTask(task.taskId(), null).block(); + var retrievedTask = storeResult.task(); + + assertThat(retrievedTask).isNotNull(); + assertThat(retrievedTask.taskId()).isEqualTo(task.taskId()); + assertThat(retrievedTask.status()).isEqualTo(TaskStatus.WORKING); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreUpdateStatus() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Update status + taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.WORKING, "Processing...").block(); + + // Verify status updated + var updatedTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(updatedTask).isNotNull(); + assertThat(updatedTask.status()).isEqualTo(TaskStatus.WORKING); + assertThat(updatedTask.statusMessage()).isEqualTo("Processing..."); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreStoreResult() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Store result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Done!"))) + .isError(false) + .build(); + + taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result).block(); + + // Verify task is completed + var completedTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(completedTask).isNotNull(); + assertThat(completedTask.status()).isEqualTo(TaskStatus.COMPLETED); + + // Verify result can be retrieved + var retrievedResult = taskStore.getTaskResult(task.taskId(), null).block(); + assertThat(retrievedResult).isNotNull(); + assertThat(retrievedResult).isInstanceOf(CallToolResult.class); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreListTasks() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a few tasks + taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()).block(); + taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()).block(); + + // List tasks + var listResult = taskStore.listTasks(null, null).block(); + assertThat(listResult).isNotNull(); + assertThat(listResult.tasks()).isNotNull(); + assertThat(listResult.tasks()).hasSizeGreaterThanOrEqualTo(2); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTaskStoreRequestCancellation() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Request cancellation + var cancelledTask = taskStore.requestCancellation(task.taskId(), null).block(); + assertThat(cancelledTask).isNotNull(); + assertThat(cancelledTask.taskId()).isEqualTo(task.taskId()); + + // Verify cancellation was requested + var isCancelled = taskStore.isCancellationRequested(task.taskId(), null).block(); + assertThat(isCancelled).isTrue(); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportRequired() { + TaskStore taskStore = new InMemoryTaskStore<>(); + Tool taskTool = TaskTestUtils.createTaskTool(TEST_TASK_TOOL_NAME, "Task-based tool", TaskSupportMode.REQUIRED); + + var server = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .tool(taskTool, + (exchange, args) -> CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Task completed!"))) + .isError(false) + .build()) + .build(); + + assertThat(server.getAsyncServer().getTaskStore()).isSameAs(taskStore); + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testToolWithTaskSupportOptional() { + TaskStore taskStore = new InMemoryTaskStore<>(); + Tool taskTool = TaskTestUtils.createTaskTool(TEST_TASK_TOOL_NAME, "Optional task tool", + TaskSupportMode.OPTIONAL); + + var server = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(TaskTestUtils.DEFAULT_TASK_CAPABILITIES) + .taskStore(taskStore) + .tool(taskTool, (exchange, args) -> CallToolResult.builder().content(List.of()).isError(false).build()) + .build(); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + @Test + void testTerminalStateCannotTransition() { + TaskStore taskStore = new InMemoryTaskStore<>(); + var server = createTaskServer(taskStore); + + // Create and complete a task + var task = taskStore.createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("test-tool")).build()) + .block(); + assertThat(task).isNotNull(); + + // Complete the task + CallToolResult result = CallToolResult.builder().content(List.of()).isError(false).build(); + taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result).block(); + + // Trying to update status should fail or be ignored (implementation-dependent) + // The InMemoryTaskStore silently ignores invalid transitions + taskStore.updateTaskStatus(task.taskId(), null, TaskStatus.WORKING, "Should not work").block(); + + // Status should still be COMPLETED + var finalTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(finalTask).isNotNull(); + assertThat(finalTask.status()).isEqualTo(TaskStatus.COMPLETED); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + + /** + * Example: Using sync tool specification for external API pattern. + * + *

+ * This test demonstrates the sync equivalent of the external API pattern shown in + * integration tests. The key differences from the async version are: + *

    + *
  1. Use {@code TaskAwareSyncToolSpecification} instead of async variant
  2. + *
  3. Handlers return values directly instead of {@code Mono}
  4. + *
  5. Task store calls use {@code .block()} for synchronous execution
  6. + *
+ * + *

+ * This example shows how to create a task-aware sync tool that wraps an external + * async API, demonstrating that the pattern works the same way regardless of whether + * you're using the sync or async server API. + */ + @Test + void testSyncExternalApiPatternExample() { + TaskStore taskStore = new InMemoryTaskStore<>(); + + // For this example, we simulate an external API call and manually create a + // task + // This demonstrates the sync tool pattern equivalent to the async + // testExternalAsyncApiPattern + + var server = createTaskServer(taskStore); + + // Step 1: Create a task (simulating what a sync createTask handler would do) + var task = taskStore + .createTask(CreateTaskOptions.builder(TaskTestUtils.createTestRequest("external-job")) + .taskId("external-job-123") + .requestedTtl(60000L) + .build()) + .block(); + + assertThat(task).isNotNull(); + assertThat(task.taskId()).isEqualTo("external-job-123"); + assertThat(task.status()).isEqualTo(TaskStatus.WORKING); + + // Step 2: Simulate external API completing the job and storing result + // storeTaskResult atomically sets the terminal status AND stores the result + CallToolResult result = CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Processed: test-data"))) + .isError(false) + .build(); + taskStore.storeTaskResult(task.taskId(), null, TaskStatus.COMPLETED, result).block(); + + // Verify final state + var finalTask = taskStore.getTask(task.taskId(), null).block().task(); + assertThat(finalTask).isNotNull(); + assertThat(finalTask.status()).isEqualTo(TaskStatus.COMPLETED); + + var finalResult = taskStore.getTaskResult(task.taskId(), null).block(); + assertThat(finalResult).isNotNull(); + + assertThatCode(server::closeGracefully).doesNotThrowAnyException(); + } + }