Skip to content
This repository was archived by the owner on Aug 16, 2023. It is now read-only.

Commit cbe86cf

Browse files
cydrainliliu-z
authored andcommitted
Adopt new strategy for faiss IVF range search
Signed-off-by: Yudong Cai <[email protected]>
1 parent 7c870a3 commit cbe86cf

File tree

9 files changed

+89
-172
lines changed

9 files changed

+89
-172
lines changed

src/index/ivf/ivf.cc

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -446,13 +446,6 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
446446
Normalize(dataset);
447447
}
448448

449-
auto nprobe = ivf_cfg.nprobe.value();
450-
451-
int parallel_mode = 0;
452-
if (nprobe > 1 && nq <= 4) {
453-
parallel_mode = 1;
454-
}
455-
456449
float radius = ivf_cfg.radius.value();
457450
float range_filter = ivf_cfg.range_filter.value();
458451
bool is_ip = (index_->metric_type == faiss::METRIC_INNER_PRODUCT);
@@ -467,7 +460,6 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
467460
std::vector<size_t> result_lims(nq + 1);
468461

469462
try {
470-
size_t max_codes = 0;
471463
std::vector<folly::Future<folly::Unit>> futs;
472464
futs.reserve(nq);
473465
for (int i = 0; i < nq; ++i) {
@@ -476,15 +468,13 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
476468
faiss::RangeSearchResult res(1);
477469
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
478470
auto cur_data = (const uint8_t*)xq + index * dim / 8;
479-
index_->range_search_thread_safe(1, cur_data, radius, &res, nprobe, bitset);
471+
index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, bitset);
480472
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
481473
auto cur_data = (const float*)xq + index * dim;
482-
index_->range_search_without_codes_thread_safe(1, cur_data, radius, &res, nprobe, parallel_mode,
483-
max_codes, bitset);
474+
index_->range_search_without_codes_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset);
484475
} else {
485476
auto cur_data = (const float*)xq + index * dim;
486-
index_->range_search_thread_safe(1, cur_data, radius, &res, nprobe, parallel_mode, max_codes,
487-
bitset);
477+
index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset);
488478
}
489479
auto elem_cnt = res.lims[1];
490480
result_dist_array[index].resize(elem_cnt);

src/index/ivf/ivf_config.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ class IvfConfig : public BaseConfig {
2929
.set_default(8)
3030
.description("number of probes at query time.")
3131
.for_search()
32-
.set_range(1, 65536)
33-
.for_range_search();
32+
.set_range(1, 65536);
3433
}
3534
};
3635

thirdparty/faiss/faiss/IndexBinaryIVF.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,6 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
409409
float radius,
410410
RangeQueryResult& result,
411411
const BitsetView bitset) const override {
412-
size_t nup = 0;
413412
for (size_t j = 0; j < n; j++) {
414413
if (bitset.empty() || !bitset.test(ids[j])) {
415414
float dis = hc.compute(codes);

thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
175175
float radius,
176176
RangeQueryResult& result,
177177
const BitsetView bitset) const override {
178-
size_t nup = 0;
179178
for (size_t j = 0; j < n; j++) {
180179
if (bitset.empty() || !bitset.test(ids[j])) {
181180
float dis = hc.compute(codes);
@@ -801,9 +800,12 @@ void IndexBinaryIVF::range_search_preassigned_thread_safe(
801800
scanner->set_query(x + i * code_size);
802801

803802
RangeQueryResult& qres = pres.new_result(i);
803+
size_t prev_nres = qres.nres;
804804

805805
for (size_t ik = 0; ik < nprobe; ik++) {
806806
scan_list_func(i, ik, qres);
807+
if (qres.nres == prev_nres) break;
808+
prev_nres = qres.nres;
807809
}
808810
}
809811

thirdparty/faiss/faiss/IndexIVF.cpp

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -866,55 +866,20 @@ void IndexIVF::range_search_preassigned(
866866
}
867867
};
868868

869-
if (parallel_mode == 0) {
870869
#pragma omp for
871-
for (idx_t i = 0; i < nx; i++) {
872-
scanner->set_query(x + i * d);
873-
874-
RangeQueryResult& qres = pres.new_result(i);
875-
876-
for (size_t ik = 0; ik < nprobe; ik++) {
877-
scan_list_func(i, ik, qres);
878-
}
879-
}
880-
881-
} else if (parallel_mode == 1) {
882-
for (size_t i = 0; i < nx; i++) {
883-
scanner->set_query(x + i * d);
870+
for (idx_t i = 0; i < nx; i++) {
871+
scanner->set_query(x + i * d);
884872

885-
RangeQueryResult& qres = pres.new_result(i);
886-
887-
#pragma omp for schedule(dynamic)
888-
for (int64_t ik = 0; ik < nprobe; ik++) {
889-
scan_list_func(i, ik, qres);
890-
}
891-
}
892-
} else if (parallel_mode == 2) {
893-
std::vector<RangeQueryResult*> all_qres(nx);
894-
RangeQueryResult* qres = nullptr;
873+
RangeQueryResult& qres = pres.new_result(i);
874+
size_t prev_nres = qres.nres;
895875

896-
#pragma omp for schedule(dynamic)
897-
for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
898-
idx_t i = iik / (idx_t)nprobe;
899-
idx_t ik = iik % (idx_t)nprobe;
900-
if (qres == nullptr || qres->qno != i) {
901-
FAISS_ASSERT(!qres || i > qres->qno);
902-
qres = &pres.new_result(i);
903-
scanner->set_query(x + i * d);
904-
}
905-
scan_list_func(i, ik, *qres);
876+
for (size_t ik = 0; ik < nprobe; ik++) {
877+
scan_list_func(i, ik, qres);
878+
if (qres.nres == prev_nres) break;
879+
prev_nres = qres.nres;
906880
}
907-
} else {
908-
FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
909-
}
910-
if (parallel_mode == 0) {
911-
pres.finalize();
912-
} else {
913-
#pragma omp barrier
914-
#pragma omp single
915-
RangeSearchPartialResult::merge(all_pres, false);
916-
#pragma omp barrier
917881
}
882+
pres.finalize();
918883
}
919884

920885
if (interrupt) {

thirdparty/faiss/faiss/IndexIVF.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ struct IndexIVF : Index, Level1Quantizer {
300300
float radius,
301301
RangeSearchResult* result,
302302
const size_t nprobe,
303-
const int parallel_mode,
304303
const size_t max_codes,
305304
const BitsetView bitset = nullptr) const;
306305

@@ -310,7 +309,6 @@ struct IndexIVF : Index, Level1Quantizer {
310309
float radius,
311310
RangeSearchResult* result,
312311
const size_t nprobe,
313-
const int parallel_mode,
314312
const size_t max_codes,
315313
const BitsetView bitset = nullptr) const;
316314

thirdparty/faiss/faiss/IndexIVFPQ.cpp

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -807,13 +807,11 @@ struct KnnSearchResults {
807807

808808
size_t nup;
809809

810-
inline void add(idx_t j, float dis, const BitsetView bitset = nullptr) {
811-
if (bitset.empty() || !bitset.test(ids[j])) {
812-
if (C::cmp(heap_sim[0], dis)) {
813-
idx_t id = ids ? ids[j] : lo_build(key, j);
814-
heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
815-
nup++;
816-
}
810+
inline void add(idx_t j, float dis) {
811+
if (C::cmp(heap_sim[0], dis)) {
812+
idx_t id = ids ? ids[j] : lo_build(key, j);
813+
heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
814+
nup++;
817815
}
818816
}
819817
};
@@ -827,12 +825,10 @@ struct RangeSearchResults {
827825
float radius;
828826
RangeQueryResult& rres;
829827

830-
inline void add(idx_t j, float dis, const BitsetView bitset = nullptr) {
831-
if (bitset.empty() || !bitset.test(ids[j])) {
832-
if (C::cmp(radius, dis)) {
833-
idx_t id = ids ? ids[j] : lo_build(key, j);
834-
rres.add(dis, id);
835-
}
828+
inline void add(idx_t j, float dis) {
829+
if (C::cmp(radius, dis)) {
830+
idx_t id = ids ? ids[j] : lo_build(key, j);
831+
rres.add(dis, id);
836832
}
837833
}
838834
};
@@ -878,17 +874,18 @@ struct IVFPQScannerT : QueryTables {
878874
SearchResultType& res,
879875
const BitsetView bitset = nullptr) const {
880876
for (size_t j = 0; j < ncode; j++) {
881-
PQDecoder decoder(codes, pq.nbits);
882-
codes += pq.code_size;
883-
float dis = dis0;
884-
const float* tab = sim_table;
885-
886-
for (size_t m = 0; m < pq.M; m++) {
887-
dis += tab[decoder.decode()];
888-
tab += pq.ksub;
889-
}
877+
if (bitset.empty() || !bitset.test(res.ids[j])) {
878+
PQDecoder decoder(codes, pq.nbits);
879+
codes += pq.code_size;
880+
float dis = dis0;
881+
const float* tab = sim_table;
890882

891-
res.add(j, dis, bitset);
883+
for (size_t m = 0; m < pq.M; m++) {
884+
dis += tab[decoder.decode()];
885+
tab += pq.ksub;
886+
}
887+
res.add(j, dis);
888+
}
892889
}
893890
}
894891

@@ -901,18 +898,20 @@ struct IVFPQScannerT : QueryTables {
901898
SearchResultType& res,
902899
const BitsetView bitset = nullptr) const {
903900
for (size_t j = 0; j < ncode; j++) {
904-
PQDecoder decoder(codes, pq.nbits);
905-
codes += pq.code_size;
901+
if (bitset.empty() || !bitset.test(res.ids[j])) {
902+
PQDecoder decoder(codes, pq.nbits);
903+
codes += pq.code_size;
906904

907-
float dis = dis0;
908-
const float* tab = sim_table_2;
905+
float dis = dis0;
906+
const float* tab = sim_table_2;
909907

910-
for (size_t m = 0; m < pq.M; m++) {
911-
int ci = decoder.decode();
912-
dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
913-
tab += pq.ksub;
908+
for (size_t m = 0; m < pq.M; m++) {
909+
int ci = decoder.decode();
910+
dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
911+
tab += pq.ksub;
912+
}
913+
res.add(j, dis);
914914
}
915-
res.add(j, dis, bitset);
916915
}
917916
}
918917

@@ -939,16 +938,18 @@ struct IVFPQScannerT : QueryTables {
939938
}
940939

941940
for (size_t j = 0; j < ncode; j++) {
942-
pq.decode(codes, decoded_vec);
943-
codes += pq.code_size;
941+
if (bitset.empty() || !bitset.test(res.ids[j])) {
942+
pq.decode(codes, decoded_vec);
943+
codes += pq.code_size;
944944

945-
float dis;
946-
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
947-
dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
948-
} else {
949-
dis = fvec_L2sqr(decoded_vec, dvec, d);
945+
float dis;
946+
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
947+
dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
948+
} else {
949+
dis = fvec_L2sqr(decoded_vec, dvec, d);
950+
}
951+
res.add(j, dis);
950952
}
951-
res.add(j, dis, bitset);
952953
}
953954
}
954955

@@ -970,21 +971,22 @@ struct IVFPQScannerT : QueryTables {
970971
HammingComputer hc(q_code.data(), code_size);
971972

972973
for (size_t j = 0; j < ncode; j++) {
973-
const uint8_t* b_code = codes;
974-
int hd = hc.compute(b_code);
975-
if (hd < ht) {
976-
n_hamming_pass++;
977-
PQDecoder decoder(codes, pq.nbits);
978-
979-
float dis = dis0;
980-
const float* tab = sim_table;
981-
982-
for (size_t m = 0; m < pq.M; m++) {
983-
dis += tab[decoder.decode()];
984-
tab += pq.ksub;
974+
if (bitset.empty() || !bitset.test(res.ids[j])) {
975+
const uint8_t* b_code = codes;
976+
int hd = hc.compute(b_code);
977+
if (hd < ht) {
978+
n_hamming_pass++;
979+
PQDecoder decoder(codes, pq.nbits);
980+
981+
float dis = dis0;
982+
const float* tab = sim_table;
983+
984+
for (size_t m = 0; m < pq.M; m++) {
985+
dis += tab[decoder.decode()];
986+
tab += pq.ksub;
987+
}
988+
res.add(j, dis);
985989
}
986-
987-
res.add(j, dis, bitset);
988990
}
989991
codes += code_size;
990992
}

thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,12 @@ struct IVFScanner : InvertedListScanner {
286286
RangeQueryResult& res,
287287
const BitsetView bitset) const override {
288288
for (size_t j = 0; j < list_size; j++) {
289-
float dis = hc.compute(codes);
290-
if (dis < radius) {
291-
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
292-
res.add(dis, id);
289+
if (bitset.empty() || !bitset.test(ids[j])) {
290+
float dis = hc.compute(codes);
291+
if (dis < radius) {
292+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
293+
res.add(dis, id);
294+
}
293295
}
294296
codes += code_size;
295297
}

0 commit comments

Comments
 (0)