diff --git a/CHANGELOG.md b/CHANGELOG.md index 29843289fc..a14368bc5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ This section is for maintaining a changelog for all breaking changes for the cli ## [Unreleased 2.x] ### Added +- Added support for Radial Search (`min_score` & `max_distance`) on the k-NN & Neural query types ([#]()) ### Dependencies diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java index 596752f47c..10d6cd6502 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java @@ -23,7 +23,12 @@ public class KnnQuery extends QueryBase implements QueryVariant { private final String field; private final float[] vector; - private final int k; + @Nullable + private final Integer k; + @Nullable + private final Float minScore; + @Nullable + private final Float maxDistance; @Nullable private final Query filter; @@ -32,7 +37,9 @@ private KnnQuery(Builder builder) { this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field"); this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector"); - this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k"); + this.k = builder.k; + this.minScore = builder.minScore; + this.maxDistance = builder.maxDistance; this.filter = builder.filter; } @@ -66,13 +73,34 @@ public final float[] vector() { } /** - * Required - The number of neighbors the search of each graph will return. + * Optional - The number of neighbors the search of each graph will return. * @return The number of neighbors to return. */ - public final int k() { + @Nullable + public final Integer k() { return this.k; } + /** + * Optional - The minimum score threshold for the search results + * + * @return The minimum score threshold for the search results + */ + @Nullable + public final Float minScore() { + return this.minScore; + } + + /** + * Optional - The maximum distance threshold for the search results + * + * @return The maximum distance threshold for the search results + */ + @Nullable + public final Float maxDistance() { + return this.maxDistance; + } + /** * Optional - A query to filter the results of the query. * @return The filter query. @@ -88,8 +116,6 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { super.serializeInternal(generator, mapper); - // TODO: Implement the rest of the serialization. - generator.writeKey("vector"); generator.writeStartArray(); for (float value : this.vector) { @@ -97,7 +123,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } generator.writeEnd(); - generator.write("k", this.k); + if (this.k != null) { + generator.write("k", this.k); + } + + if (this.minScore != null) { + generator.write("min_score", this.minScore); + } + + if (this.maxDistance != null) { + generator.write("max_distance", this.maxDistance); + } if (this.filter != null) { generator.writeKey("filter"); @@ -108,7 +144,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } public Builder toBuilder() { - return toBuilder(new Builder()).field(field).vector(vector).k(k).filter(filter); + return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter); } /** @@ -122,6 +158,10 @@ public static class Builder extends QueryBase.AbstractBuilder implement @Nullable private Integer k; @Nullable + private Float minScore; + @Nullable + private Float maxDistance; + @Nullable private Query filter; /** @@ -146,7 +186,7 @@ public Builder vector(@Nullable float[] vector) { } /** - * Required - The number of neighbors the search of each graph will return. + * Optional - The number of neighbors to return. * * @param k The number of neighbors to return. * @return This builder. @@ -156,6 +196,28 @@ public Builder k(@Nullable Integer k) { return this; } + /** + * Optional - The minimum score threshold for the search results + * + * @param minScore The minimum score threshold for the search results + * @return This builder. + */ + public Builder minScore(@Nullable Float minScore) { + this.minScore = minScore; + return this; + } + + /** + * Optional - The maximum distance threshold for the search results + * + * @param maxDistance The maximum distance threshold for the search results + * @return This builder. + */ + public Builder maxDistance(@Nullable Float maxDistance) { + this.maxDistance = maxDistance; + return this; + } + /** * Optional - A query to filter the results of the knn query. * @@ -201,6 +263,8 @@ protected static void setupKnnQueryDeserializer(ObjectDeserializer op) b.vector(vector); }, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector"); op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k"); + op.add(Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score"); + op.add(Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance"); op.add(Builder::filter, Query._DESERIALIZER, "filter"); op.setKey(Builder::field, JsonpDeserializer.stringDeserializer()); diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java index 9984f912d0..b015f24baf 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/NeuralQuery.java @@ -26,7 +26,12 @@ public class NeuralQuery extends QueryBase implements QueryVariant { private final String field; private final String queryText; private final String queryImage; - private final int k; + @Nullable + private final Integer k; + @Nullable + private final Float minScore; + @Nullable + private final Float maxDistance; @Nullable private final String modelId; @Nullable @@ -41,7 +46,9 @@ private NeuralQuery(NeuralQuery.Builder builder) { } this.queryText = builder.queryText; this.queryImage = builder.queryImage; - this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k"); + this.k = builder.k; + this.minScore = builder.minScore; + this.maxDistance = builder.maxDistance; this.modelId = builder.modelId; this.filter = builder.filter; } @@ -90,17 +97,34 @@ public final String queryImage() { } /** - * Required - The number of neighbors to return. + * Optional - The number of neighbors to return. * * @return The number of neighbors to return. */ - public final int k() { + @Nullable + public final Integer k() { return this.k; } /** - * Builder for {@link NeuralQuery}. + * Optional - The minimum score threshold for the search results + * + * @return The minimum score threshold for the search results + */ + @Nullable + public final Float minScore() { + return this.minScore; + } + + /** + * Optional - The maximum distance threshold for the search results + * + * @return The maximum distance threshold for the search results */ + @Nullable + public final Float maxDistance() { + return this.maxDistance; + } /** * Optional - The model_id field if the default model for the index or field is set. @@ -141,7 +165,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { generator.write("model_id", this.modelId); } - generator.write("k", this.k); + if (this.k != null) { + generator.write("k", this.k); + } + + if (this.minScore != null) { + generator.write("min_score", this.minScore); + } + + if (this.maxDistance != null) { + generator.write("max_distance", this.maxDistance); + } if (this.filter != null) { generator.writeKey("filter"); @@ -152,7 +186,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } public Builder toBuilder() { - return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter); + return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).minScore(minScore).maxDistance(maxDistance).modelId(modelId).filter(filter); } /** @@ -162,8 +196,13 @@ public static class Builder extends QueryBase.AbstractBuilder