Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Adopt new strategy for faiss IVF range search
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain authored and liliu-z committed Aug 8, 2023
1 parent 7c870a3 commit cbe86cf
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 172 deletions.
16 changes: 3 additions & 13 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,13 +446,6 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
Normalize(dataset);
}

auto nprobe = ivf_cfg.nprobe.value();

int parallel_mode = 0;
if (nprobe > 1 && nq <= 4) {
parallel_mode = 1;
}

float radius = ivf_cfg.radius.value();
float range_filter = ivf_cfg.range_filter.value();
bool is_ip = (index_->metric_type == faiss::METRIC_INNER_PRODUCT);
Expand All @@ -467,7 +460,6 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
std::vector<size_t> result_lims(nq + 1);

try {
size_t max_codes = 0;
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
Expand All @@ -476,15 +468,13 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
faiss::RangeSearchResult res(1);
if constexpr (std::is_same<T, faiss::IndexBinaryIVF>::value) {
auto cur_data = (const uint8_t*)xq + index * dim / 8;
index_->range_search_thread_safe(1, cur_data, radius, &res, nprobe, bitset);
index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, bitset);
} else if constexpr (std::is_same<T, faiss::IndexIVFFlat>::value) {
auto cur_data = (const float*)xq + index * dim;
index_->range_search_without_codes_thread_safe(1, cur_data, radius, &res, nprobe, parallel_mode,
max_codes, bitset);
index_->range_search_without_codes_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset);
} else {
auto cur_data = (const float*)xq + index * dim;
index_->range_search_thread_safe(1, cur_data, radius, &res, nprobe, parallel_mode, max_codes,
bitset);
index_->range_search_thread_safe(1, cur_data, radius, &res, index_->nlist, 0, bitset);
}
auto elem_cnt = res.lims[1];
result_dist_array[index].resize(elem_cnt);
Expand Down
3 changes: 1 addition & 2 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class IvfConfig : public BaseConfig {
.set_default(8)
.description("number of probes at query time.")
.for_search()
.set_range(1, 65536)
.for_range_search();
.set_range(1, 65536);
}
};

Expand Down
1 change: 0 additions & 1 deletion thirdparty/faiss/faiss/IndexBinaryIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
float radius,
RangeQueryResult& result,
const BitsetView bitset) const override {
size_t nup = 0;
for (size_t j = 0; j < n; j++) {
if (bitset.empty() || !bitset.test(ids[j])) {
float dis = hc.compute(codes);
Expand Down
4 changes: 3 additions & 1 deletion thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
float radius,
RangeQueryResult& result,
const BitsetView bitset) const override {
size_t nup = 0;
for (size_t j = 0; j < n; j++) {
if (bitset.empty() || !bitset.test(ids[j])) {
float dis = hc.compute(codes);
Expand Down Expand Up @@ -801,9 +800,12 @@ void IndexBinaryIVF::range_search_preassigned_thread_safe(
scanner->set_query(x + i * code_size);

RangeQueryResult& qres = pres.new_result(i);
size_t prev_nres = qres.nres;

for (size_t ik = 0; ik < nprobe; ik++) {
scan_list_func(i, ik, qres);
if (qres.nres == prev_nres) break;
prev_nres = qres.nres;
}
}

Expand Down
53 changes: 9 additions & 44 deletions thirdparty/faiss/faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,55 +866,20 @@ void IndexIVF::range_search_preassigned(
}
};

if (parallel_mode == 0) {
#pragma omp for
for (idx_t i = 0; i < nx; i++) {
scanner->set_query(x + i * d);

RangeQueryResult& qres = pres.new_result(i);

for (size_t ik = 0; ik < nprobe; ik++) {
scan_list_func(i, ik, qres);
}
}

} else if (parallel_mode == 1) {
for (size_t i = 0; i < nx; i++) {
scanner->set_query(x + i * d);
for (idx_t i = 0; i < nx; i++) {
scanner->set_query(x + i * d);

RangeQueryResult& qres = pres.new_result(i);

#pragma omp for schedule(dynamic)
for (int64_t ik = 0; ik < nprobe; ik++) {
scan_list_func(i, ik, qres);
}
}
} else if (parallel_mode == 2) {
std::vector<RangeQueryResult*> all_qres(nx);
RangeQueryResult* qres = nullptr;
RangeQueryResult& qres = pres.new_result(i);
size_t prev_nres = qres.nres;

#pragma omp for schedule(dynamic)
for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
idx_t i = iik / (idx_t)nprobe;
idx_t ik = iik % (idx_t)nprobe;
if (qres == nullptr || qres->qno != i) {
FAISS_ASSERT(!qres || i > qres->qno);
qres = &pres.new_result(i);
scanner->set_query(x + i * d);
}
scan_list_func(i, ik, *qres);
for (size_t ik = 0; ik < nprobe; ik++) {
scan_list_func(i, ik, qres);
if (qres.nres == prev_nres) break;
prev_nres = qres.nres;
}
} else {
FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
}
if (parallel_mode == 0) {
pres.finalize();
} else {
#pragma omp barrier
#pragma omp single
RangeSearchPartialResult::merge(all_pres, false);
#pragma omp barrier
}
pres.finalize();
}

if (interrupt) {
Expand Down
2 changes: 0 additions & 2 deletions thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ struct IndexIVF : Index, Level1Quantizer {
float radius,
RangeSearchResult* result,
const size_t nprobe,
const int parallel_mode,
const size_t max_codes,
const BitsetView bitset = nullptr) const;

Expand All @@ -310,7 +309,6 @@ struct IndexIVF : Index, Level1Quantizer {
float radius,
RangeSearchResult* result,
const size_t nprobe,
const int parallel_mode,
const size_t max_codes,
const BitsetView bitset = nullptr) const;

Expand Down
110 changes: 56 additions & 54 deletions thirdparty/faiss/faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,13 +807,11 @@ struct KnnSearchResults {

size_t nup;

inline void add(idx_t j, float dis, const BitsetView bitset = nullptr) {
if (bitset.empty() || !bitset.test(ids[j])) {
if (C::cmp(heap_sim[0], dis)) {
idx_t id = ids ? ids[j] : lo_build(key, j);
heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
nup++;
}
inline void add(idx_t j, float dis) {
if (C::cmp(heap_sim[0], dis)) {
idx_t id = ids ? ids[j] : lo_build(key, j);
heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
nup++;
}
}
};
Expand All @@ -827,12 +825,10 @@ struct RangeSearchResults {
float radius;
RangeQueryResult& rres;

inline void add(idx_t j, float dis, const BitsetView bitset = nullptr) {
if (bitset.empty() || !bitset.test(ids[j])) {
if (C::cmp(radius, dis)) {
idx_t id = ids ? ids[j] : lo_build(key, j);
rres.add(dis, id);
}
inline void add(idx_t j, float dis) {
if (C::cmp(radius, dis)) {
idx_t id = ids ? ids[j] : lo_build(key, j);
rres.add(dis, id);
}
}
};
Expand Down Expand Up @@ -878,17 +874,18 @@ struct IVFPQScannerT : QueryTables {
SearchResultType& res,
const BitsetView bitset = nullptr) const {
for (size_t j = 0; j < ncode; j++) {
PQDecoder decoder(codes, pq.nbits);
codes += pq.code_size;
float dis = dis0;
const float* tab = sim_table;

for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
if (bitset.empty() || !bitset.test(res.ids[j])) {
PQDecoder decoder(codes, pq.nbits);
codes += pq.code_size;
float dis = dis0;
const float* tab = sim_table;

res.add(j, dis, bitset);
for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
res.add(j, dis);
}
}
}

Expand All @@ -901,18 +898,20 @@ struct IVFPQScannerT : QueryTables {
SearchResultType& res,
const BitsetView bitset = nullptr) const {
for (size_t j = 0; j < ncode; j++) {
PQDecoder decoder(codes, pq.nbits);
codes += pq.code_size;
if (bitset.empty() || !bitset.test(res.ids[j])) {
PQDecoder decoder(codes, pq.nbits);
codes += pq.code_size;

float dis = dis0;
const float* tab = sim_table_2;
float dis = dis0;
const float* tab = sim_table_2;

for (size_t m = 0; m < pq.M; m++) {
int ci = decoder.decode();
dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
tab += pq.ksub;
for (size_t m = 0; m < pq.M; m++) {
int ci = decoder.decode();
dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
tab += pq.ksub;
}
res.add(j, dis);
}
res.add(j, dis, bitset);
}
}

Expand All @@ -939,16 +938,18 @@ struct IVFPQScannerT : QueryTables {
}

for (size_t j = 0; j < ncode; j++) {
pq.decode(codes, decoded_vec);
codes += pq.code_size;
if (bitset.empty() || !bitset.test(res.ids[j])) {
pq.decode(codes, decoded_vec);
codes += pq.code_size;

float dis;
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
} else {
dis = fvec_L2sqr(decoded_vec, dvec, d);
float dis;
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
} else {
dis = fvec_L2sqr(decoded_vec, dvec, d);
}
res.add(j, dis);
}
res.add(j, dis, bitset);
}
}

Expand All @@ -970,21 +971,22 @@ struct IVFPQScannerT : QueryTables {
HammingComputer hc(q_code.data(), code_size);

for (size_t j = 0; j < ncode; j++) {
const uint8_t* b_code = codes;
int hd = hc.compute(b_code);
if (hd < ht) {
n_hamming_pass++;
PQDecoder decoder(codes, pq.nbits);

float dis = dis0;
const float* tab = sim_table;

for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
if (bitset.empty() || !bitset.test(res.ids[j])) {
const uint8_t* b_code = codes;
int hd = hc.compute(b_code);
if (hd < ht) {
n_hamming_pass++;
PQDecoder decoder(codes, pq.nbits);

float dis = dis0;
const float* tab = sim_table;

for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
res.add(j, dis);
}

res.add(j, dis, bitset);
}
codes += code_size;
}
Expand Down
10 changes: 6 additions & 4 deletions thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,12 @@ struct IVFScanner : InvertedListScanner {
RangeQueryResult& res,
const BitsetView bitset) const override {
for (size_t j = 0; j < list_size; j++) {
float dis = hc.compute(codes);
if (dis < radius) {
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
res.add(dis, id);
if (bitset.empty() || !bitset.test(ids[j])) {
float dis = hc.compute(codes);
if (dis < radius) {
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
res.add(dis, id);
}
}
codes += code_size;
}
Expand Down
Loading

0 comments on commit cbe86cf

Please sign in to comment.