diff --git a/knowhere/index/vector_index/IndexHNSW.cpp b/knowhere/index/vector_index/IndexHNSW.cpp index 20db90126..205cbeac8 100644 --- a/knowhere/index/vector_index/IndexHNSW.cpp +++ b/knowhere/index/vector_index/IndexHNSW.cpp @@ -166,7 +166,8 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais } } - index_->setEf(GetIndexParamEf(config)); + size_t ef = GetIndexParamEf(config); + hnswlib::SearchParam param{ef}; bool transform = (index_->metric_type_ == 1); // InnerProduct: 1 std::chrono::high_resolution_clock::time_point query_start, query_end; @@ -179,10 +180,10 @@ if (CheckKeyInConfig(config, meta::QUERY_THREAD_NUM)) auto single_query = (float*)p_data + i * dim; std::priority_queue> rst; if (STATISTICS_LEVEL >= 3) { - rst = index_->searchKnn(single_query, k, bitset, query_stats[i]); + rst = index_->searchKnn(single_query, k, bitset, query_stats[i], ¶m); } else { auto dummy_stat = hnswlib::StatisticsInfo(); - rst = index_->searchKnn(single_query, k, bitset, dummy_stat); + rst = index_->searchKnn(single_query, k, bitset, dummy_stat, ¶m); } size_t rst_size = rst.size(); @@ -246,7 +247,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset, auto range_k = GetIndexParamHNSWK(config); auto radius = GetMetaRadius(config); - index_->setEf(GetIndexParamEf(config)); + size_t ef = GetIndexParamEf(config); + hnswlib::SearchParam param{ef}; bool is_IP = (index_->metric_type_ == 1); // InnerProduct: 1 if (!is_IP) { @@ -262,7 +264,8 @@ IndexHNSW::QueryByRange(const DatasetPtr& dataset, auto single_query = (float*)p_data + i * dim; auto dummy_stat = hnswlib::StatisticsInfo(); - auto rst = index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat); + auto rst = + index_->searchRange(single_query, range_k, (is_IP ? 1.0f - radius : radius), bitset, dummy_stat, ¶m); for (auto& p : rst) { result_dist_array[i].push_back(is_IP ? (1 - p.first) : p.first); diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index 6f0f43ede..4a8e4b047 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -572,6 +572,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::mutex global; size_t ef_; + // Do not call this to set EF in multi-thread case. This is not thread-safe. void setEf(size_t ef) { ef_ = ef; @@ -1111,7 +1112,8 @@ class HierarchicalNSW : public AlgorithmInterface { }; std::priority_queue> - searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats) const { + searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, StatisticsInfo& stats, + const SearchParam* param = nullptr) const { std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1151,10 +1153,11 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> top_candidates; + size_t ef = param ? param->ef_ : this->ef_; if (!bitset.empty()) { - top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, stats); + top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, k), bitset, stats); } else { - top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, stats); + top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, k), bitset, stats); } while (top_candidates.size() > k) { @@ -1170,7 +1173,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector> searchRange(const void* query_data, size_t range_k, float radius, const faiss::BitsetView bitset, - StatisticsInfo& stats) const { + StatisticsInfo& stats, const SearchParam* param = nullptr) const { if (cur_element_count == 0) { return {}; } @@ -1207,10 +1210,11 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> top_candidates; + size_t ef = param ? param->ef_ : this->ef_; if (!bitset.empty()) { - top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, range_k), bitset, stats); + top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, range_k), bitset, stats); } else { - top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, range_k), bitset, stats); + top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef, range_k), bitset, stats); } while (top_candidates.size() > range_k) { diff --git a/thirdparty/hnswlib/hnswlib/hnswlib.h b/thirdparty/hnswlib/hnswlib/hnswlib.h index 10abc5464..ec2665c6d 100644 --- a/thirdparty/hnswlib/hnswlib/hnswlib.h +++ b/thirdparty/hnswlib/hnswlib/hnswlib.h @@ -175,16 +175,21 @@ class StatisticsInfo { std::vector accessed_points_; }; +struct SearchParam { + size_t ef_; +}; + template class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label)=0; - virtual std::priority_queue> - searchKnn(const void *, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0; + virtual std::priority_queue> + searchKnn(const void*, size_t, const faiss::BitsetView, hnswlib::StatisticsInfo&, const SearchParam*) const = 0; virtual std::vector> - searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&) const = 0; + searchRange(const void*, size_t, float, const faiss::BitsetView, hnswlib::StatisticsInfo&, + const SearchParam*) const = 0; // Return k nearest neighbor in the order of closer fist virtual std::vector> @@ -202,7 +207,7 @@ AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t std::vector> result; // here searchKnn returns the result in the order of further first - auto ret = searchKnn(query_data, k, bitset, stats); + auto ret = searchKnn(query_data, k, bitset, stats, nullptr); { size_t sz = ret.size(); result.resize(sz);