From f0ecc0d90c339f8dc0abe3c68d5aefabf60848c6 Mon Sep 17 00:00:00 2001
From: liliu-z <105927039+liliu-z@users.noreply.github.com>
Date: Thu, 18 Aug 2022 20:40:47 +0800
Subject: [PATCH] Make HNSW search thread safe (#406)
Signed-off-by: liliu-z
Signed-off-by: liliu-z
---
knowhere/index/vector_index/IndexHNSW.cpp | 13 ++++++++-----
thirdparty/hnswlib/hnswlib/hnswalg.h | 16 ++++++++++------
thirdparty/hnswlib/hnswlib/hnswlib.h | 13 +++++++++----
3 files changed, 27 insertions(+), 15 deletions(-)
diff --git a/knowhere/index/vector_index/IndexHNSW.cpp b/knowhere/index/vector_index/IndexHNSW.cpp
index 20db90126..205cbeac8 100644
--- a/knowhere/index/vector_index/IndexHNSW.cpp
+++ b/knowhere/index/vector_index/IndexHNSW.cpp
@@ -166,7 +166,8 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
}
}
- index_->setEf(GetIndexParamEf(config));
+ size_t ef = GetIndexParamEf(config);
+ hnswlib::SearchParam param{ef};
bool transform = (index_->metric_type_ == 1); // InnerProduct: 1
std::chrono::high_resolution_clock::time_point query_start, query_end;
@@ -179,10 +180,10 @@ if (CheckKeyInConfig(config, meta::QUERY_THREAD_NUM))
auto single_query = (float*)p_data + i * dim;
std::priority_queue> rst;
if (STATISTICS_LEVEL >= 3) {
- rst = index_->searchKnn(single_query, k, bitset, query_stats[i]);
+ rst = index_->searchKnn(single_query, k, bitset, query_stats[i], ¶m);
} else {
auto dummy_stat = hnswlib::StatisticsInfo();
- rst = index_->searchKnn(single_query, k, bitset, dummy_stat);
+ rst = index_->searchKnn(single_query, k, bitset, dummy_stat, ¶m);
}
size_t rst_size = rst.size();
@@ -246,7 +247,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset,
auto range_k = GetIndexParamHNSWK(config);
auto radius = GetMetaRadius(config);
- index_->setEf(GetIndexParamEf(config));
+ size_t ef = GetIndexParamEf(config);
+ hnswlib::SearchParam param{ef};
bool is_IP = (index_->metric_type_ == 1); // InnerProduct: 1
if (!is_IP) {
@@ -262,7 +264,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset,
auto single_query = (float*)p_data + i * dim;
auto dummy_stat = hnswlib::StatisticsInfo();
- auto rst = index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat);
+ auto rst =
+ index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat, ¶m);
for (auto& p : rst) {
result_dist_array[i].push_back(is_IP ? (1 - p.first) : p.first);
diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h
index 6f0f43ede..4a8e4b047 100644
--- a/thirdparty/hnswlib/hnswlib/hnswalg.h
+++ b/thirdparty/hnswlib/hnswlib/hnswalg.h
@@ -572,6 +572,7 @@ class HierarchicalNSW : public AlgorithmInterface {
std::mutex global;
size_t ef_;
+ // Do not call this to set EF in multi-thread case. This is not thread-safe.
void
setEf(size_t ef) {
ef_ = ef;
@@ -1111,7 +1112,8 @@ class HierarchicalNSW : public AlgorithmInterface {
};
std::priority_queue>
- searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats) const {
+ searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats,
+ const SearchParam* param = nullptr) const {
std::priority_queue> result;
if (cur_element_count == 0)
return result;
@@ -1151,10 +1153,11 @@ class HierarchicalNSW : public AlgorithmInterface {
std::priority_queue, std::vector>, CompareByFirst>
top_candidates;
+ size_t ef = param ? param->ef_ : this->ef_;
if (!bitset.empty()) {
- top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, stats);
+ top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, k), bitset, stats);
} else {
- top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, stats);
+ top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, k), bitset, stats);
}
while (top_candidates.size() > k) {
@@ -1170,7 +1173,7 @@ class HierarchicalNSW : public AlgorithmInterface {
std::vector>
searchRange(const void* query_data, size_t range_k, float radius, const faiss::BitsetView bitset,
- StatisticsInfo& stats) const {
+ StatisticsInfo& stats, const SearchParam* param = nullptr) const {
if (cur_element_count == 0) {
return {};
}
@@ -1207,10 +1210,11 @@ class HierarchicalNSW : public AlgorithmInterface {
std::priority_queue, std::vector>, CompareByFirst>
top_candidates;
+ size_t ef = param ? param->ef_ : this->ef_;
if (!bitset.empty()) {
- top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, range_k), bitset, stats);
+ top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, range_k), bitset, stats);
} else {
- top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, range_k), bitset, stats);
+ top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, range_k), bitset, stats);
}
while (top_candidates.size() > range_k) {
diff --git a/thirdparty/hnswlib/hnswlib/hnswlib.h b/thirdparty/hnswlib/hnswlib/hnswlib.h
index 10abc5464..ec2665c6d 100644
--- a/thirdparty/hnswlib/hnswlib/hnswlib.h
+++ b/thirdparty/hnswlib/hnswlib/hnswlib.h
@@ -175,16 +175,21 @@ class StatisticsInfo {
std::vector accessed_points_;
};
+struct SearchParam {
+ size_t ef_;
+};
+
template
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;
- virtual std::priority_queue>
- searchKnn(const void *, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0;
+ virtual std::priority_queue>
+ searchKnn(const void*, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&, const SearchParam*) const = 0;
virtual std::vector>
- searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0;
+ searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&,
+ const SearchParam*) const = 0;
// Return k nearest neighbor in the order of closer fist
virtual std::vector>
@@ -202,7 +207,7 @@ AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t
std::vector> result;
// here searchKnn returns the result in the order of further first
- auto ret = searchKnn(query_data, k, bitset, stats);
+ auto ret = searchKnn(query_data, k, bitset, stats, nullptr);
{
size_t sz = ret.size();
result.resize(sz);