Skip to content
This repository was archived by the owner on Aug 16, 2023. It is now read-only.

Commit f0ecc0d

Browse files
authored
Make HNSW search thread safe (#406)
Signed-off-by: liliu-z <[email protected]> Signed-off-by: liliu-z <[email protected]>
1 parent a9b9607 commit f0ecc0d

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

knowhere/index/vector_index/IndexHNSW.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
166166
}
167167
}
168168

169-
index_->setEf(GetIndexParamEf(config));
169+
size_t ef = GetIndexParamEf(config);
170+
hnswlib::SearchParam param{ef};
170171
bool transform = (index_->metric_type_ == 1); // InnerProduct: 1
171172

172173
std::chrono::high_resolution_clock::time_point query_start, query_end;
@@ -179,10 +180,10 @@ if (CheckKeyInConfig(config, meta::QUERY_THREAD_NUM))
179180
auto single_query = (float*)p_data + i * dim;
180181
std::priority_queue<std::pair<float, hnswlib::labeltype>> rst;
181182
if (STATISTICS_LEVEL >= 3) {
182-
rst = index_->searchKnn(single_query, k, bitset, query_stats[i]);
183+
rst = index_->searchKnn(single_query, k, bitset, query_stats[i], &param);
183184
} else {
184185
auto dummy_stat = hnswlib::StatisticsInfo();
185-
rst = index_->searchKnn(single_query, k, bitset, dummy_stat);
186+
rst = index_->searchKnn(single_query, k, bitset, dummy_stat, &param);
186187
}
187188
size_t rst_size = rst.size();
188189

@@ -246,7 +247,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset,
246247

247248
auto range_k = GetIndexParamHNSWK(config);
248249
auto radius = GetMetaRadius(config);
249-
index_->setEf(GetIndexParamEf(config));
250+
size_t ef = GetIndexParamEf(config);
251+
hnswlib::SearchParam param{ef};
250252
bool is_IP = (index_->metric_type_ == 1); // InnerProduct: 1
251253

252254
if (!is_IP) {
@@ -262,7 +264,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset,
262264
auto single_query = (float*)p_data + i * dim;
263265

264266
auto dummy_stat = hnswlib::StatisticsInfo();
265-
auto rst = index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat);
267+
auto rst =
268+
index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat, &param);
266269

267270
for (auto& p : rst) {
268271
result_dist_array[i].push_back(is_IP ? (1 - p.first) : p.first);

thirdparty/hnswlib/hnswlib/hnswalg.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
572572
std::mutex global;
573573
size_t ef_;
574574

575+
// Do not call this to set EF in multi-thread case. This is not thread-safe.
575576
void
576577
setEf(size_t ef) {
577578
ef_ = ef;
@@ -1111,7 +1112,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11111112
};
11121113

11131114
std::priority_queue<std::pair<dist_t, labeltype>>
1114-
searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats) const {
1115+
searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats,
1116+
const SearchParam* param = nullptr) const {
11151117
std::priority_queue<std::pair<dist_t, labeltype>> result;
11161118
if (cur_element_count == 0)
11171119
return result;
@@ -1151,10 +1153,11 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11511153

11521154
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
11531155
top_candidates;
1156+
size_t ef = param ? param->ef_ : this->ef_;
11541157
if (!bitset.empty()) {
1155-
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef_, k), bitset, stats);
1158+
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, k), bitset, stats);
11561159
} else {
1157-
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef_, k), bitset, stats);
1160+
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), bitset, stats);
11581161
}
11591162

11601163
while (top_candidates.size() > k) {
@@ -1170,7 +1173,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11701173

11711174
std::vector<std::pair<dist_t, labeltype>>
11721175
searchRange(const void* query_data, size_t range_k, float radius, const faiss::BitsetView bitset,
1173-
StatisticsInfo& stats) const {
1176+
StatisticsInfo& stats, const SearchParam* param = nullptr) const {
11741177
if (cur_element_count == 0) {
11751178
return {};
11761179
}
@@ -1207,10 +1210,11 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
12071210

12081211
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
12091212
top_candidates;
1213+
size_t ef = param ? param->ef_ : this->ef_;
12101214
if (!bitset.empty()) {
1211-
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef_, range_k), bitset, stats);
1215+
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, range_k), bitset, stats);
12121216
} else {
1213-
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef_, range_k), bitset, stats);
1217+
top_candidates = searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, range_k), bitset, stats);
12141218
}
12151219

12161220
while (top_candidates.size() > range_k) {

thirdparty/hnswlib/hnswlib/hnswlib.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,21 @@ class StatisticsInfo {
175175
std::vector<uint32_t> accessed_points_;
176176
};
177177

178+
struct SearchParam {
179+
size_t ef_;
180+
};
181+
178182
template<typename dist_t>
179183
class AlgorithmInterface {
180184
public:
181185
virtual void addPoint(const void *datapoint, labeltype label)=0;
182186

183-
virtual std::priority_queue<std::pair<dist_t, labeltype >>
184-
searchKnn(const void *, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0;
187+
virtual std::priority_queue<std::pair<dist_t, labeltype>>
188+
searchKnn(const void*, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&, const SearchParam*) const = 0;
185189

186190
virtual std::vector<std::pair<dist_t, labeltype>>
187-
searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0;
191+
searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&,
192+
const SearchParam*) const = 0;
188193

189194
// Return k nearest neighbor in the order of closer fist
190195
virtual std::vector<std::pair<dist_t, labeltype>>
@@ -202,7 +207,7 @@ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t
202207
std::vector<std::pair<dist_t, labeltype>> result;
203208

204209
// here searchKnn returns the result in the order of further first
205-
auto ret = searchKnn(query_data, k, bitset, stats);
210+
auto ret = searchKnn(query_data, k, bitset, stats, nullptr);
206211
{
207212
size_t sz = ret.size();
208213
result.resize(sz);

0 commit comments

Comments
 (0)