From cbe86cf716dc1969fc716c29ccf8ea63e82a2b4c Mon Sep 17 00:00:00 2001 From: Yudong Cai Date: Fri, 28 Jul 2023 17:21:39 +0800 Subject: [PATCH] Adopt new strategy for faiss IVF range search Signed-off-by: Yudong Cai --- src/index/ivf/ivf.cc | 16 +-- src/index/ivf/ivf_config.h | 3 +- thirdparty/faiss/faiss/IndexBinaryIVF.cpp | 1 - .../faiss/faiss/IndexBinaryIVFThreadSafe.cpp | 4 +- thirdparty/faiss/faiss/IndexIVF.cpp | 53 ++------- thirdparty/faiss/faiss/IndexIVF.h | 2 - thirdparty/faiss/faiss/IndexIVFPQ.cpp | 110 +++++++++--------- .../faiss/faiss/IndexIVFSpectralHash.cpp | 10 +- thirdparty/faiss/faiss/IndexIVFThreadSafe.cpp | 62 ++-------- 9 files changed, 89 insertions(+), 172 deletions(-) diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index c571fa902..5d057a4e0 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -446,13 +446,6 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi Normalize(dataset); } - auto nprobe = ivf_cfg.nprobe.value(); - - int parallel_mode = 0; - if (nprobe > 1 && nq <= 4) { - parallel_mode = 1; - } - float radius = ivf_cfg.radius.value(); float range_filter = ivf_cfg.range_filter.value(); bool is_ip = (index_->metric_type == faiss::METRIC_INNER_PRODUCT); @@ -467,7 +460,6 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi std::vector result_lims(nq + 1); try { - size_t max_codes = 0; std::vector> futs; futs.reserve(nq); for (int i = 0; i < nq; ++i) { @@ -476,15 +468,13 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi faiss::RangeSearchResult res(1); if constexpr (std::is_same::value) { auto cur_data = (const uint8_t*)xq + index * dim / 8; - index_->range_search_thread_safe(1, cur_data, radius, &res, nprobe, bitset); + index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, bitset); } else if constexpr (std::is_same::value) { auto cur_data = (const float*)xq + index * dim; - index_->range_search_without_codes_thread_safe(1, cur_data, radius, &res, nprobe, parallel_mode, - max_codes, bitset); + index_->range_search_without_codes_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset); } else { auto cur_data = (const float*)xq + index * dim; - index_->range_search_thread_safe(1, cur_data, radius, &res, nprobe, parallel_mode, max_codes, - bitset); + index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset); } auto elem_cnt = res.lims[1]; result_dist_array[index].resize(elem_cnt); diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index 4acfab6af..6a7815165 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -29,8 +29,7 @@ class IvfConfig : public BaseConfig { .set_default(8) .description("number of probes at query time.") .for_search() - .set_range(1, 65536) - .for_range_search(); + .set_range(1, 65536); } }; diff --git a/thirdparty/faiss/faiss/IndexBinaryIVF.cpp b/thirdparty/faiss/faiss/IndexBinaryIVF.cpp index 18c995a7a..a0dd5a68e 100644 --- a/thirdparty/faiss/faiss/IndexBinaryIVF.cpp +++ b/thirdparty/faiss/faiss/IndexBinaryIVF.cpp @@ -409,7 +409,6 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner { float radius, RangeQueryResult& result, const BitsetView bitset) const override { - size_t nup = 0; for (size_t j = 0; j < n; j++) { if (bitset.empty() || !bitset.test(ids[j])) { float dis = hc.compute(codes); diff --git a/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp b/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp index 07c6778ba..e356a6327 100644 --- a/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp +++ b/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp @@ -175,7 +175,6 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner { float radius, RangeQueryResult& result, const BitsetView bitset) const override { - size_t nup = 0; for (size_t j = 0; j < n; j++) { if (bitset.empty() || !bitset.test(ids[j])) { float dis = hc.compute(codes); @@ -801,9 +800,12 @@ void IndexBinaryIVF::range_search_preassigned_thread_safe( scanner->set_query(x + i * code_size); RangeQueryResult& qres = pres.new_result(i); + size_t prev_nres = qres.nres; for (size_t ik = 0; ik < nprobe; ik++) { scan_list_func(i, ik, qres); + if (qres.nres == prev_nres) break; + prev_nres = qres.nres; } } diff --git a/thirdparty/faiss/faiss/IndexIVF.cpp b/thirdparty/faiss/faiss/IndexIVF.cpp index 8ef2c650b..de8c51d8d 100644 --- a/thirdparty/faiss/faiss/IndexIVF.cpp +++ b/thirdparty/faiss/faiss/IndexIVF.cpp @@ -866,55 +866,20 @@ void IndexIVF::range_search_preassigned( } }; - if (parallel_mode == 0) { #pragma omp for - for (idx_t i = 0; i < nx; i++) { - scanner->set_query(x + i * d); - - RangeQueryResult& qres = pres.new_result(i); - - for (size_t ik = 0; ik < nprobe; ik++) { - scan_list_func(i, ik, qres); - } - } - - } else if (parallel_mode == 1) { - for (size_t i = 0; i < nx; i++) { - scanner->set_query(x + i * d); + for (idx_t i = 0; i < nx; i++) { + scanner->set_query(x + i * d); - RangeQueryResult& qres = pres.new_result(i); - -#pragma omp for schedule(dynamic) - for (int64_t ik = 0; ik < nprobe; ik++) { - scan_list_func(i, ik, qres); - } - } - } else if (parallel_mode == 2) { - std::vector all_qres(nx); - RangeQueryResult* qres = nullptr; + RangeQueryResult& qres = pres.new_result(i); + size_t prev_nres = qres.nres; -#pragma omp for schedule(dynamic) - for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) { - idx_t i = iik / (idx_t)nprobe; - idx_t ik = iik % (idx_t)nprobe; - if (qres == nullptr || qres->qno != i) { - FAISS_ASSERT(!qres || i > qres->qno); - qres = &pres.new_result(i); - scanner->set_query(x + i * d); - } - scan_list_func(i, ik, *qres); + for (size_t ik = 0; ik < nprobe; ik++) { + scan_list_func(i, ik, qres); + if (qres.nres == prev_nres) break; + prev_nres = qres.nres; } - } else { - FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode); - } - if (parallel_mode == 0) { - pres.finalize(); - } else { -#pragma omp barrier -#pragma omp single - RangeSearchPartialResult::merge(all_pres, false); -#pragma omp barrier } + pres.finalize(); } if (interrupt) { diff --git a/thirdparty/faiss/faiss/IndexIVF.h b/thirdparty/faiss/faiss/IndexIVF.h index 6fd488846..596bf221f 100644 --- a/thirdparty/faiss/faiss/IndexIVF.h +++ b/thirdparty/faiss/faiss/IndexIVF.h @@ -300,7 +300,6 @@ struct IndexIVF : Index, Level1Quantizer { float radius, RangeSearchResult* result, const size_t nprobe, - const int parallel_mode, const size_t max_codes, const BitsetView bitset = nullptr) const; @@ -310,7 +309,6 @@ struct IndexIVF : Index, Level1Quantizer { float radius, RangeSearchResult* result, const size_t nprobe, - const int parallel_mode, const size_t max_codes, const BitsetView bitset = nullptr) const; diff --git a/thirdparty/faiss/faiss/IndexIVFPQ.cpp b/thirdparty/faiss/faiss/IndexIVFPQ.cpp index e63fca5db..be6d6f601 100644 --- a/thirdparty/faiss/faiss/IndexIVFPQ.cpp +++ b/thirdparty/faiss/faiss/IndexIVFPQ.cpp @@ -807,13 +807,11 @@ struct KnnSearchResults { size_t nup; - inline void add(idx_t j, float dis, const BitsetView bitset = nullptr) { - if (bitset.empty() || !bitset.test(ids[j])) { - if (C::cmp(heap_sim[0], dis)) { - idx_t id = ids ? ids[j] : lo_build(key, j); - heap_replace_top(k, heap_sim, heap_ids, dis, id); - nup++; - } + inline void add(idx_t j, float dis) { + if (C::cmp(heap_sim[0], dis)) { + idx_t id = ids ? ids[j] : lo_build(key, j); + heap_replace_top(k, heap_sim, heap_ids, dis, id); + nup++; } } }; @@ -827,12 +825,10 @@ struct RangeSearchResults { float radius; RangeQueryResult& rres; - inline void add(idx_t j, float dis, const BitsetView bitset = nullptr) { - if (bitset.empty() || !bitset.test(ids[j])) { - if (C::cmp(radius, dis)) { - idx_t id = ids ? ids[j] : lo_build(key, j); - rres.add(dis, id); - } + inline void add(idx_t j, float dis) { + if (C::cmp(radius, dis)) { + idx_t id = ids ? ids[j] : lo_build(key, j); + rres.add(dis, id); } } }; @@ -878,17 +874,18 @@ struct IVFPQScannerT : QueryTables { SearchResultType& res, const BitsetView bitset = nullptr) const { for (size_t j = 0; j < ncode; j++) { - PQDecoder decoder(codes, pq.nbits); - codes += pq.code_size; - float dis = dis0; - const float* tab = sim_table; - - for (size_t m = 0; m < pq.M; m++) { - dis += tab[decoder.decode()]; - tab += pq.ksub; - } + if (bitset.empty() || !bitset.test(res.ids[j])) { + PQDecoder decoder(codes, pq.nbits); + codes += pq.code_size; + float dis = dis0; + const float* tab = sim_table; - res.add(j, dis, bitset); + for (size_t m = 0; m < pq.M; m++) { + dis += tab[decoder.decode()]; + tab += pq.ksub; + } + res.add(j, dis); + } } } @@ -901,18 +898,20 @@ struct IVFPQScannerT : QueryTables { SearchResultType& res, const BitsetView bitset = nullptr) const { for (size_t j = 0; j < ncode; j++) { - PQDecoder decoder(codes, pq.nbits); - codes += pq.code_size; + if (bitset.empty() || !bitset.test(res.ids[j])) { + PQDecoder decoder(codes, pq.nbits); + codes += pq.code_size; - float dis = dis0; - const float* tab = sim_table_2; + float dis = dis0; + const float* tab = sim_table_2; - for (size_t m = 0; m < pq.M; m++) { - int ci = decoder.decode(); - dis += sim_table_ptrs[m][ci] - 2 * tab[ci]; - tab += pq.ksub; + for (size_t m = 0; m < pq.M; m++) { + int ci = decoder.decode(); + dis += sim_table_ptrs[m][ci] - 2 * tab[ci]; + tab += pq.ksub; + } + res.add(j, dis); } - res.add(j, dis, bitset); } } @@ -939,16 +938,18 @@ struct IVFPQScannerT : QueryTables { } for (size_t j = 0; j < ncode; j++) { - pq.decode(codes, decoded_vec); - codes += pq.code_size; + if (bitset.empty() || !bitset.test(res.ids[j])) { + pq.decode(codes, decoded_vec); + codes += pq.code_size; - float dis; - if (METRIC_TYPE == METRIC_INNER_PRODUCT) { - dis = dis0 + fvec_inner_product(decoded_vec, qi, d); - } else { - dis = fvec_L2sqr(decoded_vec, dvec, d); + float dis; + if (METRIC_TYPE == METRIC_INNER_PRODUCT) { + dis = dis0 + fvec_inner_product(decoded_vec, qi, d); + } else { + dis = fvec_L2sqr(decoded_vec, dvec, d); + } + res.add(j, dis); } - res.add(j, dis, bitset); } } @@ -970,21 +971,22 @@ struct IVFPQScannerT : QueryTables { HammingComputer hc(q_code.data(), code_size); for (size_t j = 0; j < ncode; j++) { - const uint8_t* b_code = codes; - int hd = hc.compute(b_code); - if (hd < ht) { - n_hamming_pass++; - PQDecoder decoder(codes, pq.nbits); - - float dis = dis0; - const float* tab = sim_table; - - for (size_t m = 0; m < pq.M; m++) { - dis += tab[decoder.decode()]; - tab += pq.ksub; + if (bitset.empty() || !bitset.test(res.ids[j])) { + const uint8_t* b_code = codes; + int hd = hc.compute(b_code); + if (hd < ht) { + n_hamming_pass++; + PQDecoder decoder(codes, pq.nbits); + + float dis = dis0; + const float* tab = sim_table; + + for (size_t m = 0; m < pq.M; m++) { + dis += tab[decoder.decode()]; + tab += pq.ksub; + } + res.add(j, dis); } - - res.add(j, dis, bitset); } codes += code_size; } diff --git a/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp b/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp index cc72309b8..64ba0c045 100644 --- a/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp +++ b/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp @@ -286,10 +286,12 @@ struct IVFScanner : InvertedListScanner { RangeQueryResult& res, const BitsetView bitset) const override { for (size_t j = 0; j < list_size; j++) { - float dis = hc.compute(codes); - if (dis < radius) { - int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; - res.add(dis, id); + if (bitset.empty() || !bitset.test(ids[j])) { + float dis = hc.compute(codes); + if (dis < radius) { + int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; + res.add(dis, id); + } } codes += code_size; } diff --git a/thirdparty/faiss/faiss/IndexIVFThreadSafe.cpp b/thirdparty/faiss/faiss/IndexIVFThreadSafe.cpp index 7081aeef2..53383e4b5 100644 --- a/thirdparty/faiss/faiss/IndexIVFThreadSafe.cpp +++ b/thirdparty/faiss/faiss/IndexIVFThreadSafe.cpp @@ -515,7 +515,6 @@ void IndexIVF::range_search_thread_safe( float radius, RangeSearchResult* result, const size_t nprobe, - const int parallel_mode, const size_t max_codes, const BitsetView bitset) const { const size_t final_nprobe = std::min(nlist, nprobe); @@ -529,8 +528,7 @@ void IndexIVF::range_search_thread_safe( t0 = getmillisecs(); invlists->prefetch_lists(keys.get(), nx * final_nprobe); - IVFSearchParameters params = - gen_search_param(final_nprobe, parallel_mode, max_codes); + IVFSearchParameters params = gen_search_param(final_nprobe, 0, max_codes); range_search_preassigned( nx, @@ -553,7 +551,6 @@ void IndexIVF::range_search_without_codes_thread_safe( float radius, RangeSearchResult* result, const size_t nprobe, - const int parallel_mode, const size_t max_codes, const BitsetView bitset) const { const size_t final_nprobe = std::min(nlist, nprobe); @@ -567,8 +564,7 @@ void IndexIVF::range_search_without_codes_thread_safe( t0 = getmillisecs(); invlists->prefetch_lists(keys.get(), nx * final_nprobe); - IVFSearchParameters params = - gen_search_param(final_nprobe, parallel_mode, max_codes); + IVFSearchParameters params = gen_search_param(final_nprobe, 0, max_codes); range_search_preassigned_without_codes( nx, @@ -675,7 +671,6 @@ void IndexIVF::range_search_preassigned_without_codes( radius, qres, bitset); - } catch (const std::exception& e) { std::lock_guard lock(exception_mutex); exception_string = @@ -684,55 +679,20 @@ void IndexIVF::range_search_preassigned_without_codes( } }; - if (parallel_mode == 0) { #pragma omp for - for (idx_t i = 0; i < nx; i++) { - scanner->set_query(x + i * d); + for (idx_t i = 0; i < nx; i++) { + scanner->set_query(x + i * d); - RangeQueryResult& qres = pres.new_result(i); + RangeQueryResult& qres = pres.new_result(i); + size_t prev_nres = qres.nres; - for (size_t ik = 0; ik < nprobe; ik++) { - scan_list_func(i, ik, qres, bitset); - } + for (size_t ik = 0; ik < nprobe; ik++) { + scan_list_func(i, ik, qres, bitset); + if (qres.nres == prev_nres) break; + prev_nres = qres.nres; } - - } else if (parallel_mode == 1) { - for (size_t i = 0; i < nx; i++) { - scanner->set_query(x + i * d); - - RangeQueryResult& qres = pres.new_result(i); - -#pragma omp for schedule(dynamic) - for (int64_t ik = 0; ik < nprobe; ik++) { - scan_list_func(i, ik, qres, bitset); - } - } - } else if (parallel_mode == 2) { - std::vector all_qres(nx); - RangeQueryResult* qres = nullptr; - -#pragma omp for schedule(dynamic) - for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) { - idx_t i = iik / (idx_t)nprobe; - idx_t ik = iik % (idx_t)nprobe; - if (qres == nullptr || qres->qno != i) { - FAISS_ASSERT(!qres || i > qres->qno); - qres = &pres.new_result(i); - scanner->set_query(x + i * d); - } - scan_list_func(i, ik, *qres, bitset); - } - } else { - FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode); - } - if (parallel_mode == 0) { - pres.finalize(); - } else { -#pragma omp barrier -#pragma omp single - RangeSearchPartialResult::merge(all_pres, false); -#pragma omp barrier } + pres.finalize(); } if (interrupt) {