Skip to content

Commit

Permalink
Add support for radial search on k-NN and Neural query types
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia committed Oct 18, 2024
1 parent a4174df commit cf5ce82
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ This section is for maintaining a changelog for all breaking changes for the cli
## [Unreleased 2.x]

### Added
- Added support for Radial Search (`min_score` & `max_distance`) on the k-NN & Neural query types ([#]())

### Dependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
public class KnnQuery extends QueryBase implements QueryVariant {
private final String field;
private final float[] vector;
private final int k;
@Nullable
private final Integer k;
@Nullable
private final Float minScore;
@Nullable
private final Float maxDistance;
@Nullable
private final Query filter;

Expand All @@ -32,7 +37,9 @@ private KnnQuery(Builder builder) {

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.k = builder.k;
this.minScore = builder.minScore;
this.maxDistance = builder.maxDistance;
this.filter = builder.filter;
}

Expand Down Expand Up @@ -66,13 +73,34 @@ public final float[] vector() {
}

/**
* Required - The number of neighbors the search of each graph will return.
* Optional - The number of neighbors the search of each graph will return.
* @return The number of neighbors to return.
*/
public final int k() {
@Nullable
public final Integer k() {
return this.k;
}

/**
* Optional - The minimum score threshold for the search results
*
* @return The minimum score threshold for the search results
*/
@Nullable
public final Float minScore() {
return this.minScore;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @return The maximum distance threshold for the search results
*/
@Nullable
public final Float maxDistance() {
return this.maxDistance;
}

/**
* Optional - A query to filter the results of the query.
* @return The filter query.
Expand All @@ -88,16 +116,24 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);

// TODO: Implement the rest of the serialization.

generator.writeKey("vector");
generator.writeStartArray();
for (float value : this.vector) {
generator.write(value);
}
generator.writeEnd();

generator.write("k", this.k);
if (this.k != null) {
generator.write("k", this.k);
}

if (this.minScore != null) {
generator.write("min_score", this.minScore);
}

if (this.maxDistance != null) {
generator.write("max_distance", this.maxDistance);
}

if (this.filter != null) {
generator.writeKey("filter");
Expand All @@ -108,7 +144,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return toBuilder(new Builder()).field(field).vector(vector).k(k).filter(filter);
return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter);
}

/**
Expand All @@ -122,6 +158,10 @@ public static class Builder extends QueryBase.AbstractBuilder<Builder> implement
@Nullable
private Integer k;
@Nullable
private Float minScore;
@Nullable
private Float maxDistance;
@Nullable
private Query filter;

/**
Expand All @@ -146,7 +186,7 @@ public Builder vector(@Nullable float[] vector) {
}

/**
* Required - The number of neighbors the search of each graph will return.
* Optional - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
Expand All @@ -156,6 +196,28 @@ public Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - The minimum score threshold for the search results
*
* @param minScore The minimum score threshold for the search results
* @return This builder.
*/
public Builder minScore(@Nullable Float minScore) {
this.minScore = minScore;
return this;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @param maxDistance The maximum distance threshold for the search results
* @return This builder.
*/
public Builder maxDistance(@Nullable Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
Expand Down Expand Up @@ -201,6 +263,8 @@ protected static void setupKnnQueryDeserializer(ObjectDeserializer<Builder> op)
b.vector(vector);
}, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector");
op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
op.add(Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
op.add(Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(Builder::field, JsonpDeserializer.stringDeserializer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ public class NeuralQuery extends QueryBase implements QueryVariant {
private final String field;
private final String queryText;
private final String queryImage;
private final int k;
@Nullable
private final Integer k;
@Nullable
private final Float minScore;
@Nullable
private final Float maxDistance;
@Nullable
private final String modelId;
@Nullable
Expand All @@ -41,7 +46,9 @@ private NeuralQuery(NeuralQuery.Builder builder) {
}
this.queryText = builder.queryText;
this.queryImage = builder.queryImage;
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.k = builder.k;
this.minScore = builder.minScore;
this.maxDistance = builder.maxDistance;
this.modelId = builder.modelId;
this.filter = builder.filter;
}
Expand Down Expand Up @@ -90,17 +97,34 @@ public final String queryImage() {
}

/**
* Required - The number of neighbors to return.
* Optional - The number of neighbors to return.
*
* @return The number of neighbors to return.
*/
public final int k() {
@Nullable
public final Integer k() {
return this.k;
}

/**
* Builder for {@link NeuralQuery}.
* Optional - The minimum score threshold for the search results
*
* @return The minimum score threshold for the search results
*/
@Nullable
public final Float minScore() {
return this.minScore;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @return The maximum distance threshold for the search results
*/
@Nullable
public final Float maxDistance() {
return this.maxDistance;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
Expand Down Expand Up @@ -141,7 +165,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.write("model_id", this.modelId);
}

generator.write("k", this.k);
if (this.k != null) {
generator.write("k", this.k);
}

if (this.minScore != null) {
generator.write("min_score", this.minScore);
}

if (this.maxDistance != null) {
generator.write("max_distance", this.maxDistance);
}

if (this.filter != null) {
generator.writeKey("filter");
Expand All @@ -152,7 +186,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter);
return toBuilder(new Builder()).field(field).queryText(queryText).queryImage(queryImage).k(k).minScore(minScore).maxDistance(maxDistance).modelId(modelId).filter(filter);
}

/**
Expand All @@ -162,8 +196,13 @@ public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builde
private String field;
private String queryText;
private String queryImage;
@Nullable
private Integer k;
@Nullable
private Float minScore;
@Nullable
private Float maxDistance;
@Nullable
private String modelId;
@Nullable
private Query filter;
Expand Down Expand Up @@ -216,7 +255,7 @@ public NeuralQuery.Builder modelId(@Nullable String modelId) {
}

/**
* Required - The number of neighbors to return.
* Optional - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
Expand All @@ -226,6 +265,28 @@ public NeuralQuery.Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - The minimum score threshold for the search results
*
* @param minScore The minimum score threshold for the search results
* @return This builder.
*/
public NeuralQuery.Builder minScore(@Nullable Float minScore) {
this.minScore = minScore;
return this;
}

/**
* Optional - The maximum distance threshold for the search results
*
* @param maxDistance The maximum distance threshold for the search results
* @return This builder.
*/
public NeuralQuery.Builder maxDistance(@Nullable Float maxDistance) {
this.maxDistance = maxDistance;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
Expand Down Expand Up @@ -267,6 +328,8 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
op.add(NeuralQuery.Builder::queryImage, JsonpDeserializer.stringDeserializer(), "query_image");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
op.add(NeuralQuery.Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
Expand Down

0 comments on commit cf5ce82

Please sign in to comment.