Skip to content

Commit 6b15880

Browse files
committed
Support range search, fix milvus-io#245
1 parent 116a6da commit 6b15880

File tree

5 files changed

+201
-7
lines changed

5 files changed

+201
-7
lines changed

src/impl/MilvusClientImpl.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
#include "TypeUtils.h"
2424
#include "common.pb.h"
25-
#include "milvus.grpc.pb.h"
2625
#include "milvus.pb.h"
2726
#include "schema.pb.h"
2827

@@ -741,7 +740,13 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result
741740

742741
kv_pair = rpc_request.add_search_params();
743742
kv_pair->set_key(milvus::KeyParams());
744-
kv_pair->set_value(arguments.ExtraParams());
743+
// merge extra params with range search
744+
auto json = nlohmann::json::parse(arguments.ExtraParams());
745+
if (arguments.RangeSearch()) {
746+
json["range_filter"] = arguments.RangeFilter();
747+
json["radius"] = arguments.Radius();
748+
}
749+
kv_pair->set_value(json.dump());
745750

746751
rpc_request.set_travel_timestamp(arguments.TravelTimestamp());
747752
rpc_request.set_guarantee_timestamp(arguments.GuaranteeTimestamp());

src/impl/types/SearchArguments.cpp

+58-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "milvus/types/SearchArguments.h"
1818

1919
#include <nlohmann/json.hpp>
20+
#include <utility>
2021

2122
namespace milvus {
2223
namespace {
@@ -28,7 +29,7 @@ struct Validation {
2829
bool required;
2930

3031
Status
31-
Validate(const SearchArguments& data, std::unordered_map<std::string, int64_t> params) const {
32+
Validate(const SearchArguments&, std::unordered_map<std::string, int64_t> params) const {
3233
auto it = params.find(param);
3334
if (it != params.end()) {
3435
auto value = it->second;
@@ -43,7 +44,7 @@ struct Validation {
4344
};
4445

4546
Status
46-
validate(const SearchArguments& data, std::unordered_map<std::string, int64_t> params) {
47+
validate(const SearchArguments& data, const std::unordered_map<std::string, int64_t>& params) {
4748
auto status = Status::OK();
4849
auto validations = {
4950
Validation{"nprobe", 1, 65536, false},
@@ -128,7 +129,7 @@ SearchArguments::TargetVectors() const {
128129

129130
Status
130131
SearchArguments::AddTargetVector(std::string field_name, const std::string& vector) {
131-
return AddTargetVector(field_name, std::string{vector});
132+
return AddTargetVector(std::move(field_name), std::string{vector});
132133
}
133134

134135
Status
@@ -223,6 +224,20 @@ SearchArguments::TopK() const {
223224
return topk_;
224225
}
225226

227+
int64_t
228+
SearchArguments::Nprobe() const {
229+
if (extra_params_.find("nprobe") != extra_params_.end()) {
230+
return extra_params_.at("nprobe");
231+
}
232+
return 1;
233+
}
234+
235+
Status
236+
SearchArguments::SetNprobe(int64_t nprobe) {
237+
extra_params_["nprobe"] = nprobe;
238+
return Status::OK();
239+
}
240+
226241
Status
227242
SearchArguments::SetRoundDecimal(int round_decimal) {
228243
round_decimal_ = round_decimal;
@@ -236,6 +251,12 @@ SearchArguments::RoundDecimal() const {
236251

237252
Status
238253
SearchArguments::SetMetricType(::milvus::MetricType metric_type) {
254+
if (((metric_type == MetricType::IP && metric_type_ == MetricType::L2) ||
255+
(metric_type == MetricType::L2 && metric_type_ == MetricType::IP)) &&
256+
range_search_) {
257+
// switch radius and range_filter
258+
std::swap(radius_, range_filter_);
259+
}
239260
metric_type_ = metric_type;
240261
return Status::OK();
241262
}
@@ -251,7 +272,7 @@ SearchArguments::AddExtraParam(std::string key, int64_t value) {
251272
return Status::OK();
252273
}
253274

254-
const std::string
275+
std::string
255276
SearchArguments::ExtraParams() const {
256277
return ::nlohmann::json(extra_params_).dump();
257278
}
@@ -261,4 +282,37 @@ SearchArguments::Validate() const {
261282
return validate(*this, extra_params_);
262283
}
263284

285+
float
286+
SearchArguments::Radius() const {
287+
return radius_;
288+
}
289+
290+
float
291+
SearchArguments::RangeFilter() const {
292+
return range_filter_;
293+
}
294+
295+
Status
296+
SearchArguments::SetRange(float from, float to) {
297+
auto low = std::min(from, to);
298+
auto high = std::max(from, to);
299+
if (metric_type_ == MetricType::IP) {
300+
radius_ = low;
301+
range_filter_ = high;
302+
range_search_ = true;
303+
} else if (metric_type_ == MetricType::L2) {
304+
radius_ = high;
305+
range_filter_ = low;
306+
range_search_ = true;
307+
} else {
308+
return {StatusCode::INVALID_AGUMENT, "Metric type is not supported"};
309+
}
310+
return Status::OK();
311+
}
312+
313+
bool
314+
SearchArguments::RangeSearch() const {
315+
return range_search_;
316+
}
317+
264318
} // namespace milvus

src/include/milvus/types/SearchArguments.h

+45-1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ class SearchArguments {
164164
int64_t
165165
TopK() const;
166166

167+
/**
168+
* @brief Get nprobe
169+
*/
170+
int64_t
171+
Nprobe() const;
172+
173+
/**
174+
* @brief Set nprobe
175+
*/
176+
Status
177+
SetNprobe(int64_t nlist);
178+
167179
/**
168180
* @brief Specifies the decimal place of the returned results.
169181
*/
@@ -197,7 +209,7 @@ class SearchArguments {
197209
/**
198210
* @brief Get extra param
199211
*/
200-
const std::string
212+
std::string
201213
ExtraParams() const;
202214

203215
/**
@@ -207,6 +219,35 @@ class SearchArguments {
207219
Status
208220
Validate() const;
209221

222+
/**
223+
* @brief Get range radius
224+
* @return
225+
*/
226+
float
227+
Radius() const;
228+
229+
/**
230+
* @brief Get range filter
231+
* @return
232+
*/
233+
float
234+
RangeFilter() const;
235+
236+
/**
237+
* @brief Set range radius
238+
* @param from range radius from
239+
* @param to range radius to
240+
*/
241+
Status
242+
SetRange(float from, float to);
243+
244+
/**
245+
* @brief Get if do range search
246+
* @return
247+
*/
248+
bool
249+
RangeSearch() const;
250+
210251
private:
211252
std::string collection_name_;
212253
std::set<std::string> partition_names_;
@@ -225,6 +266,9 @@ class SearchArguments {
225266
int64_t topk_{1};
226267
int round_decimal_{-1};
227268

269+
float radius_;
270+
float range_filter_;
271+
bool range_search_{false};
228272
::milvus::MetricType metric_type_{::milvus::MetricType::L2};
229273
};
230274

test/st/TestSearch.cpp

+70
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,76 @@ TEST_F(MilvusServerTestSearch, SearchWithoutIndex) {
139139
dropCollection();
140140
}
141141

142+
TEST_F(MilvusServerTestSearch, RangeSearch) {
143+
std::vector<milvus::FieldDataPtr> fields{
144+
std::make_shared<milvus::Int16FieldData>("age", std::vector<int16_t>{12, 13, 14, 15, 16, 17, 18}),
145+
std::make_shared<milvus::VarCharFieldData>(
146+
"name", std::vector<std::string>{"Tom", "Jerry", "Lily", "Foo", "Bar", "Jake", "Jonathon"}),
147+
std::make_shared<milvus::FloatVecFieldData>("face", std::vector<std::vector<float>>{
148+
std::vector<float>{0.1f, 0.2f, 0.3f, 0.4f},
149+
std::vector<float>{0.2f, 0.3f, 0.4f, 0.5f},
150+
std::vector<float>{0.3f, 0.4f, 0.5f, 0.6f},
151+
std::vector<float>{0.4f, 0.5f, 0.6f, 0.7f},
152+
std::vector<float>{0.5f, 0.6f, 0.7f, 0.8f},
153+
std::vector<float>{0.6f, 0.7f, 0.8f, 0.9f},
154+
std::vector<float>{0.7f, 0.8f, 0.9f, 1.0f},
155+
})};
156+
157+
createCollectionAndPartitions(true);
158+
auto dml_results = insertRecords(fields);
159+
loadCollection();
160+
161+
milvus::SearchArguments arguments{};
162+
arguments.SetCollectionName(collection_name);
163+
arguments.AddPartitionName(partition_name);
164+
arguments.SetRange(0.3, 1.0);
165+
arguments.SetTopK(10);
166+
arguments.AddOutputField("age");
167+
arguments.AddOutputField("name");
168+
arguments.AddTargetVector("face", std::vector<float>{0.f, 0.f, 0.f, 0.f});
169+
arguments.AddTargetVector("face", std::vector<float>{1.f, 1.f, 1.f, 1.f});
170+
milvus::SearchResults search_results{};
171+
auto status = client_->Search(arguments, search_results);
172+
EXPECT_EQ(status.Message(), "OK");
173+
EXPECT_TRUE(status.IsOk());
174+
175+
const auto& results = search_results.Results();
176+
EXPECT_EQ(results.size(), 2);
177+
178+
// validate results
179+
auto validateScores = [&results](int firstRet, int secondRet) {
180+
// check score should between range
181+
for (const auto& result : results) {
182+
for (const auto& score : result.Scores()) {
183+
EXPECT_GE(score, 0.3);
184+
EXPECT_LE(score, 1.0);
185+
}
186+
}
187+
EXPECT_EQ(results.at(0).Ids().IntIDArray().size(), firstRet);
188+
EXPECT_EQ(results.at(1).Ids().IntIDArray().size(), secondRet);
189+
};
190+
191+
// valid score in range is 3, 2
192+
validateScores(3, 2);
193+
194+
// add fields, then search again, should be 6 and 4
195+
insertRecords(fields);
196+
loadCollection();
197+
status = client_->Search(arguments, search_results);
198+
EXPECT_TRUE(status.IsOk());
199+
validateScores(6, 4);
200+
201+
// add fields twice, and now it should be 12, 8, as limit is 10, then should be 10, 8
202+
insertRecords(fields);
203+
insertRecords(fields);
204+
loadCollection();
205+
status = client_->Search(arguments, search_results);
206+
EXPECT_TRUE(status.IsOk());
207+
validateScores(10, 8);
208+
209+
dropCollection();
210+
}
211+
142212
TEST_F(MilvusServerTestSearch, SearchWithStringFilter) {
143213
std::vector<milvus::FieldDataPtr> fields{
144214
std::make_shared<milvus::Int16FieldData>("age", std::vector<int16_t>{12, 13}),

test/ut/TestSearchArguments.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,24 @@ TEST_F(SearchArgumentsTest, ValidateTesting) {
170170
EXPECT_TRUE(status.IsOk());
171171
}
172172
}
173+
174+
TEST_F(SearchArgumentsTest, Nprobe) {
175+
milvus::SearchArguments arguments;
176+
arguments.AddExtraParam("nprobe", 10);
177+
EXPECT_EQ(10, arguments.Nprobe());
178+
179+
arguments.SetNprobe(20);
180+
EXPECT_EQ(20, arguments.Nprobe());
181+
}
182+
183+
TEST_F(SearchArgumentsTest, RangeSearchParams) {
184+
milvus::SearchArguments arguments;
185+
arguments.SetMetricType(milvus::MetricType::IP);
186+
arguments.SetRange(0.1, 0.2);
187+
EXPECT_NEAR(0.1, arguments.Radius(), 0.00001);
188+
EXPECT_NEAR(0.2, arguments.RangeFilter(), 0.00001);
189+
190+
arguments.SetMetricType(milvus::MetricType::L2);
191+
EXPECT_NEAR(0.2, arguments.Radius(), 0.00001);
192+
EXPECT_NEAR(0.1, arguments.RangeFilter(), 0.00001);
193+
}

0 commit comments

Comments
 (0)