diff --git a/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegen.java b/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegen.java index 0d9e2dc7..59ee9e4b 100644 --- a/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegen.java +++ b/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegen.java @@ -88,6 +88,7 @@ import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_RESOURCE_UNSUBSCRIBER_INTERFACE; import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_ROLE; import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_ROLE_ENUM; +import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_SAMPLING; import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_SERVER; import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_SERVER_CONFIG; import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_TOOL; @@ -448,6 +449,10 @@ private void addResourceMethod(Method.Builder builder, String uri, ClassModel.Bu parameters.add("request.features().cancellation()"); continue; } + if (MCP_SAMPLING.equals(parameter.typeName())) { + parameters.add("request.features().sampling()"); + continue; + } if (isResourceTemplate(uri)) { if (MCP_PARAMETERS.equals(parameter.typeName())) { parameters.add("request.parameters()"); @@ -590,6 +595,16 @@ private void addPromptMethod(Method.Builder builder, ClassModel.Builder classMod builder.addContentLine("var cancellation = features.cancellation();"); continue; } + if (MCP_SAMPLING.equals(param.typeName())) { + if (!featuresLocalVar) { + addFeaturesLocalVar(builder, classModel); + featuresLocalVar = true; + } + parameters.add("sampling"); + classModel.addImport(MCP_SAMPLING); + builder.addContentLine("var sampling = features.sampling();"); + continue; + } if (!parametersLocalVar) { addParametersLocalVar(builder, classModel); parametersLocalVar = true; @@ -750,83 +765,62 @@ private void addToolMethod(Method.Builder builder, ClassModel.Builder classModel .addAnnotation(Annotations.OVERRIDE); builder.addContentLine("return request -> {"); - boolean featuresLocalVar = false; - boolean parametersLocalVar = false; for (TypedElementInfo param : element.parameterArguments()) { if (MCP_REQUEST.equals(param.typeName())) { parameters.add("request"); continue; } - if (MCP_FEATURES.equals(param.typeName()) && !featuresLocalVar) { + if (MCP_FEATURES.equals(param.typeName())) { addFeaturesLocalVar(builder, classModel); - parameters.add("features"); - featuresLocalVar = true; + parameters.add("request.features()"); continue; } if (MCP_LOGGER.equals(param.typeName())) { - if (!featuresLocalVar) { - addFeaturesLocalVar(builder, classModel); - featuresLocalVar = true; - } parameters.add("logger"); - builder.addContentLine("var logger = features.logger();"); + builder.addContentLine("var logger = request.features().logger();"); continue; } if (MCP_PROGRESS.equals(param.typeName())) { - if (!featuresLocalVar) { - addFeaturesLocalVar(builder, classModel); - featuresLocalVar = true; - } parameters.add("progress"); classModel.addImport(MCP_PROGRESS); - builder.addContentLine("var progress = features.progress();"); + builder.addContentLine("var progress = request.features().progress();"); continue; } if (MCP_CANCELLATION.equals(param.typeName())) { - if (!featuresLocalVar) { - addFeaturesLocalVar(builder, classModel); - featuresLocalVar = true; - } parameters.add("cancellation"); classModel.addImport(MCP_CANCELLATION); - builder.addContentLine("var cancellation = features.cancellation();"); + builder.addContentLine("var cancellation = request.features().cancellation();"); + continue; + } + if (MCP_SAMPLING.equals(param.typeName())) { + parameters.add("sampling"); + classModel.addImport(MCP_SAMPLING); + builder.addContentLine("var sampling = request.features().sampling();"); continue; } if (TypeNames.STRING.equals(param.typeName())) { - if (!parametersLocalVar) { - addParametersLocalVar(builder, classModel); - parametersLocalVar = true; - } parameters.add(param.elementName()); builder.addContent("var ") .addContent(param.elementName()) - .addContent(" = parameters.get(\"") + .addContent(" = request.parameters().get(\"") .addContent(param.elementName()) .addContentLine("\").asString().orElse(\"\");"); continue; } if (isBoolean(param.typeName())) { - if (!parametersLocalVar) { - addParametersLocalVar(builder, classModel); - parametersLocalVar = true; - } parameters.add(param.elementName()); builder.addContent("boolean ") .addContent(param.elementName()) - .addContent(" = parameters.get(\"") + .addContent(" = request.parameters().get(\"") .addContent(param.elementName()) .addContentLine("\").asBoolean().orElse(false);"); continue; } if (isNumber(param.typeName())) { - if (!parametersLocalVar) { - addParametersLocalVar(builder, classModel); - parametersLocalVar = true; - } parameters.add(param.elementName()); builder.addContent("var ") .addContent(param.elementName()) - .addContent(" = parameters.get(\"") + .addContent(" = request.parameters().get(\"") .addContent(param.elementName()) .addContent("\").as") .addContent(param.typeName().className()) @@ -836,28 +830,19 @@ private void addToolMethod(Method.Builder builder, ClassModel.Builder classModel if (isList(param.typeName())) { TypeName typeArg = param.typeName().typeArguments().getFirst(); addToListMethod(classModel, typeArg); - - if (!parametersLocalVar) { - addParametersLocalVar(builder, classModel); - parametersLocalVar = true; - } parameters.add(param.elementName()); builder.addContent("var ") .addContent(param.elementName()) - .addContent(" = toList(parameters.get(\"") + .addContent(" = toList(request.parameters().get(\"") .addContent(param.elementName()) .addContentLine("\").asList().orElse(null));"); continue; } - if (!parametersLocalVar) { - addParametersLocalVar(builder, classModel); - parametersLocalVar = true; - } parameters.add(param.elementName()); builder.addContent(param.typeName().classNameWithEnclosingNames()) .addContent(" ") .addContent(param.elementName()) - .addContent(" = parameters.get(\"") + .addContent(" = request.parameters().get(\"") .addContent(param.elementName()) .addContent("\").as(") .addContent(param.typeName()) @@ -905,7 +890,9 @@ private void addToolDescriptionMethod(Method.Builder builder, String description builder.name("description") .addAnnotation(Annotations.OVERRIDE) .returnType(TypeNames.STRING) - .addContentLine("return \"" + description + "\";"); + .addContent("return \"") + .addContent(description) + .addContentLine("\";"); } private void addToolAnnotationsMethod(Method.Builder builder, Annotation toolAnnotation) { @@ -1077,9 +1064,10 @@ private TypeName generatedTypeName(TypeName factoryTypeName, String suffix) { private boolean isIgnoredSchemaElement(TypeName typeName) { return MCP_REQUEST.equals(typeName) - || MCP_FEATURES.equals(typeName) || MCP_LOGGER.equals(typeName) + || MCP_FEATURES.equals(typeName) || MCP_PROGRESS.equals(typeName) + || MCP_SAMPLING.equals(typeName) || MCP_CANCELLATION.equals(typeName); } diff --git a/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegenProvider.java b/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegenProvider.java index 6afe68f7..3723f6f6 100644 --- a/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegenProvider.java +++ b/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpCodegenProvider.java @@ -26,7 +26,7 @@ import static io.helidon.extensions.mcp.codegen.McpTypes.MCP_SERVER; /** - * Mcp code generator provider. + * MCP code generator provider. */ public class McpCodegenProvider implements CodegenExtensionProvider { /** diff --git a/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpTypes.java b/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpTypes.java index 89364b6c..b862b0f3 100644 --- a/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpTypes.java +++ b/codegen/src/main/java/io/helidon/extensions/mcp/codegen/McpTypes.java @@ -48,6 +48,7 @@ private McpTypes() { static final TypeName MCP_REQUEST = TypeName.create("io.helidon.extensions.mcp.server.McpRequest"); static final TypeName MCP_FEATURES = TypeName.create("io.helidon.extensions.mcp.server.McpFeatures"); static final TypeName MCP_PROGRESS = TypeName.create("io.helidon.extensions.mcp.server.McpProgress"); + static final TypeName MCP_SAMPLING = TypeName.create("io.helidon.extensions.mcp.server.McpSampling"); static final TypeName MCP_TOOL_INTERFACE = TypeName.create("io.helidon.extensions.mcp.server.McpTool"); static final TypeName MCP_PARAMETERS = TypeName.create("io.helidon.extensions.mcp.server.McpParameters"); static final TypeName MCP_PROMPT_INTERFACE = TypeName.create("io.helidon.extensions.mcp.server.McpPrompt"); diff --git a/docs/mcp-declarative/README.md b/docs/mcp-declarative/README.md index e8531840..28a90642 100644 --- a/docs/mcp-declarative/README.md +++ b/docs/mcp-declarative/README.md @@ -494,6 +494,20 @@ List cancellationTool(McpCancellation cancellation) { } ``` +### Sampling + +See the full [sampling documentation details](../mcp/README.md#sampling) + +#### Example + +Below is an example of a tool that uses the Sampling feature. `McpSampling` object can be used as method parameter. + +```java +@Mcp.Tool("Uses MCP Sampling to ask the connected client model.") +List samplingTool(McpSampling sampling) { +} +``` + ## References - [MCP Specification](https://modelcontextprotocol.io/introduction) diff --git a/docs/mcp/README.md b/docs/mcp/README.md index 343b7de8..822fe457 100644 --- a/docs/mcp/README.md +++ b/docs/mcp/README.md @@ -659,6 +659,103 @@ private class CancellationTool implements McpTool { } ``` +### Sampling + +The MCP Sampling feature provides a standardized mechanism that allows servers to request LLM sampling operations from language +models through connected clients. It enables servers to seamlessly integrate AI capabilities into their workflows without +requiring API keys. Like other MCP features, sampling can be accessed via the MCP request features. +Sampling support is optional for clients, and servers can verify its availability using the `enabled` method: + +```java +var sampling = request.features().sampling(); +if (!sampling.enabled()) { +} +``` + +If the client supports sampling, you can send a sampling request using the request method. A builder is provided to configure +and customize the sampling request as needed: + +```java +McpSamplingRequest request = McpSamplingRequest.builder() + .maxTokens(1) + .temperature(0.1) + .costPriority(0.1) + .speedPriority(0.1) + .hints(List.of("hint1")) + .metadata(JsonValue.TRUE) + .intelligencePriority(0.1) + .systemPrompt("system prompt") + .timeout(Duration.ofSeconds(10)) + .stopSequences(List.of("stop1")) + .includeContext(McpIncludeContext.NONE) + .addMessage(McpSamplingMessages.textContent("text", McpRole.USER)) + .build(); +``` + +Once your request is built, send it using the sampling feature. The request method may throw an `McpSamplingException` if an +error occurs during processing. On success, it returns an McpSamplingResponse containing the response message, the model used, +and optionally a stop reason. + +```java +try { + McpSamplingResponse response = sampling.request(req -> req.addMessage(message)); +} catch(McpSamplingException exception) { + // Manage error +} +``` + +The messages you send are prompts to the language model, and they follow the same structure as MCP prompts. You can use the +`McpSamplingMessages` utility class to create different types of messages for the client model: + +```java +var text = McpSamplingMessages.textContent("Explain Helidon MCP in one paragraph.", McpRole.USER); +var image = McpSamplingMessages.imageContent(pngBytes, MediaTypes.create("image/png"), McpRole.USER); +var audio = McpSamplingMessages.audioContent(wavBytes, MediaTypes.create("audio/wav"), McpRole.USER); +``` + +#### Example + +Below is an example of a tool that uses the Sampling feature. If the connected client does not support sampling, the tool +throws a `McpToolErrorException`. + +```java +class SamplingTool implements McpTool { + @Override + public String name() { + return "sampling-tool"; + } + + @Override + public String description() { + return "Uses MCP Sampling to ask the connected client model."; + } + + @Override + public String schema() { + return ""; + } + + @Override + public List process(McpRequest request) { + var sampling = request.features().sampling(); + + if (!sampling.enabled()) { + throw new McpToolErrorException("This tool requires sampling feature"); + } + + try { + McpSamplingResponse response = sampling.request(req -> req + .timeout(Duration.ofSeconds(10)) + .systemPrompt("You are a concise, helpful assistant.") + .addMessage(McpSamplingMessages.textContent("Write a 3-line summary of Helidon MCP Sampling.", McpRole.USER))); + return List.of(McpToolContents.textContent(response.asTextMessage())); + } catch (McpSamplingException e) { + throw new McpToolErrorException(e.getMessage()); + } + } +} +``` + ## Configuration MCP server configuration can be defined using Helidon configuration files. Example in YAML: diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/Mcp.java b/server/src/main/java/io/helidon/extensions/mcp/server/Mcp.java index e300777a..cd2ec2ce 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/Mcp.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/Mcp.java @@ -34,12 +34,12 @@ public final class Mcp { /** * Annotation to define an MCP server. An MCP Server aggregates several MCP * components like tools, prompts, resources and completions. - * - *

The primary components include:

+ *

+ * The primary components include: *

    *
  • * {@link io.helidon.extensions.mcp.server.Mcp.Tool} - - * Tool is a function that computes a set of inputs and return a result. Mcp server uses tools to + * Tool is a function that computes a set of inputs and return a result. MCP server uses tools to * interact with the outside world to reach real time data through API calls, access to databases * or performing any kind of computation. *
  • @@ -61,7 +61,8 @@ public final class Mcp { * This way, the server can suggest where are resources located and which arguments can be used. * *
- *

The MCP server can be configured using the following annotations:

+ *

+ * The MCP server can be configured using the following annotations: *

    *
  • * {@link io.helidon.extensions.mcp.server.Mcp.Version} - @@ -208,8 +209,8 @@ public final class Mcp { * Annotation to define an MCP resource. * A resource is a none static method and must be located in a class annotated with * {@link io.helidon.extensions.mcp.server.Mcp.Server}. This way, the resource is automatically registered to the server. - * - *

    This annotation supports two kinds of Resource:

    + *

    + * This annotation supports two kinds of Resource: *

      *
    • * {@code Regular Resource} where the resource {@link java.net.URI} points to an MCP resource such as diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpCapability.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpCapability.java index 7fe0d702..4b5b508c 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpCapability.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpCapability.java @@ -25,7 +25,7 @@ enum McpCapability { COMPLETION, PAGINATION, SAMPLING, - ROOT, + ROOTS, PROGRESS; String text() { diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpContent.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpContent.java index 8c251104..06d804f9 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpContent.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpContent.java @@ -19,10 +19,10 @@ /** * General content type for all MCP component contents. */ -public sealed interface McpContent permits McpEmbeddedResource, +public sealed interface McpContent permits McpTextContent, + McpMediaContent, McpResourceContent, - McpTextContent, - McpMediaContent { + McpEmbeddedResource { /** * Content type. * diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpDecorators.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpDecorators.java new file mode 100644 index 00000000..f4b57a98 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpDecorators.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import java.util.Optional; + +import io.helidon.builder.api.Prototype; + +import static io.helidon.extensions.mcp.server.McpPagination.DEFAULT_PAGE_SIZE; + +/** + * Placeholder for the MCP configuration decorators. + */ +final class McpDecorators { + private McpDecorators() { + } + + /** + * Enforce positive page size. + *

      + * See {@link io.helidon.extensions.mcp.server.McpPagination}. + */ + static class PageSizeDecorator implements Prototype.OptionDecorator, Integer> { + @Override + public void decorate(McpServerConfig.BuilderBase builder, Integer pageSize) { + if (pageSize < DEFAULT_PAGE_SIZE) { + throw new IllegalArgumentException("Page size must be greater than zero"); + } + } + } + + /** + * Enforce intelligence priority value between 0 and 1. + *

      + * See {@link io.helidon.extensions.mcp.server.McpSamplingRequest}. + */ + static class IntelligencePriorityDecorator implements Prototype.OptionDecorator, Optional> { + @Override + public void decorate(McpSamplingRequest.BuilderBase builder, Optional value) { + value.filter(McpDecorators::isPositiveAndLessThanOne) + .orElseThrow(() -> new IllegalArgumentException("Intelligence priority must be in range [0, 1]")); + } + } + + /** + * Enforce speed priority value between 0 and 1. + *

      + * See {@link io.helidon.extensions.mcp.server.McpSamplingRequest}. + */ + static class SpeedPriorityDecorator implements Prototype.OptionDecorator, Optional> { + @Override + public void decorate(McpSamplingRequest.BuilderBase builder, Optional value) { + value.filter(McpDecorators::isPositiveAndLessThanOne) + .orElseThrow(() -> new IllegalArgumentException("Speed priority must be in range [0, 1]")); + } + } + + /** + * Enforce cost priority value between 0 and 1. + *

      + * See {@link io.helidon.extensions.mcp.server.McpSamplingRequest}. + */ + static class CostPriorityDecorator implements Prototype.OptionDecorator, Optional> { + @Override + public void decorate(McpSamplingRequest.BuilderBase builder, Optional value) { + value.filter(McpDecorators::isPositiveAndLessThanOne) + .orElseThrow(() -> new IllegalArgumentException("Cost priority must be in range [0, 1]")); + } + } + + static boolean isPositiveAndLessThanOne(Double value) { + return 0 <= value && value <= 1.0; + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpFeatures.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpFeatures.java index d119ab71..635a4936 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpFeatures.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpFeatures.java @@ -43,6 +43,10 @@ * {@link io.helidon.extensions.mcp.server.McpSubscriptions} - MCP subscription feature. * Sends notifications to the subscribed clients. *

    • + *
    • + * {@link io.helidon.extensions.mcp.server.McpSampling} - MCP Sampling feature. + * Send sampling messages to client. + *
    • *
    */ public final class McpFeatures { @@ -51,8 +55,9 @@ public final class McpFeatures { private final McpSession session; private SseSink sseSink; - private McpProgress progress; private McpLogger logger; + private McpSampling sampling; + private McpProgress progress; private McpSubscriptions subscriptions; McpFeatures(McpSession session) { @@ -111,6 +116,23 @@ public McpLogger logger() { return logger; } + /** + * Get a {@link io.helidon.extensions.mcp.server.McpSampling} feature. + * + * @return the MCP sampling + */ + public McpSampling sampling() { + if (sampling == null) { + if (response != null) { + sseSink = getOrCreateSseSink(); + sampling = new McpSampling(session, sseSink); + } else { + sampling = new McpSampling(session); + } + } + return sampling; + } + /** * Get a {@link io.helidon.extensions.mcp.server.McpSubscriptions} feature. * diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpIncludeContext.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpIncludeContext.java new file mode 100644 index 00000000..62c75cda --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpIncludeContext.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.server; + +/** + * Include context values to be used as part of Sampling request. + */ +public enum McpIncludeContext { + /** + * None. + */ + NONE("none"), + + /** + * This server. + */ + THIS_SERVER("thisServer"), + + /** + * All server. + */ + ALL_SERVERS("allServers"),; + + private final String literal; + + McpIncludeContext(String literal) { + this.literal = literal; + } + + String text() { + return literal; + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpJsonRpc.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpJsonRpc.java index 680f862a..a0201136 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpJsonRpc.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpJsonRpc.java @@ -21,21 +21,29 @@ import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import io.helidon.common.media.type.MediaType; +import io.helidon.common.media.type.MediaTypes; +import io.helidon.jsonrpc.core.JsonRpcError; + import jakarta.json.Json; import jakarta.json.JsonArrayBuilder; import jakarta.json.JsonBuilderFactory; import jakarta.json.JsonObject; import jakarta.json.JsonObjectBuilder; import jakarta.json.JsonReaderFactory; +import jakarta.json.JsonString; import jakarta.json.JsonStructure; import jakarta.json.JsonValue; import jakarta.json.JsonWriter; import jakarta.json.JsonWriterFactory; import jakarta.json.stream.JsonGenerator; +import static io.helidon.jsonrpc.core.JsonRpcError.INTERNAL_ERROR; + final class McpJsonRpc { static final JsonBuilderFactory JSON_BUILDER_FACTORY = Json.createBuilderFactory(Map.of()); static final JsonReaderFactory JSON_READER_FACTORY = Json.createReaderFactory(Map.of()); @@ -239,7 +247,7 @@ static JsonObject listResources(McpPage page) { static JsonObject listTools(McpPage page, String protocolVersion) { JsonArrayBuilder builder = JSON_BUILDER_FACTORY.createArrayBuilder(); page.components().stream() - .map(t -> McpJsonRpc.toJson(t, protocolVersion)) + .map(t -> toJson(t, protocolVersion)) .forEach(builder::add); JsonObjectBuilder resources = JSON_BUILDER_FACTORY.createObjectBuilder() .add("tools", builder); @@ -278,7 +286,7 @@ static JsonObject listPrompts(McpPage page) { static JsonObjectBuilder toJson(McpToolResourceContent content) { return JSON_BUILDER_FACTORY.createObjectBuilder() .add("type", content.type().text()) - .add("resource", McpJsonRpc.toJson(content.content()) + .add("resource", toJson(content.content()) .add("uri", content.uri().toASCIIString())); } @@ -319,7 +327,7 @@ static JsonObjectBuilder resourceTemplates(McpResource resource) { static JsonObject readResource(String uri, List contents) { JsonArrayBuilder array = JSON_BUILDER_FACTORY.createArrayBuilder(); for (McpResourceContent content : contents) { - JsonObjectBuilder builder = McpJsonRpc.toJson(content); + JsonObjectBuilder builder = toJson(content); builder.add("uri", uri); array.add(builder); } @@ -329,7 +337,7 @@ static JsonObject readResource(String uri, List contents) { static JsonObject toJson(List contents, String description) { JsonArrayBuilder array = JSON_BUILDER_FACTORY.createArrayBuilder(); for (McpPromptContent prompt : contents) { - array.add(McpJsonRpc.toJson(prompt)); + array.add(toJson(prompt)); } return JSON_BUILDER_FACTORY.createObjectBuilder() .add("description", description) @@ -339,18 +347,18 @@ static JsonObject toJson(List contents, String description) { static JsonObjectBuilder toJson(McpPromptContent content) { if (content instanceof McpPromptImageContent image) { - return McpJsonRpc.toJson(image); + return toJson(image); } if (content instanceof McpPromptTextContent text) { - return McpJsonRpc.toJson(text); + return toJson(text); } if (content instanceof McpPromptResourceContent resource) { - return McpJsonRpc.toJson(resource); + return toJson(resource); } if (content instanceof McpPromptAudioContent resource) { - return McpJsonRpc.toJson(resource); + return toJson(resource); } - return null; + throw new IllegalArgumentException("Unsupported content type: " + content.getClass().getName()); } static JsonObjectBuilder toJson(McpContent content) { @@ -366,7 +374,20 @@ static JsonObjectBuilder toJson(McpContent content) { if (content instanceof McpAudioContent audio) { return toJson(audio); } - return null; + throw new IllegalArgumentException("Unsupported content type: " + content.getClass().getName()); + } + + static JsonObjectBuilder toJson(McpSamplingMessage message) { + if (message instanceof McpSamplingTextMessageImpl text) { + return toJson(text); + } + if (message instanceof McpSamplingImageMessageImpl image) { + return toJson(image); + } + if (message instanceof McpSamplingAudioMessageImpl resource) { + return toJson(resource); + } + throw new IllegalArgumentException("Unsupported content type: " + message.getClass().getName()); } static JsonObjectBuilder toJson(McpResourceContent content) { @@ -376,7 +397,7 @@ static JsonObjectBuilder toJson(McpResourceContent content) { if (content instanceof McpResourceBinaryContent binary) { return toJson(binary); } - return null; + throw new IllegalArgumentException("Unsupported content type: " + content.getClass().getName()); } static JsonObjectBuilder toJson(McpPromptResourceContent resource) { @@ -384,14 +405,14 @@ static JsonObjectBuilder toJson(McpPromptResourceContent resource) { .add("role", resource.role().text()) .add("content", JSON_BUILDER_FACTORY.createObjectBuilder() .add("type", resource.type().text()) - .add("resource", McpJsonRpc.toJson(resource.content()) + .add("resource", toJson(resource.content()) .add("uri", resource.uri().toASCIIString()))); } static JsonObjectBuilder toJson(McpPromptImageContent image) { return JSON_BUILDER_FACTORY.createObjectBuilder() .add("role", image.role().text()) - .add("content", McpJsonRpc.toJson(image.content())); + .add("content", toJson(image.content())); } static JsonObjectBuilder toJson(McpPromptTextContent content) { @@ -403,7 +424,33 @@ static JsonObjectBuilder toJson(McpPromptTextContent content) { static JsonObjectBuilder toJson(McpPromptAudioContent audio) { return JSON_BUILDER_FACTORY.createObjectBuilder() .add("role", audio.role().text()) - .add("content", McpJsonRpc.toJson(audio.content())); + .add("content", toJson(audio.content())); + } + + static JsonObjectBuilder toJson(McpSamplingImageMessage image) { + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("role", image.role().text()) + .add("content", JSON_BUILDER_FACTORY.createObjectBuilder() + .add("type", image.type().text()) + .add("data", image.encodeBase64Data()) + .add("mimeType", image.mediaType().text())); + } + + static JsonObjectBuilder toJson(McpSamplingTextMessage text) { + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("role", text.role().text()) + .add("content", JSON_BUILDER_FACTORY.createObjectBuilder() + .add("type", text.type().text()) + .add("text", text.text())); + } + + static JsonObjectBuilder toJson(McpSamplingAudioMessage audio) { + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("role", audio.role().text()) + .add("content", JSON_BUILDER_FACTORY.createObjectBuilder() + .add("type", audio.type().text()) + .add("data", audio.encodeBase64Data()) + .add("mimeType", audio.mediaType().text())); } static JsonObjectBuilder toJson(McpTextContent content) { @@ -439,7 +486,6 @@ static JsonObjectBuilder toJson(McpResourceTextContent content) { } static JsonObject toJson(McpProgress progress, int newProgress, String message) { - JsonObjectBuilder response = JSON_BUILDER_FACTORY.createObjectBuilder(); JsonObjectBuilder params = JSON_BUILDER_FACTORY.createObjectBuilder() .add("progress", newProgress) .add("total", progress.total()); @@ -451,30 +497,20 @@ static JsonObject toJson(McpProgress progress, int newProgress, String message) if (message != null) { params.add("message", message); } - response.add("jsonrpc", "2.0"); - response.add("method", McpJsonRpc.METHOD_NOTIFICATION_PROGRESS); - response.add("params", params); - return response.build(); + return createJsonRpcNotification(METHOD_NOTIFICATION_PROGRESS, params); } static JsonObject createLoggingNotification(McpLogger.Level level, String name, String message) { - return JSON_BUILDER_FACTORY.createObjectBuilder() - .add("jsonrpc", "2.0") - .add("method", METHOD_NOTIFICATION_MESSAGE) - .add("params", JSON_BUILDER_FACTORY.createObjectBuilder() - .add("level", level.text()) - .add("logger", name) - .add("data", message)) - .build(); + var params = JSON_BUILDER_FACTORY.createObjectBuilder() + .add("level", level.text()) + .add("logger", name) + .add("data", message); + return createJsonRpcNotification(METHOD_NOTIFICATION_MESSAGE, params); } static JsonObject createUpdateNotification(String uri) { - return JSON_BUILDER_FACTORY.createObjectBuilder() - .add("jsonrpc", "2.0") - .add("method", METHOD_NOTIFICATION_UPDATE) - .add("params", JSON_BUILDER_FACTORY.createObjectBuilder() - .add("uri", uri)) - .build(); + var params = JSON_BUILDER_FACTORY.createObjectBuilder().add("uri", uri); + return createJsonRpcNotification(METHOD_NOTIFICATION_UPDATE, params); } static JsonObject toJson(McpCompletionContent content) { @@ -486,8 +522,79 @@ static JsonObject toJson(McpCompletionContent content) { .build(); } + static JsonObjectBuilder toJson(McpSamplingRequest request) { + var hints = JSON_BUILDER_FACTORY.createArrayBuilder(); + var params = JSON_BUILDER_FACTORY.createObjectBuilder(); + var messages = JSON_BUILDER_FACTORY.createArrayBuilder(); + var sequences = JSON_BUILDER_FACTORY.createArrayBuilder(); + var modelPreference = JSON_BUILDER_FACTORY.createObjectBuilder(); + + request.hints() + .stream() + .flatMap(List::stream) + .map(hint -> JSON_BUILDER_FACTORY.createObjectBuilder().add("name", hint)) + .forEach(hints::add); + request.hints().map(it -> modelPreference.add("hints", hints)); + request.speedPriority().map(speed -> modelPreference.add("speedPriority", speed)); + request.costPriority().map(priority -> modelPreference.add("costPriority", priority)); + request.intelligencePriority().map(intelligence -> modelPreference.add("intelligencePriority", intelligence)); + params.add("modelPreference", modelPreference); + + request.messages().stream() + .map(McpJsonRpc::toJson) + .forEach(messages::add); + params.add("messages", messages); + params.add("maxTokens", request.maxTokens()); + request.systemPrompt().map(prompt -> params.add("systemPrompt", prompt)); + request.temperature().map(temperature -> params.add("temperature", temperature)); + request.includeContext().map(context -> params.add("includeContext", context.text())); + request.stopSequences() + .stream() + .flatMap(List::stream) + .forEach(sequences::add); + request.stopSequences().map(it -> params.add("stopSequences", sequences)); + request.metadata().map(metadata -> params.add("metadata", metadata)); + return params; + } + + static JsonObject createSamplingRequest(long id, McpSamplingRequest request) { + var params = toJson(request); + return createJsonRpcRequest(id, METHOD_SAMPLING_CREATE_MESSAGE, params); + } + static JsonObject disconnectSession() { - return JSON_BUILDER_FACTORY.createObjectBuilder().add("disconnect", true).build(); + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("disconnect", true) + .build(); + } + + static McpSamplingResponse createSamplingResponse(JsonObject object) throws McpSamplingException { + find(object, "error") + .filter(McpJsonRpc::isJsonObject) + .map(JsonValue::asJsonObject) + .map(JsonRpcError::create) + .ifPresent(error -> { + throw new McpSamplingException(error.message()); + }); + try { + var result = find(object, "result") + .filter(McpJsonRpc::isJsonObject) + .map(JsonValue::asJsonObject) + .orElseThrow(() -> new McpSamplingException(String.format("Sampling result not found: %s", object))); + + String model = result.getString("model"); + McpRole role = McpRole.valueOf(result.getString("role").toUpperCase()); + McpSamplingMessage message = parseMessage(role, result.getJsonObject("content")); + McpStopReason stopReason = find(result, "stopReason") + .filter(McpJsonRpc::isJsonString) + .map(JsonString.class::cast) + .map(JsonString::getString) + .map(McpStopReason::map) + .orElse(null); + return new McpSamplingResponseImpl(message, model, stopReason); + } catch (Exception e) { + throw new McpSamplingException("Wrong sampling response format", e); + } } static String prettyPrint(JsonStructure json) { @@ -497,4 +604,77 @@ static String prettyPrint(JsonStructure json) { } return baos.toString(StandardCharsets.UTF_8); } + + static JsonObject createJsonRpcNotification(String method, JsonObjectBuilder params) { + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("jsonrpc", "2.0") + .add("method", method) + .add("params", params) + .build(); + } + + static JsonObject createJsonRpcRequest(long id, String method, JsonObjectBuilder params) { + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("jsonrpc", "2.0") + .add("id", id) + .add("method", method) + .add("params", params) + .build(); + } + + static JsonObject createJsonRpcErrorResponse(long id, JsonObjectBuilder params) { + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("jsonrpc", "2.0") + .add("id", id) + .add("error", params) + .build(); + } + + static JsonObject createJsonRpcResultResponse(long id, JsonValue params) { + return JSON_BUILDER_FACTORY.createObjectBuilder() + .add("jsonrpc", "2.0") + .add("id", id) + .add("result", params) + .build(); + } + + static JsonObject timeoutResponse(long requestId) { + var error = JSON_BUILDER_FACTORY.createObjectBuilder() + .add("code", INTERNAL_ERROR) + .add("message", "response timeout"); + return createJsonRpcErrorResponse(requestId, error); + } + + private static McpSamplingMessage parseMessage(McpRole role, JsonObject object) { + String type = object.getString("type").toUpperCase(); + McpSamplingMessageType messageType = McpSamplingMessageType.valueOf(type); + return switch (messageType) { + case TEXT -> new McpSamplingTextMessageImpl(object.getString("text"), role); + case IMAGE -> { + byte[] data = object.getString("data").getBytes(StandardCharsets.UTF_8); + MediaType mediaType = MediaTypes.create(object.getString("mimeType")); + yield new McpSamplingImageMessageImpl(data, mediaType, role); + } + case AUDIO -> { + byte[] data = object.getString("data").getBytes(StandardCharsets.UTF_8); + MediaType mediaType = MediaTypes.create(object.getString("mimeType")); + yield new McpSamplingAudioMessageImpl(data, mediaType, role); + } + }; + } + + private static Optional find(JsonObject object, String key) { + if (object.containsKey(key)) { + return Optional.of(object.get(key)); + } + return Optional.empty(); + } + + private static boolean isJsonObject(JsonValue value) { + return JsonValue.ValueType.OBJECT.equals(value.getValueType()); + } + + private static boolean isJsonString(JsonValue value) { + return JsonValue.ValueType.STRING.equals(value.getValueType()); + } } diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpPage.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpPage.java index 35ac5dc4..2530c549 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpPage.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpPage.java @@ -28,12 +28,28 @@ class McpPage { private final String cursor; private final boolean isLast; + /** + * Create a page with provided parameters. + * + * @param components the page components + * @param cursor the cursor pointing to the next page + * @param isLast {@code true} if this is the last page, {@code false} otherwise + */ McpPage(List components, String cursor, boolean isLast) { this.cursor = cursor; this.isLast = isLast; this.components = components; } + /** + * Create a single page with provided content. + * + * @param components the page components + */ + McpPage(List components) { + this(components, "", true); + } + String cursor() { return cursor; } diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpPagination.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpPagination.java index 58a16130..3e331de6 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpPagination.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpPagination.java @@ -23,8 +23,6 @@ import java.util.concurrent.ConcurrentMap; import java.util.stream.Collectors; -import io.helidon.builder.api.Prototype; - /** * Support for MCP pagination feature. *

    @@ -67,7 +65,7 @@ class McpPagination { // Pagination is disabled if (pageSize == DEFAULT_PAGE_SIZE) { - pages.put(prevCursor, new McpPage<>(components, "", true)); + pages.put(prevCursor, new McpPage<>(components)); return; } @@ -83,7 +81,7 @@ class McpPagination { if (total % pageSize != 0) { int lastPageStart = total - (total % pageSize); List lastPage = components.subList(lastPageStart, total); - pages.put(prevCursor, new McpPage<>(lastPage, "", true)); + pages.put(prevCursor, new McpPage<>(lastPage)); } } @@ -101,13 +99,4 @@ List content() { .flatMap(Collection::stream) .collect(Collectors.toList()); } - - static class PageSizeDecorator implements Prototype.OptionDecorator, Integer> { - @Override - public void decorate(McpServerConfig.BuilderBase builder, Integer pageSize) { - if (pageSize < DEFAULT_PAGE_SIZE) { - throw new IllegalArgumentException("Page size must be greater than zero"); - } - } - } } diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpParameters.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpParameters.java index c27da2c9..47e8a6bb 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpParameters.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpParameters.java @@ -35,7 +35,7 @@ import jakarta.json.JsonValue; /** - * Mcp client parameters provided to {@link McpTool} and {@link McpPrompt}. + * MCP client parameters provided to {@link McpTool} and {@link McpPrompt}. */ public final class McpParameters { private static final Mappers MAPPERS = Mappers.create(); @@ -63,7 +63,7 @@ private McpParameters(JsonValue root, String key) { } /** - * Get Mcp parameter node. + * Get MCP parameter node. * * @param key node key * @return parameter diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSampling.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSampling.java new file mode 100644 index 00000000..8c249a52 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSampling.java @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import java.lang.System.Logger.Level; +import java.util.function.Consumer; + +import io.helidon.http.sse.SseEvent; +import io.helidon.webserver.sse.SseSink; + +import jakarta.json.JsonObject; + +import static io.helidon.extensions.mcp.server.McpJsonRpc.createSamplingRequest; +import static io.helidon.extensions.mcp.server.McpJsonRpc.createSamplingResponse; +import static io.helidon.extensions.mcp.server.McpJsonRpc.prettyPrint; + +/** + * MCP Sampling feature. + */ +public final class McpSampling extends McpFeature { + private static final System.Logger LOGGER = System.getLogger(McpSampling.class.getName()); + + McpSampling(McpSession session) { + super(session); + } + + McpSampling(McpSession session, SseSink sseSink) { + super(session, sseSink); + } + + /** + * Whether the connected client supports sampling feature. + * + * @return {@code true} if the connected client supports sampling feature, + * {@code false} otherwise. + */ + public boolean enabled() { + return session() + .capabilities() + .contains(McpCapability.SAMPLING); + } + + /** + * Send the provided sampling request to the client and return its response. + * + * @param request sampling request + * @return sampling response + * @throws io.helidon.extensions.mcp.server.McpSamplingException when an error occurs + */ + public McpSamplingResponse request(Consumer request) throws McpSamplingException { + var builder = McpSamplingRequest.builder(); + request.accept(builder); + return request(builder.build()); + } + + /** + * Send the provided sampling request to the client and return its response. + * + * @param request sampling request + * @return sampling response + * @throws io.helidon.extensions.mcp.server.McpSamplingException when an error occurs + */ + public McpSamplingResponse request(McpSamplingRequest request) throws McpSamplingException { + long id = session().jsonRpcId(); + JsonObject payload = createSamplingRequest(id, request); + + if (LOGGER.isLoggable(Level.DEBUG)) { + LOGGER.log(Level.DEBUG, "Sampling request:\n" + prettyPrint(payload)); + } + sseSink().ifPresentOrElse(sink -> sink.emit(SseEvent.builder() + .name("message") + .data(payload) + .build()), + () -> session().send(payload)); + JsonObject response = session().pollResponse(id, request.timeout()); + if (LOGGER.isLoggable(Level.DEBUG)) { + LOGGER.log(Level.DEBUG, "Sampling response:\n" + prettyPrint(response)); + } + return createSamplingResponse(response); + } + +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingAudioMessage.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingAudioMessage.java new file mode 100644 index 00000000..fc92bba4 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingAudioMessage.java @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +/** + * MCP sampling audio content. + */ +public sealed interface McpSamplingAudioMessage extends McpSamplingMessage, + McpSamplingMediaMessage permits McpSamplingAudioMessageImpl { +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingAudioMessageImpl.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingAudioMessageImpl.java new file mode 100644 index 00000000..a65a19e0 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingAudioMessageImpl.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import java.util.Base64; + +import io.helidon.common.media.type.MediaType; + +/** + * MCP sampling audio content. + */ +final class McpSamplingAudioMessageImpl implements McpSamplingAudioMessage { + private final byte[] data; + private final McpRole role; + private final MediaType type; + + McpSamplingAudioMessageImpl(byte[] data, MediaType type, McpRole role) { + this.data = data; + this.role = role; + this.type = type; + } + + @Override + public McpSamplingMessageType type() { + return McpSamplingMessageType.AUDIO; + } + + @Override + public McpRole role() { + return role; + } + + @Override + public MediaType mediaType() { + return type; + } + + @Override + public byte[] data() { + return data; + } + + @Override + public byte[] decodeBase64Data() { + return Base64.getDecoder().decode(data); + } + + @Override + public String encodeBase64Data() { + return Base64.getEncoder().encodeToString(data); + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingException.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingException.java new file mode 100644 index 00000000..bcf85aa2 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingException.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +/** + * MCP sampling exception thrown during a sampling request to the client. + */ +public class McpSamplingException extends RuntimeException { + /** + * Creates a new MCP sampling exception with specified details message. + * + * @param message exception message + */ + McpSamplingException(String message) { + super(message); + } + + /** + * Creates a new MCP sampling exception with specified details message and its cause. + * + * @param message exception message + * @param cause exception cause + */ + McpSamplingException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingImageMessage.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingImageMessage.java new file mode 100644 index 00000000..635dd184 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingImageMessage.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +/** + * MCP sampling image content. + */ +public sealed interface McpSamplingImageMessage extends McpSamplingMessage, + McpSamplingMediaMessage permits McpSamplingImageMessageImpl { + +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingImageMessageImpl.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingImageMessageImpl.java new file mode 100644 index 00000000..1ae8fd68 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingImageMessageImpl.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import java.util.Base64; + +import io.helidon.common.media.type.MediaType; + +/** + * MCP sampling image content. + */ +final class McpSamplingImageMessageImpl implements McpSamplingImageMessage { + private final byte[] data; + private final McpRole role; + private final MediaType type; + + McpSamplingImageMessageImpl(byte[] data, MediaType mediaType, McpRole role) { + this.data = data; + this.role = role; + this.type = mediaType; + } + + @Override + public McpSamplingMessageType type() { + return McpSamplingMessageType.IMAGE; + } + + @Override + public McpRole role() { + return role; + } + + @Override + public MediaType mediaType() { + return type; + } + + @Override + public byte[] data() { + return data; + } + + @Override + public byte[] decodeBase64Data() { + return Base64.getDecoder().decode(data); + } + + @Override + public String encodeBase64Data() { + return Base64.getEncoder().encodeToString(data); + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMediaMessage.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMediaMessage.java new file mode 100644 index 00000000..131dfbc2 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMediaMessage.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import io.helidon.common.media.type.MediaType; + +/** + * MCP sampling media content. + */ +public sealed interface McpSamplingMediaMessage permits McpSamplingAudioMessage, McpSamplingImageMessage { + /** + * Image content raw data. + * + * @return content + */ + byte[] data(); + + /** + * Returns the decoded image data using base64 decoder. + * + * @return decoded content. + */ + byte[] decodeBase64Data(); + + /** + * Returns the encoded image data using base64 encoder. + * + * @return content in base64. + */ + String encodeBase64Data(); + + /** + * Image content MIME type. + * + * @return MIME type + */ + MediaType mediaType(); +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessage.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessage.java new file mode 100644 index 00000000..f274a7b0 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessage.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.server; + +/** + * MCP sampling message. + */ +public sealed interface McpSamplingMessage permits McpSamplingTextMessage, McpSamplingImageMessage, McpSamplingAudioMessage { + /** + * Sampling message role. + * + * @return role + */ + McpRole role(); + + /** + * Sampling message type. + * + * @return type + */ + McpSamplingMessageType type(); +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessageType.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessageType.java new file mode 100644 index 00000000..9cb9a7af --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessageType.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +/** + * Sampling message types. + */ +public enum McpSamplingMessageType { + /** + * Sampling text message type. + */ + TEXT, + /** + * Sampling image message type. + */ + IMAGE, + /** + * Sampling audio message type. + */ + AUDIO; + + /** + * Returns lower case sampling message type name. + * + * @return type name + */ + String text() { + return this.name().toLowerCase(); + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessages.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessages.java new file mode 100644 index 00000000..cc1cc8df --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingMessages.java @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.server; + +import java.util.Objects; + +import io.helidon.common.media.type.MediaType; + +/** + * {@link io.helidon.extensions.mcp.server.McpSamplingMessage} factory class. + */ +public final class McpSamplingMessages { + private McpSamplingMessages() { + } + + /** + * Create a sampling text message. + * + * @param text text + * @param role role + * @return a sampling text message + */ + public static McpSamplingMessage textMessage(String text, McpRole role) { + Objects.requireNonNull(role, "role must not be null"); + Objects.requireNonNull(text, "text must not be null"); + return new McpSamplingTextMessageImpl(text, role); + } + + /** + * Create a sampling image message. + * + * @param data data + * @param mediaType media type + * @param role role + * @return a sampling image message + */ + public static McpSamplingMessage imageMessage(byte[] data, MediaType mediaType, McpRole role) { + Objects.requireNonNull(role, "role must not be null"); + Objects.requireNonNull(data, "data must not be null"); + Objects.requireNonNull(mediaType, "media type must not be null"); + return new McpSamplingImageMessageImpl(data, mediaType, role); + } + + /** + * Create a sampling audio message. + * + * @param data data + * @param mediaType media type + * @param role role + * @return a sampling audio message + */ + public static McpSamplingMessage audioMessage(byte[] data, MediaType mediaType, McpRole role) { + Objects.requireNonNull(data, "data must not be null"); + Objects.requireNonNull(role, "role must not be null"); + Objects.requireNonNull(mediaType, "media type must not be null"); + return new McpSamplingAudioMessageImpl(data, mediaType, role); + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingRequestBlueprint.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingRequestBlueprint.java new file mode 100644 index 00000000..46d5ce0a --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingRequestBlueprint.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.server; + +import java.time.Duration; +import java.util.List; +import java.util.Optional; + +import io.helidon.builder.api.Option; +import io.helidon.builder.api.Prototype; + +import jakarta.json.JsonValue; + +/** + * Configuration of an MCP sampling request. + */ +@Prototype.Blueprint +interface McpSamplingRequestBlueprint { + /** + * Sampling messages sent to the client. + * + * @return messages + */ + @Option.Singular + List messages(); + + /** + * Sampling model hints. + * + * @return hints + */ + Optional> hints(); + + /** + * Sampling cost priority. + * + * @return cost priority + */ + @Option.Decorator(McpDecorators.CostPriorityDecorator.class) + Optional costPriority(); + + /** + * Sampling speed priority. + * + * @return speed priority + */ + @Option.Decorator(McpDecorators.SpeedPriorityDecorator.class) + Optional speedPriority(); + + /** + * Sampling intelligence priority. + * + * @return intelligence priority + */ + @Option.Decorator(McpDecorators.IntelligencePriorityDecorator.class) + Optional intelligencePriority(); + + /** + * Sampling system prompt. + * + * @return system prompt + */ + Optional systemPrompt(); + + /** + * Sampling temperature. + * + * @return temperature + */ + Optional temperature(); + + /** + * Sampling max tokens. + * + * @return max tokens + */ + @Option.DefaultInt(100) + Integer maxTokens(); + + /** + * Sampling stop sequence. + * + * @return stop sequence + */ + Optional> stopSequences(); + + /** + * Sampling include context. + * + * @return include context + */ + Optional includeContext(); + + /** + * Optional metadata to pass through to the LLM provider. + * The format of this metadata is provider-specific. + * + * @return metadata + */ + Optional metadata(); + + /** + * Sampling request timeout. Default is five seconds. + * + * @return timeout + */ + @Option.Default("PT5S") + Duration timeout(); +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingResponse.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingResponse.java new file mode 100644 index 00000000..c9210dc7 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingResponse.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.server; + +import java.util.Optional; + +/** + * Configuration of an MCP sampling response. + */ +public sealed interface McpSamplingResponse permits McpSamplingResponseImpl { + /** + * Sampling response message. + * + * @return response + */ + McpSamplingMessage message(); + + /** + * Returns sampling response message as text message. + * + * @return message as text + * @throws McpSamplingException if the message is not a text + */ + McpSamplingTextMessage asTextMessage() throws McpSamplingException; + + /** + * Returns sampling response message as image message. + * + * @return message as image + * @throws McpSamplingException if the message is not an image + */ + McpSamplingImageMessage asImageMessage() throws McpSamplingException; + + /** + * Returns sampling response message as audio message. + * + * @return message as audio + * @throws McpSamplingException if the message is not an audio + */ + McpSamplingAudioMessage asAudioMessage() throws McpSamplingException; + + /** + * Sampling model used. + * + * @return model + */ + String model(); + + /** + * Sampling stop reason. + * + * @return stop reason + */ + Optional stopReason(); +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingResponseImpl.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingResponseImpl.java new file mode 100644 index 00000000..f5bed2b3 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingResponseImpl.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import java.util.Optional; + +final class McpSamplingResponseImpl implements McpSamplingResponse { + private final String model; + private final McpStopReason stopReason; + private final McpSamplingMessage message; + + McpSamplingResponseImpl(McpSamplingMessage message, String model, McpStopReason stopReason) { + this.message = message; + this.model = model; + this.stopReason = stopReason; + } + + @Override + public McpSamplingMessage message() { + return message; + } + + @Override + public McpSamplingTextMessage asTextMessage() throws McpSamplingException { + if (message instanceof McpSamplingTextMessage text) { + return text; + } + throw new McpSamplingException("Sampling message is not text"); + } + + @Override + public McpSamplingImageMessage asImageMessage() throws McpSamplingException { + if (message instanceof McpSamplingImageMessage image) { + return image; + } + throw new McpSamplingException("Sampling message is not an image"); + } + + @Override + public McpSamplingAudioMessage asAudioMessage() throws McpSamplingException { + if (message instanceof McpSamplingAudioMessage audio) { + return audio; + } + throw new McpSamplingException("Sampling message is not an audio"); + } + + @Override + public String model() { + return model; + } + + @Override + public Optional stopReason() { + return Optional.ofNullable(stopReason); + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingTextMessage.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingTextMessage.java new file mode 100644 index 00000000..a6c4ff85 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingTextMessage.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +/** + * MCP sampling text content. + */ +public sealed interface McpSamplingTextMessage extends McpSamplingMessage permits McpSamplingTextMessageImpl { + /** + * Text content as string. + * + * @return text + */ + String text(); +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingTextMessageImpl.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingTextMessageImpl.java new file mode 100644 index 00000000..9c9bc987 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSamplingTextMessageImpl.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.server; + +/** + * MCP sampling text content. + */ +final class McpSamplingTextMessageImpl implements McpSamplingTextMessage { + private final String text; + private final McpRole role; + + McpSamplingTextMessageImpl(String text, McpRole role) { + this.text = text; + this.role = role; + } + + @Override + public McpSamplingMessageType type() { + return McpSamplingMessageType.TEXT; + } + + @Override + public McpRole role() { + return role; + } + + @Override + public String text() { + return text; + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpServerConfigBlueprint.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpServerConfigBlueprint.java index 43dbebff..5c737590 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpServerConfigBlueprint.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpServerConfigBlueprint.java @@ -75,7 +75,7 @@ interface McpServerConfigBlueprint extends Prototype.Factory { */ @Option.Configured @Option.DefaultInt(DEFAULT_PAGE_SIZE) - @Option.Decorator(McpPagination.PageSizeDecorator.class) + @Option.Decorator(McpDecorators.PageSizeDecorator.class) int toolsPageSize(); /** @@ -85,7 +85,7 @@ interface McpServerConfigBlueprint extends Prototype.Factory { */ @Option.Configured @Option.DefaultInt(DEFAULT_PAGE_SIZE) - @Option.Decorator(McpPagination.PageSizeDecorator.class) + @Option.Decorator(McpDecorators.PageSizeDecorator.class) int promptsPageSize(); /** @@ -95,7 +95,7 @@ interface McpServerConfigBlueprint extends Prototype.Factory { */ @Option.Configured @Option.DefaultInt(DEFAULT_PAGE_SIZE) - @Option.Decorator(McpPagination.PageSizeDecorator.class) + @Option.Decorator(McpDecorators.PageSizeDecorator.class) int resourcesPageSize(); /** @@ -105,7 +105,7 @@ interface McpServerConfigBlueprint extends Prototype.Factory { */ @Option.Configured @Option.DefaultInt(DEFAULT_PAGE_SIZE) - @Option.Decorator(McpPagination.PageSizeDecorator.class) + @Option.Decorator(McpDecorators.PageSizeDecorator.class) int resourceTemplatesPageSize(); /** diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpServerFeature.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpServerFeature.java index 86ef000a..b88a37e5 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpServerFeature.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpServerFeature.java @@ -28,7 +28,6 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import io.helidon.builder.api.RuntimeType; import io.helidon.common.mapper.OptionalValue; @@ -229,7 +228,9 @@ private void mcpMetadata(ServerRequest request, ServerResponse response) { if (providers.isEmpty()) { response.status(Status.NOT_FOUND_404); response.send(); - log(Level.DEBUG, () -> "Security is not enabled, add OIDC security provider to the configuration"); + if (LOGGER.isLoggable(Level.DEBUG)) { + LOGGER.log(Level.DEBUG, "Security is not enabled, add OIDC security provider to the configuration"); + } return; } for (Config provider : providers.get()) { @@ -242,7 +243,9 @@ private void mcpMetadata(ServerRequest request, ServerResponse response) { return; } } - log(Level.DEBUG, () -> "Cannot find \"oidc.identity-uri\" property"); + if (LOGGER.isLoggable(Level.DEBUG)) { + LOGGER.log(Level.DEBUG, "Cannot find \"oidc.identity-uri\" property"); + } response.status(Status.NOT_FOUND_404); response.send(); } @@ -348,10 +351,16 @@ private Optional handleErrorRequest(ServerRequest req, JsonObject if (isResponse(object)) { Optional session = findSession(req); if (session.isPresent()) { - session.get().send(object); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Client response:\n" + prettyPrint(object)); + } + session.get().sendResponse(object); return Optional.empty(); } } + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Wrong format message received:\n" + prettyPrint(object)); + } return Optional.of(JsonRpcError.create(INVALID_REQUEST, "Invalid request")); } @@ -368,7 +377,9 @@ private void notificationCancelRpc(JsonRpcRequest req, JsonRpcResponse res) { Optional foundSession = findSession(req); if (foundSession.isEmpty()) { res.status(Status.NOT_FOUND_404).send(); - log(Level.TRACE, () -> "No session found for cancellation request: %s".formatted(req.asJsonObject())); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "No session found for cancellation request: %s".formatted(req.asJsonObject())); + } return; } McpSession session = foundSession.get(); @@ -378,7 +389,9 @@ private void notificationCancelRpc(JsonRpcRequest req, JsonRpcResponse res) { if (requestId.isEmpty() || reason.isEmpty() || !JsonValue.ValueType.STRING.equals(reason.get().getValueType())) { - log(Level.TRACE, () -> "Malformed cancellation request: %s".formatted(req.asJsonObject())); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Malformed cancellation request: %s".formatted(req.asJsonObject())); + } return; } String cancelReason = ((JsonString) reason.get()).getString(); @@ -388,7 +401,9 @@ private void notificationCancelRpc(JsonRpcRequest req, JsonRpcResponse res) { private void initializeRpc(JsonRpcRequest req, JsonRpcResponse res) { Optional foundSession = findSession(req); - log(Level.TRACE, () -> "Request:\n" + prettyPrint(req.asJsonObject())); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Request:\n" + prettyPrint(req.asJsonObject())); + } // is this streamable HTTP? if (foundSession.isEmpty()) { @@ -407,7 +422,9 @@ private void initializeRpc(JsonRpcRequest req, JsonRpcResponse res) { session.protocolVersion(protocolVersion); res.header(SESSION_ID_HEADER, sessionId); res.result(toJson(protocolVersion, capabilities, config)); - log(Level.TRACE, () -> "Streamable HTTP transport"); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Streamable HTTP transport"); + } res.send(); } else { McpSession session = foundSession.get(); @@ -419,10 +436,14 @@ private void initializeRpc(JsonRpcRequest req, JsonRpcResponse res) { session.state(INITIALIZING); } res.result(toJson(protocolVersion, capabilities, config)); - log(Level.TRACE, () -> "SSE transport"); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "SSE transport"); + } session.send(res); } - log(Level.TRACE, () -> "Response:\n" + prettyPrint(res.asJsonObject())); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Response:\n" + prettyPrint(res.asJsonObject())); + } } private String parseClientVersion(McpParameters parameters) { @@ -441,12 +462,12 @@ private void parseClientCapabilities(McpSession session, McpParameters parameter if (capabilities.get(McpCapability.SAMPLING.text()).isPresent()) { session.capabilities(McpCapability.SAMPLING); } - capabilities.get("roots") + capabilities.get(McpCapability.ROOTS.text()) .get("listChanged") .asBoolean() .ifPresent(listChanged -> { if (listChanged) { - session.capabilities(McpCapability.ROOT); + session.capabilities(McpCapability.ROOTS); } }); } @@ -643,7 +664,9 @@ private void resourceSubscribeRpc(JsonRpcRequest req, JsonRpcResponse res) { try { session.blockSubscribe(resourceUri); } catch (InterruptedException e) { - log(Level.TRACE, () -> "Subscriber thread interrupted"); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Subscriber thread interrupted"); + } } // send final response after unblocking @@ -904,8 +927,10 @@ private void processSimpleCall(JsonRpcRequest req, JsonRpcResponse res, Consumer * @param session the active session */ private void sendResponse(JsonRpcRequest req, JsonRpcResponse res, McpSession session) { - log(Level.TRACE, () -> "Request:\n" + prettyPrint(req.asJsonObject())); - log(Level.TRACE, () -> "Response:\n" + prettyPrint(res.asJsonObject())); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Request:\n" + prettyPrint(req.asJsonObject())); + LOGGER.log(Level.TRACE, "Response:\n" + prettyPrint(res.asJsonObject())); + } if (isStreamableHttp(req.headers())) { res.send(); @@ -943,8 +968,10 @@ private void sendResponse(JsonRpcRequest req, McpSession session, JsonValue requestId, SseSink sseSink) { - log(Level.TRACE, () -> "Request:\n" + prettyPrint(req.asJsonObject())); - log(Level.TRACE, () -> "Response:\n" + prettyPrint(res.asJsonObject())); + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Request:\n" + prettyPrint(req.asJsonObject())); + LOGGER.log(Level.TRACE, "Response:\n" + prettyPrint(res.asJsonObject())); + } // send response as HTTP or SSE with streamable HTTP if (isStreamableHttp(req.headers())) { @@ -1030,6 +1057,10 @@ private Optional sendError(JsonRpcRequest request, JsonRpcResponse response, Throwable throwable, int errorCode) { + if (LOGGER.isLoggable(Level.DEBUG)) { + LOGGER.log(Level.DEBUG, "Send error response because of: ", throwable); + } + // Look up session to send an error to the client var session = findSession(request); if (session.isEmpty()) { @@ -1044,22 +1075,16 @@ private Optional sendError(JsonRpcRequest request, sseSink = features.get().sseSink().orElse(null); } - log(Level.TRACE, () -> "Request:\n" + prettyPrint(request.asJsonObject())); - log(Level.TRACE, () -> "Response:\n" + prettyPrint(response.asJsonObject())); - // If streamable HTTP transport and did not switch to SSE // the handler manages the response if (isStreamableHttp(request.headers()) && sseSink == null) { + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Response:\n" + prettyPrint(response.asJsonObject())); + } session.get().clearRequest(requestId); return response.error(); } sendResponse(request, response, session.get(), requestId, sseSink); return Optional.empty(); } - - private void log(Level level, Supplier message) { - if (LOGGER.isLoggable(level)) { - LOGGER.log(level, message); - } - } } diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpSession.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpSession.java index faeb38b6..ced6b0dd 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpSession.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpSession.java @@ -29,6 +29,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.Consumer; @@ -45,6 +46,7 @@ import jakarta.json.JsonObject; import jakarta.json.JsonValue; +import static io.helidon.extensions.mcp.server.McpJsonRpc.timeoutResponse; import static io.helidon.extensions.mcp.server.McpServerFeature.isStreamableHttp; import static io.helidon.extensions.mcp.server.McpSession.State.UNINITIALIZED; @@ -55,8 +57,10 @@ class McpSession { private final McpServerConfig config; private final Set capabilities; private final Context context = Context.create(); + private final AtomicLong jsonRpcId = new AtomicLong(0); private final AtomicBoolean active = new AtomicBoolean(true); private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private final BlockingQueue responses = new LinkedBlockingQueue<>(); private final LruCache features = LruCache.create(); private final ReadWriteLock lock = new ReentrantReadWriteLock(); @@ -96,6 +100,29 @@ void poll(Consumer consumer) { } } + JsonObject pollResponse(long requestId, Duration timeout) { + while (active.get()) { + try { + JsonObject response = responses.poll(timeout.toMillis(), TimeUnit.MILLISECONDS); + if (response != null) { + long id = response.getJsonNumber("id").longValue(); + if (id == requestId) { + return response; + } + } else { + return timeoutResponse(requestId); + } + } catch (ClassCastException e) { + if (LOGGER.isLoggable(Level.TRACE)) { + LOGGER.log(Level.TRACE, "Received a response with wrong request id type", e); + } + } catch (InterruptedException e) { + throw new McpInternalException("Session interrupted.", e); + } + } + throw new McpInternalException("Session disconnected"); + } + void send(JsonObject message) { try { queue.put(message); @@ -108,6 +135,14 @@ void send(JsonRpcResponse response) { send(response.status(Status.ACCEPTED_202).asJsonObject()); } + void sendResponse(JsonObject response) { + try { + responses.put(response); + } catch (InterruptedException e) { + throw new UncheckedException(e); + } + } + void disconnect() { if (active.compareAndSet(true, false)) { queue.add(McpJsonRpc.disconnectSession()); @@ -136,6 +171,16 @@ Optional features(JsonValue requestId) { return features.get(requestId); } + /** + * Generates a unique JSON-RPC {@code id} for an outbound request to the client. + * The returned identifier is guaranteed to be unused by any prior request in this session. + * + * @return a new request id + */ + long jsonRpcId() { + return jsonRpcId.getAndIncrement(); + } + void clearRequest(JsonValue requestId) { features.remove(requestId); } @@ -144,6 +189,10 @@ void capabilities(McpCapability capability) { capabilities.add(capability); } + Set capabilities() { + return capabilities; + } + State state() { return state; } @@ -178,7 +227,7 @@ Optional findSubscription(String uri) { return Optional.empty(); } throw new IllegalArgumentException("Subscription not found: " + uri); - } finally { + } finally { lock.readLock().unlock(); } } @@ -235,7 +284,7 @@ Optional unsubscribe(JsonRpcRequest req, String uri) { } log(Level.DEBUG, () -> "Removed subscription for " + uri); return Optional.ofNullable(sseSink); - } finally { + } finally { lock.writeLock().unlock(); } } diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpStopReason.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpStopReason.java new file mode 100644 index 00000000..06a06f49 --- /dev/null +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpStopReason.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.server; + +/** + * Sampling request stop reasons. + */ +public enum McpStopReason { + /** + * End turn. + */ + END_TURN, + /** + * Stop sequence. + */ + STOP_SEQUENCE, + /** + * Max tokens. + */ + MAX_TOKENS; + + String text() { + return this.name().toLowerCase().replace("_", ""); + } + + static McpStopReason map(String reason) { + reason = reason.toLowerCase(); + for (McpStopReason stopReason : McpStopReason.values()) { + if (stopReason.text().equals(reason)) { + return stopReason; + } + } + throw new IllegalArgumentException("Unknown stop reason: " + reason); + } +} diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpTextContent.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpTextContent.java index 06cb4e96..66aa39f8 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpTextContent.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpTextContent.java @@ -20,7 +20,6 @@ * Text content. */ sealed interface McpTextContent extends McpContent permits McpTextContent.McpTextContentImpl { - /** * Text content as string. * diff --git a/server/src/main/java/io/helidon/extensions/mcp/server/McpToolErrorException.java b/server/src/main/java/io/helidon/extensions/mcp/server/McpToolErrorException.java index b3c2ff9e..46e57eed 100644 --- a/server/src/main/java/io/helidon/extensions/mcp/server/McpToolErrorException.java +++ b/server/src/main/java/io/helidon/extensions/mcp/server/McpToolErrorException.java @@ -18,10 +18,11 @@ import java.util.Arrays; import java.util.List; +import java.util.Objects; /** * Tool error exception are sending a tool response to the client with - * the provided contents and an error flag. + * the provided content and an error flag. */ public final class McpToolErrorException extends RuntimeException { private final List contents; @@ -44,6 +45,18 @@ public McpToolErrorException(McpToolContent... contents) { this.contents = Arrays.asList(contents); } + /** + * Creates a tool error exception with provided message. + * + * @param messages error messages + */ + public McpToolErrorException(String... messages) { + this.contents = Arrays.stream(messages) + .filter(Objects::nonNull) + .map(McpToolContents::textContent) + .toList(); + } + List contents() { return contents; } diff --git a/server/src/test/java/io/helidon/extensions/mcp/server/McpSamplingRequestTest.java b/server/src/test/java/io/helidon/extensions/mcp/server/McpSamplingRequestTest.java new file mode 100644 index 00000000..a53bb3b2 --- /dev/null +++ b/server/src/test/java/io/helidon/extensions/mcp/server/McpSamplingRequestTest.java @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import java.time.Duration; +import java.util.List; + +import jakarta.json.JsonValue; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +class McpSamplingRequestTest { + + @Test + void testDefaultValues() { + McpSamplingRequest request = McpSamplingRequest.create(); + + assertThat(request.maxTokens(), is(100)); + assertThat(request.hints().isEmpty(), is(true)); + assertThat(request.messages().isEmpty(), is(true)); + assertThat(request.metadata().isEmpty(), is(true)); + assertThat(request.temperature().isEmpty(), is(true)); + assertThat(request.costPriority().isEmpty(), is(true)); + assertThat(request.systemPrompt().isEmpty(), is(true)); + assertThat(request.stopSequences().isEmpty(), is(true)); + assertThat(request.speedPriority().isEmpty(), is(true)); + assertThat(request.includeContext().isEmpty(), is(true)); + assertThat(request.timeout(), is(Duration.ofSeconds(5))); + assertThat(request.intelligencePriority().isEmpty(), is(true)); + } + + @Test + void testCustomValues() { + McpSamplingRequest request = McpSamplingRequest.builder() + .maxTokens(1) + .temperature(0.1) + .costPriority(0.1) + .speedPriority(0.1) + .hints(List.of("hint1")) + .metadata(JsonValue.TRUE) + .intelligencePriority(0.1) + .systemPrompt("system prompt") + .timeout(Duration.ofSeconds(10)) + .stopSequences(List.of("stop1")) + .includeContext(McpIncludeContext.NONE) + .addMessage(McpSamplingMessages.textMessage("text", McpRole.USER)) + .build(); + + assertThat(request.maxTokens(), is(1)); + assertThat(request.timeout(), is(Duration.ofSeconds(10))); + + assertThat(request.hints().isEmpty(), is(false)); + assertThat(request.hints().get(), is(List.of("hint1"))); + + assertThat(request.messages().isEmpty(), is(false)); + assertThat(request.messages().size(), is(1)); + + var message = request.messages().getFirst(); + assertThat(message, instanceOf(McpSamplingTextMessage.class)); + assertThat(message.role(), is(McpRole.USER)); + assertThat(((McpSamplingTextMessage) message).text(), is("text")); + + assertThat(request.metadata().isEmpty(), is(false)); + assertThat(request.metadata().get(), is(JsonValue.TRUE)); + + assertThat(request.includeContext().isEmpty(), is(false)); + assertThat(request.includeContext().get(), is(McpIncludeContext.NONE)); + + assertThat(request.systemPrompt().isEmpty(), is(false)); + assertThat(request.systemPrompt().get(), is("system prompt")); + + assertThat(request.stopSequences().isEmpty(), is(false)); + assertThat(request.stopSequences().get(), is(List.of("stop1"))); + + assertThat(request.temperature().isEmpty(), is(false)); + assertThat(request.temperature().get(), is(0.1)); + + assertThat(request.costPriority().isEmpty(), is(false)); + assertThat(request.costPriority().get(), is(0.1)); + + assertThat(request.speedPriority().isEmpty(), is(false)); + assertThat(request.speedPriority().get(), is(0.1)); + + assertThat(request.intelligencePriority().isEmpty(), is(false)); + assertThat(request.intelligencePriority().get(), is(0.1)); + } + + @ParameterizedTest + @ValueSource(doubles = {1.1, -1.1}) + void testIntelligencePriorityDecorator(double value) { + try { + McpSamplingRequest.builder() + .intelligencePriority(value) + .build(); + assertThat("Setting a value outside of range [0, 1] must throw an exception", true, is(false)); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), is("Intelligence priority must be in range [0, 1]")); + } + } + + @ParameterizedTest + @ValueSource(doubles = {1.1, -1.1}) + void testCostPriorityDecorator(double value) { + try { + McpSamplingRequest.builder() + .costPriority(value) + .build(); + assertThat("Setting a value outside of range [0, 1] must throw an exception", true, is(false)); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), is("Cost priority must be in range [0, 1]")); + } + } + + @ParameterizedTest + @ValueSource(doubles = {1.1, -1.1}) + void testSpeedPriorityDecorator(double value) { + try { + McpSamplingRequest.builder() + .speedPriority(value) + .build(); + assertThat("Setting a value outside of range [0, 1] must throw an exception", true, is(false)); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage(), is("Speed priority must be in range [0, 1]")); + } + } +} diff --git a/server/src/test/java/io/helidon/extensions/mcp/server/McpSamplingResponseTest.java b/server/src/test/java/io/helidon/extensions/mcp/server/McpSamplingResponseTest.java new file mode 100644 index 00000000..f2f7eb65 --- /dev/null +++ b/server/src/test/java/io/helidon/extensions/mcp/server/McpSamplingResponseTest.java @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.helidon.extensions.mcp.server; + +import java.nio.charset.StandardCharsets; + +import io.helidon.common.media.type.MediaTypes; + +import org.junit.jupiter.api.Test; + +import static io.helidon.extensions.mcp.server.McpSamplingMessages.audioMessage; +import static io.helidon.extensions.mcp.server.McpSamplingMessages.imageMessage; +import static io.helidon.extensions.mcp.server.McpSamplingMessages.textMessage; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class McpSamplingResponseTest { + + @Test + void testSamplingResponseTextMessage() { + var message = textMessage("text", McpRole.USER); + McpSamplingResponse response = new McpSamplingResponseImpl(message, "helidon-model", McpStopReason.END_TURN); + + assertThat(response.model(), is("helidon-model")); + assertThat(response.stopReason().isPresent(), is(true)); + assertThat(response.stopReason().get(), is(McpStopReason.END_TURN)); + assertThat(response.message(), instanceOf(McpSamplingTextMessage.class)); + + McpSamplingTextMessage text = response.asTextMessage(); + assertThat(text.role(), is(McpRole.USER)); + assertThat(text.text(), is("text")); + + assertThrows(McpSamplingException.class, response::asImageMessage); + assertThrows(McpSamplingException.class, response::asAudioMessage); + } + + @Test + void testSamplingResponseImageMessage() { + var data = "data".getBytes(StandardCharsets.UTF_8); + var message = imageMessage(data, MediaTypes.TEXT_PLAIN, McpRole.USER); + McpSamplingResponse response = new McpSamplingResponseImpl(message, "helidon-model", McpStopReason.END_TURN); + + assertThat(response.model(), is("helidon-model")); + assertThat(response.stopReason().isPresent(), is(true)); + assertThat(response.stopReason().get(), is(McpStopReason.END_TURN)); + assertThat(response.message(), instanceOf(McpSamplingImageMessage.class)); + + McpSamplingImageMessage image = response.asImageMessage(); + assertThat(image.role(), is(McpRole.USER)); + assertThat(image.data(), is(data)); + + assertThrows(McpSamplingException.class, response::asTextMessage); + assertThrows(McpSamplingException.class, response::asAudioMessage); + } + + @Test + void testSamplingResponseAudioMessage() { + var data = "data".getBytes(StandardCharsets.UTF_8); + var message = audioMessage(data, MediaTypes.TEXT_PLAIN, McpRole.USER); + McpSamplingResponse response = new McpSamplingResponseImpl(message, "helidon-model", McpStopReason.END_TURN); + + assertThat(response.model(), is("helidon-model")); + assertThat(response.stopReason().isPresent(), is(true)); + assertThat(response.stopReason().get(), is(McpStopReason.END_TURN)); + assertThat(response.message(), instanceOf(McpSamplingAudioMessage.class)); + + McpSamplingAudioMessage image = response.asAudioMessage(); + assertThat(image.role(), is(McpRole.USER)); + assertThat(image.data(), is(data)); + + assertThrows(McpSamplingException.class, response::asTextMessage); + assertThrows(McpSamplingException.class, response::asImageMessage); + } +} diff --git a/tests/codegen/src/test/java/io/helidon/extensions/mcp/codegen/McpTypesTest.java b/tests/codegen/src/test/java/io/helidon/extensions/mcp/codegen/McpTypesTest.java index e72a6622..7a4b50e8 100644 --- a/tests/codegen/src/test/java/io/helidon/extensions/mcp/codegen/McpTypesTest.java +++ b/tests/codegen/src/test/java/io/helidon/extensions/mcp/codegen/McpTypesTest.java @@ -47,6 +47,7 @@ import io.helidon.extensions.mcp.server.McpResourceSubscriber; import io.helidon.extensions.mcp.server.McpResourceUnsubscriber; import io.helidon.extensions.mcp.server.McpRole; +import io.helidon.extensions.mcp.server.McpSampling; import io.helidon.extensions.mcp.server.McpServerConfig; import io.helidon.extensions.mcp.server.McpTool; import io.helidon.extensions.mcp.server.McpToolAnnotations; @@ -114,6 +115,7 @@ void testTypes() { checkField(toCheck, checked, fields, "MCP_REQUEST", McpRequest.class); checkField(toCheck, checked, fields, "MCP_FEATURES", McpFeatures.class); checkField(toCheck, checked, fields, "MCP_PROGRESS", McpProgress.class); + checkField(toCheck, checked, fields, "MCP_SAMPLING", McpSampling.class); checkField(toCheck, checked, fields, "MCP_TOOL_INTERFACE", McpTool.class); checkField(toCheck, checked, fields, "MCP_PARAMETERS", McpParameters.class); checkField(toCheck, checked, fields, "MCP_PROMPT_INTERFACE", McpPrompt.class); diff --git a/tests/declarative/src/main/java/io/helidon/extensions/mcp/tests/declarative/McpSamplingServer.java b/tests/declarative/src/main/java/io/helidon/extensions/mcp/tests/declarative/McpSamplingServer.java new file mode 100644 index 00000000..a91c3423 --- /dev/null +++ b/tests/declarative/src/main/java/io/helidon/extensions/mcp/tests/declarative/McpSamplingServer.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.tests.declarative; + +import java.util.List; + +import io.helidon.common.media.type.MediaTypes; +import io.helidon.extensions.mcp.server.Mcp; +import io.helidon.extensions.mcp.server.McpPromptContent; +import io.helidon.extensions.mcp.server.McpRequest; +import io.helidon.extensions.mcp.server.McpResourceContent; +import io.helidon.extensions.mcp.server.McpSampling; +import io.helidon.extensions.mcp.server.McpToolContent; + +@Mcp.Server +@Mcp.Path("/sampling") +class McpSamplingServer { + + @Mcp.Tool("Sampling tool") + List tool(McpSampling sampling) { + return List.of(); + } + + @Mcp.Tool("Sampling tool") + List tool1(McpSampling sampling, String value) { + return List.of(); + } + + @Mcp.Tool("Sampling tool") + String tool4(McpSampling sampling) { + return ""; + } + + @Mcp.Tool("Sampling tool") + String tool5(McpSampling sampling, String value) { + return ""; + } + + @Mcp.Prompt("Sampling prompt") + List prompt(McpSampling sampling) { + return List.of(); + } + + @Mcp.Prompt("Sampling prompt") + List prompt1(McpSampling sampling, String value) { + return List.of(); + } + + @Mcp.Prompt("Sampling prompt") + String prompt4(McpSampling sampling) { + return ""; + } + + @Mcp.Prompt("Sampling prompt") + String prompt5(McpSampling sampling, String value) { + return ""; + } + + @Mcp.Resource(uri = "https://example.com", + description = "Sampling resource", + mediaType = MediaTypes.TEXT_PLAIN_VALUE) + List resource(McpSampling sampling) { + return List.of(); + } + + @Mcp.Resource(uri = "https://example.com", + description = "Sampling resource", + mediaType = MediaTypes.TEXT_PLAIN_VALUE) + List resource1(McpSampling sampling, McpRequest request) { + return List.of(); + } + + @Mcp.Resource(uri = "https://example.com", + description = "Sampling resource", + mediaType = MediaTypes.TEXT_PLAIN_VALUE) + String resource4(McpSampling sampling) { + return ""; + } + + @Mcp.Resource(uri = "https://example.com", + description = "Sampling resource", + mediaType = MediaTypes.TEXT_PLAIN_VALUE) + String resource5(McpSampling sampling, McpRequest request) { + return ""; + } +} diff --git a/tests/mcp/src/main/java/io/helidon/extensions/mcp/tests/SamplingServer.java b/tests/mcp/src/main/java/io/helidon/extensions/mcp/tests/SamplingServer.java new file mode 100644 index 00000000..9b88f032 --- /dev/null +++ b/tests/mcp/src/main/java/io/helidon/extensions/mcp/tests/SamplingServer.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.tests; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.function.Function; + +import io.helidon.common.media.type.MediaTypes; +import io.helidon.extensions.mcp.server.McpContent; +import io.helidon.extensions.mcp.server.McpException; +import io.helidon.extensions.mcp.server.McpRequest; +import io.helidon.extensions.mcp.server.McpSampling; +import io.helidon.extensions.mcp.server.McpSamplingException; +import io.helidon.extensions.mcp.server.McpSamplingMessage; +import io.helidon.extensions.mcp.server.McpSamplingMessages; +import io.helidon.extensions.mcp.server.McpSamplingResponse; +import io.helidon.extensions.mcp.server.McpServerFeature; +import io.helidon.extensions.mcp.server.McpTool; +import io.helidon.extensions.mcp.server.McpToolContent; +import io.helidon.extensions.mcp.server.McpToolErrorException; +import io.helidon.json.schema.Schema; +import io.helidon.webserver.http.HttpRouting; + +import static io.helidon.extensions.mcp.server.McpRole.USER; +import static io.helidon.extensions.mcp.server.McpToolContents.textContent; + +class SamplingServer { + private SamplingServer() { + } + + static void setUpRoute(HttpRouting.Builder builder) { + builder.addFeature(McpServerFeature.builder() + .path("/") + .addTool(new EnabledTool()) + .addTool(new SamplingTool()) + .addTool(new ErrorSamplingTool()) + .addTool(new TimeoutSamplingTool()) + .addTool(new MultipleSamplingRequestTool()) + ); + } + + private static class SamplingTool implements McpTool { + @Override + public String name() { + return "sampling-tool"; + } + + @Override + public String description() { + return "A tool that returns sampling response as tool content."; + } + + @Override + public String schema() { + return Schema.builder().build().generate(); + } + + @Override + public Function> tool() { + return this::sampling; + } + + List sampling(McpRequest request) { + McpSampling sampling = request.features().sampling(); + McpContent.ContentType requestType = request.parameters() + .get("type") + .asString() + .map(String::toUpperCase) + .map(McpContent.ContentType::valueOf) + .orElseThrow(() -> new McpToolErrorException("Error while parsing content type")); + + McpSamplingMessage message = createMessage(requestType); + McpSamplingResponse response = sampling.request(req -> req.addMessage(message)); + var type = response.message().type(); + return switch (type) { + case TEXT -> List.of(textContent(response.asTextMessage().text())); + case IMAGE -> List.of(textContent(new String(response.asImageMessage().data()))); + case AUDIO -> List.of(textContent(new String(response.asAudioMessage().data()))); + }; + } + + McpSamplingMessage createMessage(McpContent.ContentType type) { + return switch (type) { + case TEXT -> McpSamplingMessages.textMessage("samplingMessage", USER); + case IMAGE -> McpSamplingMessages.imageMessage("samplingMessage".getBytes(StandardCharsets.UTF_8), + MediaTypes.TEXT_PLAIN, + USER); + case AUDIO -> McpSamplingMessages.audioMessage("samplingMessage".getBytes(StandardCharsets.UTF_8), + MediaTypes.TEXT_PLAIN, + USER); + default -> throw new McpToolErrorException(textContent("Unsupported sampling message type: " + type)); + }; + } + } + + private static class EnabledTool extends SamplingTool { + @Override + public String name() { + return "enabled-tool"; + } + + @Override + public Function> tool() { + return this::enabledSampling; + } + + private List enabledSampling(McpRequest request) { + McpSampling sampling = request.features().sampling(); + if (sampling.enabled()) { + return sampling(request); + } + throw new McpToolErrorException(textContent("sampling is disabled")); + } + } + + private static class MultipleSamplingRequestTool extends SamplingTool { + private final McpSamplingMessage message = McpSamplingMessages.textMessage("ignored", USER); + + @Override + public String name() { + return "multiple-sampling-tool"; + } + + @Override + public Function> tool() { + return request -> { + McpSampling sampling = request.features().sampling(); + var response = sampling.request(req -> req.addMessage(message)); + return sampling(request); + }; + } + } + + private static class TimeoutSamplingTool extends SamplingTool { + @Override + public String name() { + return "timeout-tool"; + } + + @Override + public Function> tool() { + return request -> { + try { + request.features() + .sampling() + .request(req -> req.timeout(Duration.ofSeconds(2)) + .addMessage(McpSamplingMessages.textMessage("timeout", USER))); + throw new McpException("Timeout should have been triggered"); + } catch (McpSamplingException e) { + throw new McpToolErrorException(e.getMessage()); + } + }; + } + } + + private static class ErrorSamplingTool extends SamplingTool { + @Override + public String name() { + return "error-tool"; + } + + @Override + public Function> tool() { + return request -> { + try { + request.features() + .sampling() + .request(req -> req.addMessage(McpSamplingMessages.textMessage("error", USER))); + throw new McpException("MCP sampling exception should have been triggered"); + } catch (McpSamplingException e) { + throw new McpToolErrorException(e.getMessage()); + } + }; + } + } +} diff --git a/tests/mcp/src/main/java/io/helidon/extensions/mcp/tests/ToolErrorResultServer.java b/tests/mcp/src/main/java/io/helidon/extensions/mcp/tests/ToolErrorResultServer.java index 5052c1f3..a663ff5f 100644 --- a/tests/mcp/src/main/java/io/helidon/extensions/mcp/tests/ToolErrorResultServer.java +++ b/tests/mcp/src/main/java/io/helidon/extensions/mcp/tests/ToolErrorResultServer.java @@ -23,10 +23,11 @@ import io.helidon.extensions.mcp.server.McpServerFeature; import io.helidon.extensions.mcp.server.McpTool; import io.helidon.extensions.mcp.server.McpToolContent; -import io.helidon.extensions.mcp.server.McpToolContents; import io.helidon.extensions.mcp.server.McpToolErrorException; import io.helidon.webserver.http.HttpRouting; +import static io.helidon.extensions.mcp.server.McpToolContents.textContent; + class ToolErrorResultServer { private ToolErrorResultServer() { } @@ -35,11 +36,11 @@ static void setUpRoute(HttpRouting.Builder builder) { builder.addFeature(McpServerFeature.builder() .path("/") .addTool(new FailingTool()) - .addTool(new FailingTool1())); + .addTool(new FailingTool1()) + .addTool(new FailingTool2())); } private static class FailingTool implements McpTool { - @Override public String name() { return "failing-tool"; @@ -57,13 +58,12 @@ public String schema() { @Override public Function> tool() { - McpToolContent content = McpToolContents.textContent("Tool error message"); + McpToolContent content = textContent("Tool error message"); throw new McpToolErrorException(content); } } private static class FailingTool1 extends FailingTool { - @Override public String name() { return "failing-tool-1"; @@ -71,8 +71,22 @@ public String name() { @Override public Function> tool() { - McpToolContent content = McpToolContents.textContent("Tool error message"); + McpToolContent content = textContent("Tool error message"); throw new McpToolErrorException(List.of(content)); } } + + private static class FailingTool2 extends FailingTool { + @Override + public String name() { + return "failing-tool-2"; + } + + @Override + public Function> tool() { + McpToolContent content = textContent("Tool error message"); + McpToolContent content1 = textContent("Second error message"); + throw new McpToolErrorException(content, content1); + } + } } diff --git a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jMcpExceptionTest.java b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jMcpExceptionTest.java index 9a2c2256..82e3aaef 100644 --- a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jMcpExceptionTest.java +++ b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jMcpExceptionTest.java @@ -18,15 +18,12 @@ import java.util.Map; -import io.helidon.jsonrpc.core.JsonRpcError; import io.helidon.webserver.http.HttpRouting; import io.helidon.webserver.testing.junit5.SetUpRoute; import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.exception.LangChain4jException; import dev.langchain4j.mcp.client.McpClient; -import dev.langchain4j.mcp.client.protocol.McpGetPromptRequest; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; diff --git a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jToolErrorResultTest.java b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jToolErrorResultTest.java index d670dbd2..bbdfc50f 100644 --- a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jToolErrorResultTest.java +++ b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractLangchain4jToolErrorResultTest.java @@ -23,6 +23,7 @@ import dev.langchain4j.exception.ToolExecutionException; import dev.langchain4j.mcp.client.McpClient; import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -53,4 +54,15 @@ void testFailingToolResult(String name) { assertThat(e.getMessage(), is("Tool error message")); } } + + @Test + void testMultipleErrorMessages() { + try { + var result = client.executeTool(ToolExecutionRequest.builder() + .name("failing-tool-2") + .build()); + } catch (ToolExecutionException e) { + assertThat(e.getMessage(), is("Tool error message\nSecond error message")); + } + } } diff --git a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractMcpSdkSamplingTest.java b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractMcpSdkSamplingTest.java new file mode 100644 index 00000000..8403bc02 --- /dev/null +++ b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractMcpSdkSamplingTest.java @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.tests; + +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import io.helidon.common.media.type.MediaTypes; +import io.helidon.webserver.http.HttpRouting; +import io.helidon.webserver.testing.junit5.SetUpRoute; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static io.modelcontextprotocol.spec.McpSchema.CreateMessageResult.StopReason.STOP_SEQUENCE; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +abstract class AbstractMcpSdkSamplingTest extends AbstractMcpSdkTest { + private static final String SAMPLING_CLIENT_TEXT = "samplingMessage"; + private static final String SAMPLING_ERROR_MESSAGE = "sampling error message"; + + @SetUpRoute + static void routing(HttpRouting.Builder builder) { + SamplingServer.setUpRoute(builder); + } + + protected McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest request) { + var messages = request.messages(); + assertThat(messages.size(), is(1)); + + var message = messages.getFirst(); + return switch (message.content().type()) { + case "text" -> testTextMessage(message); + case "image" -> testImageMessage(message); + case "audio" -> testAudioMessage(message); + default -> throw new IllegalStateException("Wrong sampling message type"); + }; + } + + private McpSchema.CreateMessageResult testAudioMessage(McpSchema.SamplingMessage message) { + assertThat(message.content(), instanceOf(McpSchema.AudioContent.class)); + assertThat(message.role(), is(McpSchema.Role.USER)); + + var audio = (McpSchema.AudioContent) message.content(); + assertThat(decode(audio.data()), is(SAMPLING_CLIENT_TEXT)); + assertThat(audio.mimeType(), is(MediaTypes.TEXT_PLAIN_VALUE)); + + var annotations = new McpSchema.Annotations(List.of(), 1.0); + var result = new McpSchema.AudioContent(annotations, SAMPLING_CLIENT_TEXT, MediaTypes.TEXT_PLAIN_VALUE); + return new McpSchema.CreateMessageResult(McpSchema.Role.USER, + result, + "test-model", + STOP_SEQUENCE); + } + + private McpSchema.CreateMessageResult testImageMessage(McpSchema.SamplingMessage message) { + assertThat(message.content(), instanceOf(McpSchema.ImageContent.class)); + assertThat(message.role(), is(McpSchema.Role.USER)); + + var image = (McpSchema.ImageContent) message.content(); + assertThat(decode(image.data()), is(SAMPLING_CLIENT_TEXT)); + assertThat(image.mimeType(), is(MediaTypes.TEXT_PLAIN_VALUE)); + + var annotations = new McpSchema.Annotations(List.of(), 1.0); + var result = new McpSchema.ImageContent(annotations, SAMPLING_CLIENT_TEXT, MediaTypes.TEXT_PLAIN_VALUE); + return new McpSchema.CreateMessageResult(McpSchema.Role.USER, result, "test-model", STOP_SEQUENCE); + } + + private McpSchema.CreateMessageResult testTextMessage(McpSchema.SamplingMessage message) { + assertThat(message.content(), instanceOf(McpSchema.TextContent.class)); + assertThat(message.role(), is(McpSchema.Role.USER)); + + var text = (McpSchema.TextContent) message.content(); + var result = new McpSchema.TextContent(SAMPLING_CLIENT_TEXT); + + if ("timeout".equals(text.text())) { + try { + TimeUnit.SECONDS.sleep(4); + return new McpSchema.CreateMessageResult(McpSchema.Role.USER, result, "test-model", STOP_SEQUENCE); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + if ("error".equals(text.text())) { + throw new RuntimeException(SAMPLING_ERROR_MESSAGE); + } + + return new McpSchema.CreateMessageResult(McpSchema.Role.USER, result, "test-model", STOP_SEQUENCE); + } + + private String decode(String data) { + return new String(Base64.getDecoder().decode(data)); + } + + @ParameterizedTest + @ValueSource(strings = {"image", "audio"}) + void testContentTypeSamplingTool(String type) { + var result = client().callTool(McpSchema.CallToolRequest.builder() + .name("sampling-tool") + .arguments(Map.of("type", type)) + .build()); + List contents = result.content(); + assertThat(contents.size(), is(1)); + + McpSchema.Content content = contents.getFirst(); + assertThat(content, instanceOf(McpSchema.TextContent.class)); + + McpSchema.TextContent textContent = (McpSchema.TextContent) content; + assertThat(textContent.text(), is(SAMPLING_CLIENT_TEXT)); + } + + @Test + void testTextContentTypeSamplingTool() { + var result = client().callTool(McpSchema.CallToolRequest.builder() + .name("sampling-tool") + .arguments(Map.of("type", "text")) + .build()); + List contents = result.content(); + assertThat(contents.size(), is(1)); + + McpSchema.Content content = contents.getFirst(); + assertThat(content, instanceOf(McpSchema.TextContent.class)); + + McpSchema.TextContent textContent = (McpSchema.TextContent) content; + assertThat(textContent.text(), is(SAMPLING_CLIENT_TEXT)); + } + + @Test + void testEnabledTool() { + var result = client().callTool(McpSchema.CallToolRequest.builder() + .name("enabled-tool") + .arguments(Map.of("type", "text")) + .build()); + List contents = result.content(); + assertThat(contents.size(), is(1)); + + McpSchema.Content content = contents.getFirst(); + assertThat(content, instanceOf(McpSchema.TextContent.class)); + + McpSchema.TextContent textContent = (McpSchema.TextContent) content; + assertThat(textContent.text(), is(SAMPLING_CLIENT_TEXT)); + } + + @Test + void testTimeoutTool() { + var result = client().callTool(McpSchema.CallToolRequest.builder() + .name("timeout-tool") + .arguments(Map.of("type", "text")) + .build()); + assertThat(result.isError(), is(true)); + + var contents = result.content(); + assertThat(result.content().size(), is(1)); + + McpSchema.Content content = contents.getFirst(); + assertThat(content, instanceOf(McpSchema.TextContent.class)); + + McpSchema.TextContent textContent = (McpSchema.TextContent) content; + assertThat(textContent.text(), is("response timeout")); + } + + @Test + void testErrorTool() { + var result = client().callTool(McpSchema.CallToolRequest.builder() + .name("error-tool") + .arguments(Map.of("type", "text")) + .build()); + assertThat(result.isError(), is(true)); + + var contents = result.content(); + assertThat(result.content().size(), is(1)); + + McpSchema.Content content = contents.getFirst(); + assertThat(content, instanceOf(McpSchema.TextContent.class)); + + McpSchema.TextContent textContent = (McpSchema.TextContent) content; + assertThat(textContent.text(), is(SAMPLING_ERROR_MESSAGE)); + } +} diff --git a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractMcpSdkToolErrorResultTest.java b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractMcpSdkToolErrorResultTest.java index fddff463..b8d90494 100644 --- a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractMcpSdkToolErrorResultTest.java +++ b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/AbstractMcpSdkToolErrorResultTest.java @@ -20,6 +20,7 @@ import io.helidon.webserver.testing.junit5.SetUpRoute; import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -48,4 +49,25 @@ void testFailingToolResult(String name) { McpSchema.TextContent text = (McpSchema.TextContent) content; assertThat(text.text(), is("Tool error message")); } + + @Test + void testMultipleMessageError() { + var result = client().callTool(McpSchema.CallToolRequest.builder() + .name("failing-tool-2") + .build()); + assertThat(result.isError(), is(true)); + assertThat(result.content().size(), is(2)); + + var content = result.content().getFirst(); + assertThat(content, instanceOf(McpSchema.TextContent.class)); + + McpSchema.TextContent text = (McpSchema.TextContent) content; + assertThat(text.text(), is("Tool error message")); + + content = result.content().get(1); + assertThat(content, instanceOf(McpSchema.TextContent.class)); + + text = (McpSchema.TextContent) content; + assertThat(text.text(), is("Second error message")); + } } diff --git a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/McpSdkSseSamplingTest.java b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/McpSdkSseSamplingTest.java new file mode 100644 index 00000000..afe9fe09 --- /dev/null +++ b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/McpSdkSseSamplingTest.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.tests; + +import io.helidon.webserver.WebServer; +import io.helidon.webserver.testing.junit5.ServerTest; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; + +@ServerTest +class McpSdkSseSamplingTest extends AbstractMcpSdkSamplingTest { + private final McpSyncClient client; + + McpSdkSseSamplingTest(WebServer server) { + client = McpClient.sync(sse(server.port())) + .sampling(this::samplingHandler) + .build(); + client.initialize(); + } + + @Override + McpSyncClient client() { + return client; + } +} diff --git a/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/McpSdkStreamableSamplingTest.java b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/McpSdkStreamableSamplingTest.java new file mode 100644 index 00000000..38be438b --- /dev/null +++ b/tests/mcp/src/test/java/io/helidon/extensions/mcp/tests/McpSdkStreamableSamplingTest.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.extensions.mcp.tests; + +import io.helidon.webserver.WebServer; +import io.helidon.webserver.testing.junit5.ServerTest; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; + +@ServerTest +class McpSdkStreamableSamplingTest extends AbstractMcpSdkSamplingTest { + private final McpSyncClient client; + + McpSdkStreamableSamplingTest(WebServer server) { + client = McpClient.sync(streamable(server.port())) + .sampling(this::samplingHandler) + .build(); + client.initialize(); + } + + @Override + McpSyncClient client() { + return client; + } +}