diff --git a/Directory.Packages.props b/Directory.Packages.props
index 026d961..c707df4 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -3,7 +3,7 @@
-
+
diff --git a/Milvus.Client.Tests/SearchQueryIteratorLongKeyTests.cs b/Milvus.Client.Tests/SearchQueryIteratorLongKeyTests.cs
new file mode 100644
index 0000000..3443994
--- /dev/null
+++ b/Milvus.Client.Tests/SearchQueryIteratorLongKeyTests.cs
@@ -0,0 +1,256 @@
+using Xunit;
+
+namespace Milvus.Client.Tests;
+
+[Collection("Milvus")]
+public class SearchQueryIteratorLongKeyTests : IClassFixture,
+ IAsyncLifetime
+{
+ private const string CollectionName = nameof(SearchQueryIteratorLongKeyTests);
+ private readonly DataCollectionFixture _dataCollectionFixture;
+ private readonly MilvusClient Client;
+
+ public SearchQueryIteratorLongKeyTests(MilvusFixture milvusFixture, DataCollectionFixture dataCollectionFixture)
+ {
+ Client = milvusFixture.CreateClient();
+ _dataCollectionFixture = dataCollectionFixture;
+ }
+
+ public Task InitializeAsync() => Task.CompletedTask;
+
+ public Task DisposeAsync()
+ {
+ Client.Dispose();
+ return Task.CompletedTask;
+ }
+
+ private MilvusCollection Collection => _dataCollectionFixture.Collection;
+
+ [Fact]
+ public async Task QueryWithIterator_NoOutputFields()
+ {
+ var items = new List-
+ {
+ new(1, new[] { 10f, 20f }),
+ new(2, new[] { 30f, 40f }),
+ new(3, new[] { 50f, 60f })
+ };
+
+ await Collection.InsertAsync(
+ [
+ FieldData.Create("id", items.Select(x => x.Id).ToArray()),
+ FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray())
+ ]);
+
+ var iterator = Collection.QueryWithIteratorAsync();
+
+ List> results = new();
+ await foreach (var result in iterator)
+ {
+ results.Add(result);
+ }
+
+ var returnedItems = results.SelectMany(ExtractItems).ToList();
+ Assert.Empty(returnedItems);
+ }
+
+ [Fact]
+ public void QueryWithIterator_OffsetNotZero()
+ {
+ var queryParameters = new QueryParameters
+ {
+ Offset = 1
+ };
+
+ var iterator = Collection.QueryWithIteratorAsync(parameters: queryParameters);
+
+ var exception = Assert.ThrowsAsync(async () => await iterator.GetAsyncEnumerator().MoveNextAsync());
+ Assert.NotNull(exception);
+ }
+
+ [Fact]
+ public void QueryWithIterator_LimitNotZero()
+ {
+ var queryParameters = new QueryParameters
+ {
+ Limit = 0
+ };
+
+ var iterator = Collection.QueryWithIteratorAsync(parameters: queryParameters);
+
+ var exception = Assert.ThrowsAsync(async () => await iterator.GetAsyncEnumerator().MoveNextAsync());
+ Assert.NotNull(exception);
+ }
+
+ [Theory]
+ [InlineData("id in [1, 2, 3]", 1, null)]
+ [InlineData("id in [1, 2, 3]", 1, 2)]
+ [InlineData("id in [1, 2, 3]", 2, null)]
+ [InlineData("id in [1, 2, 3]", 2, 2)]
+ [InlineData("id in [1, 2, 3]", 1000, null)]
+ [InlineData("id in [1, 2, 3]", 1000, 2)]
+ [InlineData(null, 1, null)]
+ [InlineData(null, 1, 2)]
+ [InlineData(null, 2, null)]
+ [InlineData(null, 2, 2)]
+ [InlineData(null, 1000, null)]
+ [InlineData(null, 1000, 2)]
+ public async Task QueryWithIterator(string? expression, int batchSize, int? limit)
+ {
+ var items = new List
-
+ {
+ new(1, new[] { 10f, 20f }),
+ new(2, new[] { 30f, 40f }),
+ new(3, new[] { 50f, 60f })
+ };
+
+ await Collection.InsertAsync(
+ [
+ FieldData.Create("id", items.Select(x => x.Id).ToArray()),
+ FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray())
+ ]);
+
+ var queryParameters = new QueryParameters
+ {
+ OutputFields = { "id", "float_vector" },
+ Limit = limit
+ };
+
+ var iterator = Collection.QueryWithIteratorAsync(
+ expression: expression,
+ batchSize: batchSize,
+ parameters: queryParameters);
+
+ List> results = new();
+ await foreach (var result in iterator)
+ {
+ results.Add(result);
+ }
+
+ var returnedItems = results.SelectMany(ExtractItems).ToArray();
+ var expectedItems = items.Take(limit ?? int.MaxValue).ToArray();
+ Assert.Equal(expectedItems.Length, returnedItems.Length);
+
+ foreach (var expectedItem in expectedItems)
+ {
+ Assert.Contains(expectedItem, returnedItems);
+ }
+ }
+
+ IEnumerable
- ExtractItems(IReadOnlyList result)
+ {
+ long rowCount = result.Select(x => x.RowCount).FirstOrDefault();
+
+ var items = new Item[rowCount];
+ for (int i = 0; i < rowCount; i++)
+ {
+ items[i] = new Item();
+ }
+
+ foreach (var fieldData in result)
+ {
+ switch (fieldData.FieldName)
+ {
+ case "id":
+ {
+ for (int j = 0; j < rowCount; j++)
+ {
+ items[j].Id = ((FieldData) fieldData).Data[j];
+ }
+
+ break;
+ }
+
+ case "float_vector":
+ {
+ for (int j = 0; j < rowCount; j++)
+ {
+ items[j].Vector = ((FloatVectorFieldData) fieldData).Data[j];
+ }
+
+ break;
+ }
+ }
+ }
+
+ return items;
+ }
+
+ #region Nested type: DataCollectionFixture
+
+ public class DataCollectionFixture : IAsyncLifetime
+ {
+ public MilvusCollection Collection;
+ private readonly MilvusClient Client;
+
+ public DataCollectionFixture(MilvusFixture milvusFixture)
+ {
+ Client = milvusFixture.CreateClient();
+ Collection = Client.GetCollection(CollectionName);
+ }
+
+ public async Task InitializeAsync()
+ {
+ await Collection.DropAsync();
+
+ await Client.CreateCollectionAsync(
+ Collection.Name,
+ new[]
+ {
+ FieldSchema.Create("id", isPrimaryKey: true),
+ FieldSchema.CreateFloatVector("float_vector", 2)
+ });
+
+ await Collection.CreateIndexAsync("float_vector", IndexType.Flat, SimilarityMetricType.L2);
+ await Collection.WaitForIndexBuildAsync("float_vector");
+ await Collection.LoadAsync();
+ await Collection.WaitForCollectionLoadAsync();
+ }
+
+ public Task DisposeAsync()
+ {
+ Client.Dispose();
+ return Task.CompletedTask;
+ }
+ }
+
+ #endregion
+
+ #region Nested type: Item
+
+ public record Item
+ {
+ public Item(long id, ReadOnlyMemory vector)
+ {
+ Id = id;
+ Vector = vector;
+ }
+
+ public Item()
+ {
+ }
+
+ public virtual bool Equals(Item? other)
+ {
+ return other != null && Id == other.Id && Vector.Span.SequenceEqual(other.Vector.Span);
+ }
+
+ public long Id { get; set; }
+
+ public ReadOnlyMemory Vector { get; set; }
+
+ public override int GetHashCode()
+ {
+ var hashCode = new HashCode();
+ hashCode.Add(Id);
+ foreach (float value in Vector.ToArray())
+ {
+ hashCode.Add(value);
+ }
+
+ return hashCode.ToHashCode();
+ }
+ }
+
+ #endregion
+}
diff --git a/Milvus.Client.Tests/SearchQueryIteratorStringKeyTests.cs b/Milvus.Client.Tests/SearchQueryIteratorStringKeyTests.cs
new file mode 100644
index 0000000..d9cf379
--- /dev/null
+++ b/Milvus.Client.Tests/SearchQueryIteratorStringKeyTests.cs
@@ -0,0 +1,229 @@
+using Xunit;
+
+namespace Milvus.Client.Tests;
+
+[Collection("Milvus")]
+public class SearchQueryIteratorStringKeyTests : IClassFixture,
+ IAsyncLifetime
+{
+ private const string CollectionName = nameof(SearchQueryIteratorStringKeyTests);
+
+ private readonly DataCollectionFixture _dataCollectionFixture;
+ private readonly MilvusClient Client;
+
+ public SearchQueryIteratorStringKeyTests(MilvusFixture milvusFixture, DataCollectionFixture dataCollectionFixture)
+ {
+ Client = milvusFixture.CreateClient();
+ _dataCollectionFixture = dataCollectionFixture;
+ }
+
+ public Task InitializeAsync() => Task.CompletedTask;
+
+ public Task DisposeAsync()
+ {
+ Client.Dispose();
+ return Task.CompletedTask;
+ }
+
+ private MilvusCollection Collection => _dataCollectionFixture.Collection;
+
+ [Fact]
+ public async Task QueryWithIterator_NoOutputFields()
+ {
+ var items = new List
-
+ {
+ new("1", new[] { 10f, 20f }),
+ new("2", new[] { 30f, 40f }),
+ new("3", new[] { 50f, 60f })
+ };
+
+ await Collection.InsertAsync(
+ [
+ FieldData.Create("id", items.Select(x => x.Id).ToArray()),
+ FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray())
+ ]);
+
+ var iterator = Collection.QueryWithIteratorAsync();
+
+ List> results = new();
+ await foreach (var result in iterator)
+ {
+ results.Add(result);
+ }
+
+ var returnedItems = results.SelectMany(ExtractItems).ToList();
+ Assert.Empty(returnedItems);
+ }
+
+ [Theory]
+ [InlineData("id in ['1', '2', '3']", 1, null)]
+ [InlineData("id in ['1', '2', '3']", 1, 2)]
+ [InlineData("id in ['1', '2', '3']", 2, null)]
+ [InlineData("id in ['1', '2', '3']", 2, 2)]
+ [InlineData("id in ['1', '2', '3']", 1000, null)]
+ [InlineData("id in ['1', '2', '3']", 1000, 2)]
+ [InlineData(null, 1, null)]
+ [InlineData(null, 1, 2)]
+ [InlineData(null, 2, null)]
+ [InlineData(null, 2, 2)]
+ [InlineData(null, 1000, null)]
+ [InlineData(null, 1000, 2)]
+ public async Task QueryWithIterator(string? expression, int batchSize, int? limit)
+ {
+ var items = new List
-
+ {
+ new("1", new[] { 10f, 20f }),
+ new("2", new[] { 30f, 40f }),
+ new("3", new[] { 50f, 60f })
+ };
+
+ await Collection.InsertAsync(
+ [
+ FieldData.Create("id", items.Select(x => x.Id).ToArray()),
+ FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray())
+ ]);
+
+ var queryParameters = new QueryParameters
+ {
+ OutputFields = { "id", "float_vector" },
+ Limit = limit
+ };
+
+ var iterator = Collection.QueryWithIteratorAsync(
+ expression: expression,
+ batchSize: batchSize,
+ parameters: queryParameters);
+
+ List> results = new();
+ await foreach (var result in iterator)
+ {
+ results.Add(result);
+ }
+
+ var returnedItems = results.SelectMany(ExtractItems).ToArray();
+ var expectedItems = items.Take(limit ?? int.MaxValue).ToArray();
+ Assert.Equal(expectedItems.Length, returnedItems.Length);
+
+ foreach (var expectedItem in expectedItems)
+ {
+ Assert.Contains(expectedItem, returnedItems);
+ }
+ }
+
+ IEnumerable
- ExtractItems(IReadOnlyList result)
+ {
+ long rowCount = result.Select(x => x.RowCount).FirstOrDefault();
+
+ var items = new Item[rowCount];
+ for (int i = 0; i < rowCount; i++)
+ {
+ items[i] = new Item();
+ }
+
+ foreach (var fieldData in result)
+ {
+ switch (fieldData.FieldName)
+ {
+ case "id":
+ {
+ for (int j = 0; j < rowCount; j++)
+ {
+ items[j].Id = ((FieldData) fieldData).Data[j];
+ }
+
+ break;
+ }
+
+ case "float_vector":
+ {
+ for (int j = 0; j < rowCount; j++)
+ {
+ items[j].Vector = ((FloatVectorFieldData) fieldData).Data[j];
+ }
+
+ break;
+ }
+ }
+ }
+
+ return items;
+ }
+
+ #region Nested type: DataCollectionFixture
+
+ public class DataCollectionFixture : IAsyncLifetime
+ {
+ public MilvusCollection Collection;
+ private readonly MilvusClient Client;
+
+ public DataCollectionFixture(MilvusFixture milvusFixture)
+ {
+ Client = milvusFixture.CreateClient();
+ Collection = Client.GetCollection(CollectionName);
+ }
+
+ public async Task InitializeAsync()
+ {
+ await Collection.DropAsync();
+
+ await Client.CreateCollectionAsync(
+ Collection.Name,
+ new[]
+ {
+ FieldSchema.CreateVarchar("id", 16, isPrimaryKey: true),
+ FieldSchema.CreateFloatVector("float_vector", 2)
+ });
+
+ await Collection.CreateIndexAsync("float_vector", IndexType.Flat, SimilarityMetricType.L2);
+ await Collection.WaitForIndexBuildAsync("float_vector");
+ await Collection.LoadAsync();
+ await Collection.WaitForCollectionLoadAsync();
+ }
+
+ public Task DisposeAsync()
+ {
+ Client.Dispose();
+ return Task.CompletedTask;
+ }
+ }
+
+ #endregion
+
+ #region Nested type: Item
+
+ public record Item
+ {
+ public Item(string id, ReadOnlyMemory vector)
+ {
+ Id = id;
+ Vector = vector;
+ }
+
+ public Item()
+ {
+ }
+
+ public virtual bool Equals(Item? other)
+ {
+ return other != null && Id == other.Id && Vector.Span.SequenceEqual(other.Vector.Span);
+ }
+
+ public string? Id { get; set; }
+
+ public ReadOnlyMemory Vector { get; set; }
+
+ public override int GetHashCode()
+ {
+ var hashCode = new HashCode();
+ hashCode.Add(Id);
+ foreach (float value in Vector.ToArray())
+ {
+ hashCode.Add(value);
+ }
+
+ return hashCode.ToHashCode();
+ }
+ }
+
+ #endregion
+}
diff --git a/Milvus.Client/MilvusCollection.Entity.cs b/Milvus.Client/MilvusCollection.Entity.cs
index 24b9e44..dd48f2f 100644
--- a/Milvus.Client/MilvusCollection.Entity.cs
+++ b/Milvus.Client/MilvusCollection.Entity.cs
@@ -5,6 +5,7 @@
using System.Runtime.InteropServices;
using System.Text.Json;
using Google.Protobuf.Collections;
+using KeyValuePair = Milvus.Client.Grpc.KeyValuePair;
namespace Milvus.Client;
@@ -143,7 +144,7 @@ public async Task DeleteAsync(
{
Verify.NotNullOrWhiteSpace(expression);
- var request = new DeleteRequest
+ DeleteRequest request = new DeleteRequest
{
CollectionName = Name,
Expr = expression,
@@ -390,7 +391,7 @@ public Task FlushAsync(CancellationToken cancellationToken = defaul
public async Task> GetPersistentSegmentInfosAsync(
CancellationToken cancellationToken = default)
{
- var request = new GetPersistentSegmentInfoRequest { CollectionName = Name };
+ GetPersistentSegmentInfoRequest request = new GetPersistentSegmentInfoRequest { CollectionName = Name };
GetPersistentSegmentInfoResponse response = await _client.InvokeAsync(
_client.GrpcClient.GetPersistentSegmentInfoAsync,
@@ -429,7 +430,7 @@ public async Task> QueryAsync(
PopulateQueryRequestFromParameters(request, parameters);
- var response = await _client.InvokeAsync(
+ QueryResults? response = await _client.InvokeAsync(
_client.GrpcClient.QueryAsync,
request,
static r => r.Status,
@@ -460,22 +461,22 @@ public async IAsyncEnumerable> QueryWithIteratorAsync(
throw new MilvusException("Not support offset when searching iteration");
}
- var describeResponse = await _client.InvokeAsync(
+ DescribeCollectionResponse? describeResponse = await _client.InvokeAsync(
_client.GrpcClient.DescribeCollectionAsync,
new DescribeCollectionRequest { CollectionName = Name },
r => r.Status,
cancellationToken)
.ConfigureAwait(false);
- var pkField = describeResponse.Schema.Fields.FirstOrDefault(x => x.IsPrimaryKey);
+ Grpc.FieldSchema? pkField = describeResponse.Schema.Fields.FirstOrDefault(x => x.IsPrimaryKey);
if (pkField == null)
{
throw new MilvusException("Schema must contain pk field");
}
- var isUserRequestPkField = parameters?.OutputFieldsInternal?.Contains(pkField.Name) ?? false;
- var userExpression = expression;
- var userLimit = parameters?.Limit ?? int.MaxValue;
+ bool isUserRequestPkField = parameters?.OutputFieldsInternal?.Contains(pkField.Name) ?? false;
+ string? userExpression = expression;
+ int userLimit = parameters?.Limit ?? int.MaxValue;
QueryRequest request = new()
{
@@ -486,8 +487,10 @@ public async IAsyncEnumerable> QueryWithIteratorAsync(
{userExpression: not null} => userExpression,
// If user expression is null and pk field is string
{pkField.DataType: DataType.VarChar} => $"{pkField.Name} != ''",
- // If user expression is null and pk field is not string
- _ => $"{pkField.Name} < {long.MaxValue}",
+ // If user expression is null and pk field is int
+ {pkField.DataType: DataType.Int8 or DataType.Int16 or DataType.Int32 or DataType.Int64} => $"{pkField.Name} < {long.MaxValue}",
+ // If user expression is null and pk field is not string and not int
+ _ => throw new MilvusException("Unsupported data type for primary key field")
}
};
@@ -497,17 +500,18 @@ public async IAsyncEnumerable> QueryWithIteratorAsync(
if (!isUserRequestPkField) request.OutputFields.Add(pkField.Name);
// Replace parameters required for iterator
+ string iterationBatchSize = Math.Min(batchSize, userLimit).ToString(CultureInfo.InvariantCulture);
ReplaceKeyValueItems(request.QueryParams,
new Grpc.KeyValuePair {Key = Constants.Iterator, Value = "True"},
new Grpc.KeyValuePair {Key = Constants.ReduceStopForBest, Value = "True"},
- new Grpc.KeyValuePair {Key = Constants.BatchSize, Value = batchSize.ToString(CultureInfo.InvariantCulture)},
- new Grpc.KeyValuePair {Key = Constants.Offset, Value = 0.ToString(CultureInfo.InvariantCulture)},
- new Grpc.KeyValuePair {Key = Constants.Limit, Value = Math.Min(batchSize, userLimit).ToString(CultureInfo.InvariantCulture)});
+ new Grpc.KeyValuePair {Key = Constants.BatchSize, Value = iterationBatchSize},
+ new Grpc.KeyValuePair {Key = Constants.Offset, Value = "0"},
+ new Grpc.KeyValuePair {Key = Constants.Limit, Value = iterationBatchSize});
- var processedItemsCount = 0;
- while (!cancellationToken.IsCancellationRequested)
+ int processedItemsCount = 0;
+ while (true)
{
- var response = await _client.InvokeAsync(
+ QueryResults? response = await _client.InvokeAsync(
_client.GrpcClient.QueryAsync,
request,
static r => r.Status,
@@ -516,16 +520,25 @@ public async IAsyncEnumerable> QueryWithIteratorAsync(
object? pkLastValue;
int processedDuringIterationCount;
- var pkFieldsData = response.FieldsData.Single(x => x.FieldId == pkField.FieldID);
- if (pkField.DataType == DataType.VarChar)
+ Grpc.FieldData? pkFieldsData = response.FieldsData.Single(x => x.FieldId == pkField.FieldID);
+ switch (pkField.DataType)
{
- pkLastValue = pkFieldsData.Scalars.StringData.Data.LastOrDefault();
- processedDuringIterationCount = pkFieldsData.Scalars.StringData.Data.Count;
- }
- else
- {
- pkLastValue = pkFieldsData.Scalars.IntData.Data.LastOrDefault();
- processedDuringIterationCount = pkFieldsData.Scalars.IntData.Data.Count;
+ case DataType.VarChar:
+ pkLastValue = pkFieldsData.Scalars.StringData.Data.LastOrDefault();
+ processedDuringIterationCount = pkFieldsData.Scalars.StringData.Data.Count;
+ break;
+ case DataType.Int8:
+ case DataType.Int16:
+ case DataType.Int32:
+ pkLastValue = pkFieldsData.Scalars.IntData.Data.LastOrDefault();
+ processedDuringIterationCount = pkFieldsData.Scalars.IntData.Data.Count;
+ break;
+ case DataType.Int64:
+ pkLastValue = pkFieldsData.Scalars.LongData.Data.LastOrDefault();
+ processedDuringIterationCount = pkFieldsData.Scalars.LongData.Data.Count;
+ break;
+ default:
+ throw new MilvusException("Unsupported data type for primary key field");
}
// If there are no more items to process, we should break the loop
@@ -540,7 +553,7 @@ public async IAsyncEnumerable> QueryWithIteratorAsync(
yield return ProcessReturnedFieldData(response.FieldsData);
processedItemsCount += processedDuringIterationCount;
- var leftItemsCount = userLimit - processedItemsCount;
+ int leftItemsCount = userLimit - processedItemsCount;
// If user limit is reached, we should break the loop
if(leftItemsCount <= 0) yield break;
@@ -554,13 +567,16 @@ public async IAsyncEnumerable> QueryWithIteratorAsync(
Value = Math.Min(batchSize, leftItemsCount).ToString(CultureInfo.InvariantCulture)
});
- var nextExpression = pkField.DataType == DataType.VarChar
- ? $"{pkField.Name} > '{pkLastValue}'"
- : $"{pkField.Name} > {pkLastValue}";
+ string nextExpression = pkField.DataType switch
+ {
+ DataType.VarChar => $"{pkField.Name} > '{pkLastValue}'",
+ DataType.Int8 or DataType.Int16 or DataType.Int32 or DataType.Int64 => $"{pkField.Name} > {pkLastValue}",
+ _ => throw new MilvusException("Unsupported data type for primary key field")
+ };
if (!string.IsNullOrWhiteSpace(userExpression))
{
- nextExpression += $" and {userExpression}";
+ nextExpression += $" and ({userExpression})";
}
request.Expr = nextExpression;
@@ -577,7 +593,7 @@ public async IAsyncEnumerable> QueryWithIteratorAsync(
public async Task> GetQuerySegmentInfoAsync(
CancellationToken cancellationToken = default)
{
- var request = new GetQuerySegmentInfoRequest { CollectionName = Name };
+ GetQuerySegmentInfoRequest request = new GetQuerySegmentInfoRequest { CollectionName = Name };
GetQuerySegmentInfoResponse response =
await _client.InvokeAsync(_client.GrpcClient.GetQuerySegmentInfoAsync, request, static r => r.Status,
@@ -780,14 +796,14 @@ ulong CalculateGuaranteeTimestamp(
private static void ReplaceKeyValueItems(RepeatedField collection, params Grpc.KeyValuePair[] pairs)
{
- var obsoleteParameterKeys = pairs.Select(x => x.Key).Distinct().ToArray();
- var obsoleteParameters = collection.Where(x => obsoleteParameterKeys.Contains(x.Key)).ToArray();
- foreach (var field in obsoleteParameters)
+ string[] obsoleteParameterKeys = pairs.Select(x => x.Key).Distinct().ToArray();
+ KeyValuePair[] obsoleteParameters = collection.Where(x => obsoleteParameterKeys.Contains(x.Key)).ToArray();
+ foreach (KeyValuePair field in obsoleteParameters)
{
collection.Remove(field);
}
- foreach (var pair in pairs)
+ foreach (KeyValuePair pair in pairs)
{
collection.Add(pair);
}