From 92877febde03406fe99495c45b5cb582b9cd553d Mon Sep 17 00:00:00 2001 From: Volodymyr Shkolka Date: Sat, 13 Jul 2024 14:07:42 +0300 Subject: [PATCH] Added Query iterator method Continuation (#89) --- Directory.Packages.props | 2 +- .../SearchQueryIteratorLongKeyTests.cs | 256 ++++++++++++++++++ .../SearchQueryIteratorStringKeyTests.cs | 229 ++++++++++++++++ Milvus.Client/MilvusCollection.Entity.cs | 86 +++--- 4 files changed, 537 insertions(+), 36 deletions(-) create mode 100644 Milvus.Client.Tests/SearchQueryIteratorLongKeyTests.cs create mode 100644 Milvus.Client.Tests/SearchQueryIteratorStringKeyTests.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 026d961..c707df4 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -3,7 +3,7 @@ - + diff --git a/Milvus.Client.Tests/SearchQueryIteratorLongKeyTests.cs b/Milvus.Client.Tests/SearchQueryIteratorLongKeyTests.cs new file mode 100644 index 0000000..3443994 --- /dev/null +++ b/Milvus.Client.Tests/SearchQueryIteratorLongKeyTests.cs @@ -0,0 +1,256 @@ +using Xunit; + +namespace Milvus.Client.Tests; + +[Collection("Milvus")] +public class SearchQueryIteratorLongKeyTests : IClassFixture, + IAsyncLifetime +{ + private const string CollectionName = nameof(SearchQueryIteratorLongKeyTests); + private readonly DataCollectionFixture _dataCollectionFixture; + private readonly MilvusClient Client; + + public SearchQueryIteratorLongKeyTests(MilvusFixture milvusFixture, DataCollectionFixture dataCollectionFixture) + { + Client = milvusFixture.CreateClient(); + _dataCollectionFixture = dataCollectionFixture; + } + + public Task InitializeAsync() => Task.CompletedTask; + + public Task DisposeAsync() + { + Client.Dispose(); + return Task.CompletedTask; + } + + private MilvusCollection Collection => _dataCollectionFixture.Collection; + + [Fact] + public async Task QueryWithIterator_NoOutputFields() + { + var items = new List + { + new(1, new[] { 10f, 20f }), + new(2, new[] { 30f, 40f }), + new(3, new[] { 50f, 60f }) + }; + + await Collection.InsertAsync( + [ + FieldData.Create("id", items.Select(x => x.Id).ToArray()), + FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray()) + ]); + + var iterator = Collection.QueryWithIteratorAsync(); + + List> results = new(); + await foreach (var result in iterator) + { + results.Add(result); + } + + var returnedItems = results.SelectMany(ExtractItems).ToList(); + Assert.Empty(returnedItems); + } + + [Fact] + public void QueryWithIterator_OffsetNotZero() + { + var queryParameters = new QueryParameters + { + Offset = 1 + }; + + var iterator = Collection.QueryWithIteratorAsync(parameters: queryParameters); + + var exception = Assert.ThrowsAsync(async () => await iterator.GetAsyncEnumerator().MoveNextAsync()); + Assert.NotNull(exception); + } + + [Fact] + public void QueryWithIterator_LimitNotZero() + { + var queryParameters = new QueryParameters + { + Limit = 0 + }; + + var iterator = Collection.QueryWithIteratorAsync(parameters: queryParameters); + + var exception = Assert.ThrowsAsync(async () => await iterator.GetAsyncEnumerator().MoveNextAsync()); + Assert.NotNull(exception); + } + + [Theory] + [InlineData("id in [1, 2, 3]", 1, null)] + [InlineData("id in [1, 2, 3]", 1, 2)] + [InlineData("id in [1, 2, 3]", 2, null)] + [InlineData("id in [1, 2, 3]", 2, 2)] + [InlineData("id in [1, 2, 3]", 1000, null)] + [InlineData("id in [1, 2, 3]", 1000, 2)] + [InlineData(null, 1, null)] + [InlineData(null, 1, 2)] + [InlineData(null, 2, null)] + [InlineData(null, 2, 2)] + [InlineData(null, 1000, null)] + [InlineData(null, 1000, 2)] + public async Task QueryWithIterator(string? expression, int batchSize, int? limit) + { + var items = new List + { + new(1, new[] { 10f, 20f }), + new(2, new[] { 30f, 40f }), + new(3, new[] { 50f, 60f }) + }; + + await Collection.InsertAsync( + [ + FieldData.Create("id", items.Select(x => x.Id).ToArray()), + FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray()) + ]); + + var queryParameters = new QueryParameters + { + OutputFields = { "id", "float_vector" }, + Limit = limit + }; + + var iterator = Collection.QueryWithIteratorAsync( + expression: expression, + batchSize: batchSize, + parameters: queryParameters); + + List> results = new(); + await foreach (var result in iterator) + { + results.Add(result); + } + + var returnedItems = results.SelectMany(ExtractItems).ToArray(); + var expectedItems = items.Take(limit ?? int.MaxValue).ToArray(); + Assert.Equal(expectedItems.Length, returnedItems.Length); + + foreach (var expectedItem in expectedItems) + { + Assert.Contains(expectedItem, returnedItems); + } + } + + IEnumerable ExtractItems(IReadOnlyList result) + { + long rowCount = result.Select(x => x.RowCount).FirstOrDefault(); + + var items = new Item[rowCount]; + for (int i = 0; i < rowCount; i++) + { + items[i] = new Item(); + } + + foreach (var fieldData in result) + { + switch (fieldData.FieldName) + { + case "id": + { + for (int j = 0; j < rowCount; j++) + { + items[j].Id = ((FieldData) fieldData).Data[j]; + } + + break; + } + + case "float_vector": + { + for (int j = 0; j < rowCount; j++) + { + items[j].Vector = ((FloatVectorFieldData) fieldData).Data[j]; + } + + break; + } + } + } + + return items; + } + + #region Nested type: DataCollectionFixture + + public class DataCollectionFixture : IAsyncLifetime + { + public MilvusCollection Collection; + private readonly MilvusClient Client; + + public DataCollectionFixture(MilvusFixture milvusFixture) + { + Client = milvusFixture.CreateClient(); + Collection = Client.GetCollection(CollectionName); + } + + public async Task InitializeAsync() + { + await Collection.DropAsync(); + + await Client.CreateCollectionAsync( + Collection.Name, + new[] + { + FieldSchema.Create("id", isPrimaryKey: true), + FieldSchema.CreateFloatVector("float_vector", 2) + }); + + await Collection.CreateIndexAsync("float_vector", IndexType.Flat, SimilarityMetricType.L2); + await Collection.WaitForIndexBuildAsync("float_vector"); + await Collection.LoadAsync(); + await Collection.WaitForCollectionLoadAsync(); + } + + public Task DisposeAsync() + { + Client.Dispose(); + return Task.CompletedTask; + } + } + + #endregion + + #region Nested type: Item + + public record Item + { + public Item(long id, ReadOnlyMemory vector) + { + Id = id; + Vector = vector; + } + + public Item() + { + } + + public virtual bool Equals(Item? other) + { + return other != null && Id == other.Id && Vector.Span.SequenceEqual(other.Vector.Span); + } + + public long Id { get; set; } + + public ReadOnlyMemory Vector { get; set; } + + public override int GetHashCode() + { + var hashCode = new HashCode(); + hashCode.Add(Id); + foreach (float value in Vector.ToArray()) + { + hashCode.Add(value); + } + + return hashCode.ToHashCode(); + } + } + + #endregion +} diff --git a/Milvus.Client.Tests/SearchQueryIteratorStringKeyTests.cs b/Milvus.Client.Tests/SearchQueryIteratorStringKeyTests.cs new file mode 100644 index 0000000..d9cf379 --- /dev/null +++ b/Milvus.Client.Tests/SearchQueryIteratorStringKeyTests.cs @@ -0,0 +1,229 @@ +using Xunit; + +namespace Milvus.Client.Tests; + +[Collection("Milvus")] +public class SearchQueryIteratorStringKeyTests : IClassFixture, + IAsyncLifetime +{ + private const string CollectionName = nameof(SearchQueryIteratorStringKeyTests); + + private readonly DataCollectionFixture _dataCollectionFixture; + private readonly MilvusClient Client; + + public SearchQueryIteratorStringKeyTests(MilvusFixture milvusFixture, DataCollectionFixture dataCollectionFixture) + { + Client = milvusFixture.CreateClient(); + _dataCollectionFixture = dataCollectionFixture; + } + + public Task InitializeAsync() => Task.CompletedTask; + + public Task DisposeAsync() + { + Client.Dispose(); + return Task.CompletedTask; + } + + private MilvusCollection Collection => _dataCollectionFixture.Collection; + + [Fact] + public async Task QueryWithIterator_NoOutputFields() + { + var items = new List + { + new("1", new[] { 10f, 20f }), + new("2", new[] { 30f, 40f }), + new("3", new[] { 50f, 60f }) + }; + + await Collection.InsertAsync( + [ + FieldData.Create("id", items.Select(x => x.Id).ToArray()), + FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray()) + ]); + + var iterator = Collection.QueryWithIteratorAsync(); + + List> results = new(); + await foreach (var result in iterator) + { + results.Add(result); + } + + var returnedItems = results.SelectMany(ExtractItems).ToList(); + Assert.Empty(returnedItems); + } + + [Theory] + [InlineData("id in ['1', '2', '3']", 1, null)] + [InlineData("id in ['1', '2', '3']", 1, 2)] + [InlineData("id in ['1', '2', '3']", 2, null)] + [InlineData("id in ['1', '2', '3']", 2, 2)] + [InlineData("id in ['1', '2', '3']", 1000, null)] + [InlineData("id in ['1', '2', '3']", 1000, 2)] + [InlineData(null, 1, null)] + [InlineData(null, 1, 2)] + [InlineData(null, 2, null)] + [InlineData(null, 2, 2)] + [InlineData(null, 1000, null)] + [InlineData(null, 1000, 2)] + public async Task QueryWithIterator(string? expression, int batchSize, int? limit) + { + var items = new List + { + new("1", new[] { 10f, 20f }), + new("2", new[] { 30f, 40f }), + new("3", new[] { 50f, 60f }) + }; + + await Collection.InsertAsync( + [ + FieldData.Create("id", items.Select(x => x.Id).ToArray()), + FieldData.CreateFloatVector("float_vector", items.Select(x => x.Vector).ToArray()) + ]); + + var queryParameters = new QueryParameters + { + OutputFields = { "id", "float_vector" }, + Limit = limit + }; + + var iterator = Collection.QueryWithIteratorAsync( + expression: expression, + batchSize: batchSize, + parameters: queryParameters); + + List> results = new(); + await foreach (var result in iterator) + { + results.Add(result); + } + + var returnedItems = results.SelectMany(ExtractItems).ToArray(); + var expectedItems = items.Take(limit ?? int.MaxValue).ToArray(); + Assert.Equal(expectedItems.Length, returnedItems.Length); + + foreach (var expectedItem in expectedItems) + { + Assert.Contains(expectedItem, returnedItems); + } + } + + IEnumerable ExtractItems(IReadOnlyList result) + { + long rowCount = result.Select(x => x.RowCount).FirstOrDefault(); + + var items = new Item[rowCount]; + for (int i = 0; i < rowCount; i++) + { + items[i] = new Item(); + } + + foreach (var fieldData in result) + { + switch (fieldData.FieldName) + { + case "id": + { + for (int j = 0; j < rowCount; j++) + { + items[j].Id = ((FieldData) fieldData).Data[j]; + } + + break; + } + + case "float_vector": + { + for (int j = 0; j < rowCount; j++) + { + items[j].Vector = ((FloatVectorFieldData) fieldData).Data[j]; + } + + break; + } + } + } + + return items; + } + + #region Nested type: DataCollectionFixture + + public class DataCollectionFixture : IAsyncLifetime + { + public MilvusCollection Collection; + private readonly MilvusClient Client; + + public DataCollectionFixture(MilvusFixture milvusFixture) + { + Client = milvusFixture.CreateClient(); + Collection = Client.GetCollection(CollectionName); + } + + public async Task InitializeAsync() + { + await Collection.DropAsync(); + + await Client.CreateCollectionAsync( + Collection.Name, + new[] + { + FieldSchema.CreateVarchar("id", 16, isPrimaryKey: true), + FieldSchema.CreateFloatVector("float_vector", 2) + }); + + await Collection.CreateIndexAsync("float_vector", IndexType.Flat, SimilarityMetricType.L2); + await Collection.WaitForIndexBuildAsync("float_vector"); + await Collection.LoadAsync(); + await Collection.WaitForCollectionLoadAsync(); + } + + public Task DisposeAsync() + { + Client.Dispose(); + return Task.CompletedTask; + } + } + + #endregion + + #region Nested type: Item + + public record Item + { + public Item(string id, ReadOnlyMemory vector) + { + Id = id; + Vector = vector; + } + + public Item() + { + } + + public virtual bool Equals(Item? other) + { + return other != null && Id == other.Id && Vector.Span.SequenceEqual(other.Vector.Span); + } + + public string? Id { get; set; } + + public ReadOnlyMemory Vector { get; set; } + + public override int GetHashCode() + { + var hashCode = new HashCode(); + hashCode.Add(Id); + foreach (float value in Vector.ToArray()) + { + hashCode.Add(value); + } + + return hashCode.ToHashCode(); + } + } + + #endregion +} diff --git a/Milvus.Client/MilvusCollection.Entity.cs b/Milvus.Client/MilvusCollection.Entity.cs index 24b9e44..dd48f2f 100644 --- a/Milvus.Client/MilvusCollection.Entity.cs +++ b/Milvus.Client/MilvusCollection.Entity.cs @@ -5,6 +5,7 @@ using System.Runtime.InteropServices; using System.Text.Json; using Google.Protobuf.Collections; +using KeyValuePair = Milvus.Client.Grpc.KeyValuePair; namespace Milvus.Client; @@ -143,7 +144,7 @@ public async Task DeleteAsync( { Verify.NotNullOrWhiteSpace(expression); - var request = new DeleteRequest + DeleteRequest request = new DeleteRequest { CollectionName = Name, Expr = expression, @@ -390,7 +391,7 @@ public Task FlushAsync(CancellationToken cancellationToken = defaul public async Task> GetPersistentSegmentInfosAsync( CancellationToken cancellationToken = default) { - var request = new GetPersistentSegmentInfoRequest { CollectionName = Name }; + GetPersistentSegmentInfoRequest request = new GetPersistentSegmentInfoRequest { CollectionName = Name }; GetPersistentSegmentInfoResponse response = await _client.InvokeAsync( _client.GrpcClient.GetPersistentSegmentInfoAsync, @@ -429,7 +430,7 @@ public async Task> QueryAsync( PopulateQueryRequestFromParameters(request, parameters); - var response = await _client.InvokeAsync( + QueryResults? response = await _client.InvokeAsync( _client.GrpcClient.QueryAsync, request, static r => r.Status, @@ -460,22 +461,22 @@ public async IAsyncEnumerable> QueryWithIteratorAsync( throw new MilvusException("Not support offset when searching iteration"); } - var describeResponse = await _client.InvokeAsync( + DescribeCollectionResponse? describeResponse = await _client.InvokeAsync( _client.GrpcClient.DescribeCollectionAsync, new DescribeCollectionRequest { CollectionName = Name }, r => r.Status, cancellationToken) .ConfigureAwait(false); - var pkField = describeResponse.Schema.Fields.FirstOrDefault(x => x.IsPrimaryKey); + Grpc.FieldSchema? pkField = describeResponse.Schema.Fields.FirstOrDefault(x => x.IsPrimaryKey); if (pkField == null) { throw new MilvusException("Schema must contain pk field"); } - var isUserRequestPkField = parameters?.OutputFieldsInternal?.Contains(pkField.Name) ?? false; - var userExpression = expression; - var userLimit = parameters?.Limit ?? int.MaxValue; + bool isUserRequestPkField = parameters?.OutputFieldsInternal?.Contains(pkField.Name) ?? false; + string? userExpression = expression; + int userLimit = parameters?.Limit ?? int.MaxValue; QueryRequest request = new() { @@ -486,8 +487,10 @@ public async IAsyncEnumerable> QueryWithIteratorAsync( {userExpression: not null} => userExpression, // If user expression is null and pk field is string {pkField.DataType: DataType.VarChar} => $"{pkField.Name} != ''", - // If user expression is null and pk field is not string - _ => $"{pkField.Name} < {long.MaxValue}", + // If user expression is null and pk field is int + {pkField.DataType: DataType.Int8 or DataType.Int16 or DataType.Int32 or DataType.Int64} => $"{pkField.Name} < {long.MaxValue}", + // If user expression is null and pk field is not string and not int + _ => throw new MilvusException("Unsupported data type for primary key field") } }; @@ -497,17 +500,18 @@ public async IAsyncEnumerable> QueryWithIteratorAsync( if (!isUserRequestPkField) request.OutputFields.Add(pkField.Name); // Replace parameters required for iterator + string iterationBatchSize = Math.Min(batchSize, userLimit).ToString(CultureInfo.InvariantCulture); ReplaceKeyValueItems(request.QueryParams, new Grpc.KeyValuePair {Key = Constants.Iterator, Value = "True"}, new Grpc.KeyValuePair {Key = Constants.ReduceStopForBest, Value = "True"}, - new Grpc.KeyValuePair {Key = Constants.BatchSize, Value = batchSize.ToString(CultureInfo.InvariantCulture)}, - new Grpc.KeyValuePair {Key = Constants.Offset, Value = 0.ToString(CultureInfo.InvariantCulture)}, - new Grpc.KeyValuePair {Key = Constants.Limit, Value = Math.Min(batchSize, userLimit).ToString(CultureInfo.InvariantCulture)}); + new Grpc.KeyValuePair {Key = Constants.BatchSize, Value = iterationBatchSize}, + new Grpc.KeyValuePair {Key = Constants.Offset, Value = "0"}, + new Grpc.KeyValuePair {Key = Constants.Limit, Value = iterationBatchSize}); - var processedItemsCount = 0; - while (!cancellationToken.IsCancellationRequested) + int processedItemsCount = 0; + while (true) { - var response = await _client.InvokeAsync( + QueryResults? response = await _client.InvokeAsync( _client.GrpcClient.QueryAsync, request, static r => r.Status, @@ -516,16 +520,25 @@ public async IAsyncEnumerable> QueryWithIteratorAsync( object? pkLastValue; int processedDuringIterationCount; - var pkFieldsData = response.FieldsData.Single(x => x.FieldId == pkField.FieldID); - if (pkField.DataType == DataType.VarChar) + Grpc.FieldData? pkFieldsData = response.FieldsData.Single(x => x.FieldId == pkField.FieldID); + switch (pkField.DataType) { - pkLastValue = pkFieldsData.Scalars.StringData.Data.LastOrDefault(); - processedDuringIterationCount = pkFieldsData.Scalars.StringData.Data.Count; - } - else - { - pkLastValue = pkFieldsData.Scalars.IntData.Data.LastOrDefault(); - processedDuringIterationCount = pkFieldsData.Scalars.IntData.Data.Count; + case DataType.VarChar: + pkLastValue = pkFieldsData.Scalars.StringData.Data.LastOrDefault(); + processedDuringIterationCount = pkFieldsData.Scalars.StringData.Data.Count; + break; + case DataType.Int8: + case DataType.Int16: + case DataType.Int32: + pkLastValue = pkFieldsData.Scalars.IntData.Data.LastOrDefault(); + processedDuringIterationCount = pkFieldsData.Scalars.IntData.Data.Count; + break; + case DataType.Int64: + pkLastValue = pkFieldsData.Scalars.LongData.Data.LastOrDefault(); + processedDuringIterationCount = pkFieldsData.Scalars.LongData.Data.Count; + break; + default: + throw new MilvusException("Unsupported data type for primary key field"); } // If there are no more items to process, we should break the loop @@ -540,7 +553,7 @@ public async IAsyncEnumerable> QueryWithIteratorAsync( yield return ProcessReturnedFieldData(response.FieldsData); processedItemsCount += processedDuringIterationCount; - var leftItemsCount = userLimit - processedItemsCount; + int leftItemsCount = userLimit - processedItemsCount; // If user limit is reached, we should break the loop if(leftItemsCount <= 0) yield break; @@ -554,13 +567,16 @@ public async IAsyncEnumerable> QueryWithIteratorAsync( Value = Math.Min(batchSize, leftItemsCount).ToString(CultureInfo.InvariantCulture) }); - var nextExpression = pkField.DataType == DataType.VarChar - ? $"{pkField.Name} > '{pkLastValue}'" - : $"{pkField.Name} > {pkLastValue}"; + string nextExpression = pkField.DataType switch + { + DataType.VarChar => $"{pkField.Name} > '{pkLastValue}'", + DataType.Int8 or DataType.Int16 or DataType.Int32 or DataType.Int64 => $"{pkField.Name} > {pkLastValue}", + _ => throw new MilvusException("Unsupported data type for primary key field") + }; if (!string.IsNullOrWhiteSpace(userExpression)) { - nextExpression += $" and {userExpression}"; + nextExpression += $" and ({userExpression})"; } request.Expr = nextExpression; @@ -577,7 +593,7 @@ public async IAsyncEnumerable> QueryWithIteratorAsync( public async Task> GetQuerySegmentInfoAsync( CancellationToken cancellationToken = default) { - var request = new GetQuerySegmentInfoRequest { CollectionName = Name }; + GetQuerySegmentInfoRequest request = new GetQuerySegmentInfoRequest { CollectionName = Name }; GetQuerySegmentInfoResponse response = await _client.InvokeAsync(_client.GrpcClient.GetQuerySegmentInfoAsync, request, static r => r.Status, @@ -780,14 +796,14 @@ ulong CalculateGuaranteeTimestamp( private static void ReplaceKeyValueItems(RepeatedField collection, params Grpc.KeyValuePair[] pairs) { - var obsoleteParameterKeys = pairs.Select(x => x.Key).Distinct().ToArray(); - var obsoleteParameters = collection.Where(x => obsoleteParameterKeys.Contains(x.Key)).ToArray(); - foreach (var field in obsoleteParameters) + string[] obsoleteParameterKeys = pairs.Select(x => x.Key).Distinct().ToArray(); + KeyValuePair[] obsoleteParameters = collection.Where(x => obsoleteParameterKeys.Contains(x.Key)).ToArray(); + foreach (KeyValuePair field in obsoleteParameters) { collection.Remove(field); } - foreach (var pair in pairs) + foreach (KeyValuePair pair in pairs) { collection.Add(pair); }