From 70634fc9a73f5c856818cb5d80137f0e8631624b Mon Sep 17 00:00:00 2001 From: zh Wang Date: Mon, 7 Aug 2023 21:57:40 +0800 Subject: [PATCH] Fix cosine bruteforce Signed-off-by: zh Wang --- src/common/comp/brute_force.cc | 33 ++--- thirdparty/faiss/faiss/utils/distances.cpp | 154 +++++++++++++++++++++ thirdparty/faiss/faiss/utils/distances.h | 19 +++ 3 files changed, 188 insertions(+), 18 deletions(-) diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index a1f283a89..7810af423 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -33,11 +33,6 @@ expected BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { std::string metric_str = config[meta::METRIC_TYPE].get(); - bool is_cosine = IsMetricType(metric_str, metric::COSINE); - if (is_cosine) { - Normalize(*base_dataset); - } - auto xb = base_dataset->GetTensor(); auto nb = base_dataset->GetRows(); auto dim = base_dataset->GetDim(); @@ -71,11 +66,13 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset } case faiss::METRIC_INNER_PRODUCT: { auto cur_query = (float*)xq + dim * index; - if (is_cosine) { + faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; + if (IsMetricType(metric_str, metric::COSINE)) { NormalizeVec(cur_query, dim); + faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); + } else { + faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); } - faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; - faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); break; } case faiss::METRIC_Jaccard: { @@ -123,11 +120,6 @@ Status BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset) { std::string metric_str = config[meta::METRIC_TYPE].get(); - bool is_cosine = IsMetricType(metric_str, metric::COSINE); - if (is_cosine) { - Normalize(*base_dataset); - } - auto xb = base_dataset->GetTensor(); auto nb = base_dataset->GetRows(); auto dim = base_dataset->GetDim(); @@ -167,11 +159,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ } case faiss::METRIC_INNER_PRODUCT: { auto cur_query = (float*)xq + dim * index; - if (is_cosine) { + faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; + if (IsMetricType(metric_str, metric::COSINE)) { NormalizeVec(cur_query, dim); + faiss::knn_cosine(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); + } else { + faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); } - faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; - faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); break; } case faiss::METRIC_Jaccard: { @@ -262,10 +256,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da case faiss::METRIC_INNER_PRODUCT: { is_ip = true; auto cur_query = (float*)xq + dim * index; - if (is_cosine) { + if (IsMetricType(metric_str, metric::COSINE)) { NormalizeVec(cur_query, dim); + faiss::range_search_cosine(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset); + } else { + faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, + bitset); } - faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset); break; } case faiss::METRIC_Jaccard: { diff --git a/thirdparty/faiss/faiss/utils/distances.cpp b/thirdparty/faiss/faiss/utils/distances.cpp index b8c5e2dbb..793419227 100644 --- a/thirdparty/faiss/faiss/utils/distances.cpp +++ b/thirdparty/faiss/faiss/utils/distances.cpp @@ -14,6 +14,7 @@ #include #include #include +#include "simd/hook.h" #include @@ -284,6 +285,44 @@ void exhaustive_L2sqr_seq( } } +namespace { +float fvec_cosine(const float* x, const float* y, size_t d) { + return fvec_inner_product(x, y, d) / sqrtf(fvec_norm_L2sqr(y, d)); +} +} // namespace + +template +void exhaustive_cosine_seq( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + ResultHandler& res, + const BitsetView bitset) { + using SingleResultHandler = typename ResultHandler::SingleResultHandler; + int nt = std::min(int(nx), omp_get_max_threads()); + +#pragma omp parallel num_threads(nt) + { + SingleResultHandler resi(res); +#pragma omp for + for (int64_t i = 0; i < nx; i++) { + const float* x_i = x + i * d; + const float* y_j = y; + resi.begin(i); + for (size_t j = 0; j < ny; j++) { + if (bitset.empty() || !bitset.test(j)) { + float disij = fvec_cosine(x_i, y_j, d); + resi.add_result(disij, j); + } + y_j += d; + } + resi.end(); + } + } +} + /** Find the nearest neighbors for nx queries in a set of ny vectors */ template void exhaustive_inner_product_blas( @@ -426,6 +465,76 @@ void exhaustive_L2sqr_blas( } } +template +void exhaustive_cosine_blas( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + ResultHandler& res, + const BitsetView bitset = nullptr) { + // BLAS does not like empty matrices + if (nx == 0 || ny == 0) + return; + + /* block sizes */ + const size_t bs_x = distance_compute_blas_query_bs; + const size_t bs_y = distance_compute_blas_database_bs; + // const size_t bs_x = 16, bs_y = 16; + std::unique_ptr ip_block(new float[bs_x * bs_y]); + std::unique_ptr y_norms(new float[nx]); + std::unique_ptr del2; + + fvec_norms_L2(y_norms.get(), x, d, nx); + + for (size_t i0 = 0; i0 < nx; i0 += bs_x) { + size_t i1 = i0 + bs_x; + if (i1 > nx) + i1 = nx; + + res.begin_multiple(i0, i1); + + for (size_t j0 = 0; j0 < ny; j0 += bs_y) { + size_t j1 = j0 + bs_y; + if (j1 > ny) + j1 = ny; + /* compute the actual dot products */ + { + float one = 1, zero = 0; + FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d; + sgemm_("Transpose", + "Not transpose", + &nyi, + &nxi, + &di, + &one, + y + j0 * d, + &di, + x + i0 * d, + &di, + &zero, + ip_block.get(), + &nyi); + } +#pragma omp parallel for + for (int64_t i = i0; i < i1; i++) { + float* ip_line = ip_block.get() + (i - i0) * (j1 - j0); + + for (size_t j = j0; j < j1; j++) { + float ip = *ip_line; + float dis = ip / y_norms[j]; + *ip_line = dis; + ip_line++; + } + } + res.add_results(j0, j1, ip_block.get(), bitset); + } + res.end_multiple(); + InterruptCallback::check(); + } +} + template static void knn_jaccard_blas( const float* x, @@ -577,6 +686,34 @@ void knn_L2sqr( } } +void knn_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float_minheap_array_t* ha, + const BitsetView bitset) { + if (ha->k < distance_compute_min_k_reservoir) { + HeapResultHandler> res( + ha->nh, ha->val, ha->ids, ha->k); + if (nx < distance_compute_blas_threshold) { + exhaustive_L2sqr_IP_seq(x, y, d, nx, ny, res, fvec_cosine, bitset); + } else { + exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset); + } + } else { + ReservoirResultHandler> res( + ha->nh, ha->val, ha->ids, ha->k); + if (nx < distance_compute_blas_threshold) { + exhaustive_L2sqr_IP_seq( + x, y, d, nx, ny, res, fvec_inner_product, bitset); + } else { + exhaustive_cosine_blas(x, y, d, nx, ny, res, bitset); + } + } +} + struct NopDistanceCorrection { float operator()(float dis, size_t /*qno*/, size_t /*bno*/) const { return dis; @@ -640,6 +777,23 @@ void range_search_inner_product( } } +void range_search_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* res, + const BitsetView bitset) { + RangeSearchResultHandler> resh(res, radius); + if (nx < distance_compute_blas_threshold) { + exhaustive_cosine_seq(x, y, d, nx, ny, resh, bitset); + } else { + exhaustive_cosine_blas(x, y, d, nx, ny, resh, bitset); + } +} + /*************************************************************************** * compute a subset of distances ***************************************************************************/ diff --git a/thirdparty/faiss/faiss/utils/distances.h b/thirdparty/faiss/faiss/utils/distances.h index ebc51f7f2..2d015a3ef 100644 --- a/thirdparty/faiss/faiss/utils/distances.h +++ b/thirdparty/faiss/faiss/utils/distances.h @@ -199,6 +199,15 @@ void knn_L2sqr( const float* y_norm2 = nullptr, const BitsetView bitset = nullptr); +void knn_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float_minheap_array_t* ha, + const BitsetView bitset); + void knn_jaccard( const float* x, const float* y, @@ -265,6 +274,16 @@ void range_search_inner_product( RangeSearchResult* result, const BitsetView bitset = nullptr); +void range_search_cosine( + const float* x, + const float* y, + size_t d, + size_t nx, + size_t ny, + float radius, + RangeSearchResult* result, + const BitsetView bitset = nullptr); + /*************************************************************************** * PQ tables computations ***************************************************************************/