From 27e56671ee0f98e1f0520606bc27abc08e020fae Mon Sep 17 00:00:00 2001 From: Cheesaa <137966174+Cheesaa@users.noreply.github.com> Date: Wed, 19 Jul 2023 21:06:54 +0800 Subject: [PATCH] Remove metrics TANIMOTO (#994) Signed-off-by: ZhouLingyu <2585511320@qq.com> --- benchmark/hdf5/benchmark_binary.cpp | 4 +--- benchmark/hdf5/benchmark_binary_range.cpp | 4 +--- include/knowhere/comp/index_param.h | 1 - src/common/comp/brute_force.cc | 24 ++----------------- src/common/metric.h | 1 - tests/ut/test_bruteforce.cc | 6 ++--- tests/ut/test_diskann.cc | 2 +- thirdparty/faiss/faiss/IndexBinaryFlat.cpp | 21 +--------------- thirdparty/faiss/faiss/IndexBinaryIVF.cpp | 22 ++--------------- .../faiss/faiss/IndexBinaryIVFThreadSafe.cpp | 19 +-------------- thirdparty/faiss/faiss/MetricType.h | 1 - .../faiss/faiss/utils/binary_distances.cpp | 7 ------ .../faiss/faiss/utils/extra_distances-inl.h | 13 ---------- .../faiss/faiss/utils/extra_distances.cpp | 3 --- 14 files changed, 11 insertions(+), 117 deletions(-) diff --git a/benchmark/hdf5/benchmark_binary.cpp b/benchmark/hdf5/benchmark_binary.cpp index aeb5d3a3c..6f784e641 100644 --- a/benchmark/hdf5/benchmark_binary.cpp +++ b/benchmark/hdf5/benchmark_binary.cpp @@ -105,9 +105,7 @@ class Benchmark_binary : public Benchmark_knowhere, public ::testing::Test { load_hdf5_data(); assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR); - metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING - : (metric_str_ == METRIC_JAC_STR) ? knowhere::metric::JACCARD - : knowhere::metric::TANIMOTO; + metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD; cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); diff --git a/benchmark/hdf5/benchmark_binary_range.cpp b/benchmark/hdf5/benchmark_binary_range.cpp index 3d2e2bf5d..332fbbed9 100644 --- a/benchmark/hdf5/benchmark_binary_range.cpp +++ b/benchmark/hdf5/benchmark_binary_range.cpp @@ -113,9 +113,7 @@ class Benchmark_binary_range : public Benchmark_knowhere, public ::testing::Test #endif assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR); - metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING - : (metric_str_ == METRIC_JAC_STR) ? knowhere::metric::JACCARD - : knowhere::metric::TANIMOTO; + metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD; cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; cfg_[knowhere::meta::RADIUS] = *gt_radius_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 14099f625..df910d493 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -87,7 +87,6 @@ constexpr const char* L2 = "L2"; constexpr const char* COSINE = "COSINE"; constexpr const char* HAMMING = "HAMMING"; constexpr const char* JACCARD = "JACCARD"; -constexpr const char* TANIMOTO = "TANIMOTO"; constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE"; constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE"; } // namespace metric diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index e5d840227..a1f283a89 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -78,17 +78,10 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); break; } - case faiss::METRIC_Jaccard: - case faiss::METRIC_Tanimoto: { + case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances}; binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset); - - if (faiss_metric_type == faiss::METRIC_Tanimoto) { - for (int i = 0; i < topk; i++) { - cur_distances[i] = faiss::Jaccard_2_Tanimoto(cur_distances[i]); - } - } break; } case faiss::METRIC_Hamming: { @@ -181,17 +174,10 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); break; } - case faiss::METRIC_Jaccard: - case faiss::METRIC_Tanimoto: { + case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances}; binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset); - - if (faiss_metric_type == faiss::METRIC_Tanimoto) { - for (int i = 0; i < topk; i++) { - cur_distances[i] = faiss::Jaccard_2_Tanimoto(cur_distances[i]); - } - } break; } case faiss::METRIC_Hamming: { @@ -288,12 +274,6 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset); break; } - case faiss::METRIC_Tanimoto: { - auto cur_query = (const uint8_t*)xq + (dim / 8) * index; - faiss::binary_range_search, float>( - faiss::METRIC_Tanimoto, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset); - break; - } case faiss::METRIC_Hamming: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::binary_range_search, int>(faiss::METRIC_Hamming, cur_query, diff --git a/src/common/metric.h b/src/common/metric.h index f017b0a8f..2ac90090d 100644 --- a/src/common/metric.h +++ b/src/common/metric.h @@ -30,7 +30,6 @@ Str2FaissMetricType(std::string metric) { {metric::COSINE, faiss::MetricType::METRIC_INNER_PRODUCT}, {metric::HAMMING, faiss::MetricType::METRIC_Hamming}, {metric::JACCARD, faiss::MetricType::METRIC_Jaccard}, - {metric::TANIMOTO, faiss::MetricType::METRIC_Tanimoto}, {metric::SUBSTRUCTURE, faiss::MetricType::METRIC_Substructure}, {metric::SUPERSTRUCTURE, faiss::MetricType::METRIC_Superstructure}, }; diff --git a/tests/ut/test_bruteforce.cc b/tests/ut/test_bruteforce.cc index cb8d0814e..fd5c16f78 100644 --- a/tests/ut/test_bruteforce.cc +++ b/tests/ut/test_bruteforce.cc @@ -95,9 +95,8 @@ TEST_CASE("Test Brute Force", "[binary vector]") { const int64_t dim = 1024; const int64_t k = 5; - auto metric = - GENERATE(as{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD, knowhere::metric::TANIMOTO, - knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE); + auto metric = GENERATE(as{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD, + knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE); const auto train_ds = GenBinDataSet(nb, dim); const auto query_ds = CopyBinDataSet(train_ds, nq); @@ -105,7 +104,6 @@ TEST_CASE("Test Brute Force", "[binary vector]") { std::unordered_map radius_map = { {knowhere::metric::HAMMING, 1.0}, {knowhere::metric::JACCARD, 0.1}, - {knowhere::metric::TANIMOTO, 0.1}, }; const knowhere::Json conf = { {knowhere::meta::DIM, dim}, diff --git a/tests/ut/test_diskann.cc b/tests/ut/test_diskann.cc index 7a3eca12d..b9ef67698 100644 --- a/tests/ut/test_diskann.cc +++ b/tests/ut/test_diskann.cc @@ -98,7 +98,7 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { knowhere::Status test_stat; // invalid metric type test_json = test_gen(); - test_json["metric_type"] = knowhere::metric::TANIMOTO; + test_json["metric_type"] = knowhere::metric::JACCARD; test_stat = diskann.Build(*ds_ptr, test_json); REQUIRE(test_stat == knowhere::Status::invalid_metric_type); // raw data path not exist diff --git a/thirdparty/faiss/faiss/IndexBinaryFlat.cpp b/thirdparty/faiss/faiss/IndexBinaryFlat.cpp index 7b73a2e4f..82896a1ef 100644 --- a/thirdparty/faiss/faiss/IndexBinaryFlat.cpp +++ b/thirdparty/faiss/faiss/IndexBinaryFlat.cpp @@ -45,16 +45,10 @@ void IndexBinaryFlat::search( const BitsetView bitset) const { FAISS_THROW_IF_NOT(k > 0); - if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { + if (metric_type == METRIC_Jaccard) { float* D = reinterpret_cast(distances); float_maxheap_array_t res = {size_t(n), size_t(k), labels, D}; binary_knn_hc(METRIC_Jaccard, &res, x, xb.data(), ntotal, code_size, bitset); - - if (metric_type == METRIC_Tanimoto) { - for (int i = 0; i < k * n; i++) { - D[i] = Jaccard_2_Tanimoto(D[i]); - } - } } else if (metric_type == METRIC_Hamming) { int_maxheap_array_t res = {size_t(n), size_t(k), labels, distances}; binary_knn_hc(METRIC_Hamming, &res, x, xb.data(), ntotal, code_size, bitset); @@ -126,19 +120,6 @@ void IndexBinaryFlat::range_search( bitset); break; } - case METRIC_Tanimoto: { - binary_range_search, float>( - METRIC_Tanimoto, - x, - xb.data(), - n, - ntotal, - radius, - code_size, - result, - bitset); - break; - } case METRIC_Hamming: { binary_range_search, int>( METRIC_Hamming, diff --git a/thirdparty/faiss/faiss/IndexBinaryIVF.cpp b/thirdparty/faiss/faiss/IndexBinaryIVF.cpp index ea4cbcae5..18c995a7a 100644 --- a/thirdparty/faiss/faiss/IndexBinaryIVF.cpp +++ b/thirdparty/faiss/faiss/IndexBinaryIVF.cpp @@ -290,7 +290,7 @@ void IndexBinaryIVF::train(idx_t n, const uint8_t* x) { quantizer->reset(); IndexFlat index_tmp; - if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { + if (metric_type == METRIC_Jaccard) { index_tmp = IndexFlat(d, METRIC_Jaccard); } else if ( metric_type == METRIC_Substructure || @@ -858,7 +858,6 @@ BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner( return select_IVFBinaryScannerL2(code_size); } case METRIC_Jaccard: - case METRIC_Tanimoto: if (store_pairs) { return select_IVFBinaryScannerJaccard(code_size); } else { @@ -884,7 +883,7 @@ void IndexBinaryIVF::search_preassigned( bool store_pairs, const IVFSearchParameters* params, const BitsetView bitset) const { - if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { + if (metric_type == METRIC_Jaccard) { if (use_heap) { float* D = new float[k * n]; float* c_dis = new float[n * nprobe]; @@ -901,11 +900,6 @@ void IndexBinaryIVF::search_preassigned( store_pairs, params, bitset); - if (metric_type == METRIC_Tanimoto) { - for (int i = 0; i < k * n; i++) { - D[i] = Jaccard_2_Tanimoto(D[i]); - } - } memcpy(distances, D, sizeof(float) * n * k); delete[] D; delete[] c_dis; @@ -958,20 +952,8 @@ void IndexBinaryIVF::range_search( t0 = getmillisecs(); invlists->prefetch_lists(idx.get(), n * nprobe); - - if (metric_type == METRIC_Tanimoto) { - radius = Tanimoto_2_Jaccard(radius); - } - range_search_preassigned( n, x, radius, idx.get(), coarse_dis.get(), res, bitset); - - if (metric_type == METRIC_Tanimoto) { - for (auto i = 0; i < res->lims[n]; i++) { - res->distances[i] = Jaccard_2_Tanimoto(res->distances[i]); - } - } - indexIVF_stats.search_time += getmillisecs() - t0; } diff --git a/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp b/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp index d21997d44..07c6778ba 100644 --- a/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp +++ b/thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp @@ -652,7 +652,7 @@ void IndexBinaryIVF::search_preassigned_thread_safe( const IVFSearchParameters* params, const size_t nprobe, const BitsetView bitset) const { - if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) { + if (metric_type == METRIC_Jaccard) { if (use_heap) { float* D = new float[k * n]; float* c_dis = new float[n * nprobe]; @@ -670,11 +670,6 @@ void IndexBinaryIVF::search_preassigned_thread_safe( params, nprobe, bitset); - if (metric_type == METRIC_Tanimoto) { - for (int i = 0; i < k * n; i++) { - D[i] = Jaccard_2_Tanimoto(D[i]); - } - } memcpy(distances, D, sizeof(float) * n * k); delete[] D; delete[] c_dis; @@ -747,20 +742,8 @@ void IndexBinaryIVF::range_search_thread_safe( t0 = getmillisecs(); invlists->prefetch_lists(idx.get(), n * nprobe); - - if (metric_type == METRIC_Tanimoto) { - radius = Tanimoto_2_Jaccard(radius); - } - range_search_preassigned_thread_safe( n, x, radius, idx.get(), coarse_dis.get(), res, nprobe, bitset); - - if (metric_type == METRIC_Tanimoto) { - for (auto i = 0; i < res->lims[n]; i++) { - res->distances[i] = Jaccard_2_Tanimoto(res->distances[i]); - } - } - indexIVF_stats.search_time += getmillisecs() - t0; } diff --git a/thirdparty/faiss/faiss/MetricType.h b/thirdparty/faiss/faiss/MetricType.h index 8f130c1f8..068335c8e 100644 --- a/thirdparty/faiss/faiss/MetricType.h +++ b/thirdparty/faiss/faiss/MetricType.h @@ -26,7 +26,6 @@ enum MetricType { /// metric_arg METRIC_Jaccard, - METRIC_Tanimoto, METRIC_Hamming, METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1 METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0 diff --git a/thirdparty/faiss/faiss/utils/binary_distances.cpp b/thirdparty/faiss/faiss/utils/binary_distances.cpp index b4409d464..c2af28948 100644 --- a/thirdparty/faiss/faiss/utils/binary_distances.cpp +++ b/thirdparty/faiss/faiss/utils/binary_distances.cpp @@ -587,8 +587,6 @@ void binary_range_search( RangeSearchResult* res, const BitsetView bitset) { switch (metric_type) { - case METRIC_Tanimoto: - radius = Tanimoto_2_Jaccard(radius); case METRIC_Jaccard: { { switch (code_size) { @@ -614,11 +612,6 @@ void binary_range_search( break; } } - if (METRIC_Tanimoto == metric_type) { - for (auto i = 0; i < res->lims[na]; i++) { - res->distances[i] = Jaccard_2_Tanimoto(res->distances[i]); - } - } break; } diff --git a/thirdparty/faiss/faiss/utils/extra_distances-inl.h b/thirdparty/faiss/faiss/utils/extra_distances-inl.h index 90662fd19..f371f3e22 100644 --- a/thirdparty/faiss/faiss/utils/extra_distances-inl.h +++ b/thirdparty/faiss/faiss/utils/extra_distances-inl.h @@ -135,17 +135,4 @@ inline float VectorDistance::operator()( return 1 - accu_num / accu_den; } -template <> -inline float VectorDistance::operator()( - const float* x, - const float* y) const { - float accu_num = 0, accu_den = 0; - for (size_t i = 0; i < d; i++) { - float xi = x[i], yi = y[i]; - accu_num += xi * yi; - accu_den += xi * xi + yi * yi - xi * yi; - } - return -log2(accu_num / accu_den) ; -} - } // namespace faiss diff --git a/thirdparty/faiss/faiss/utils/extra_distances.cpp b/thirdparty/faiss/faiss/utils/extra_distances.cpp index 559452460..edd2a8a07 100644 --- a/thirdparty/faiss/faiss/utils/extra_distances.cpp +++ b/thirdparty/faiss/faiss/utils/extra_distances.cpp @@ -157,7 +157,6 @@ void pairwise_extra_distances( HANDLE_VAR(JensenShannon); HANDLE_VAR(Lp); HANDLE_VAR(Jaccard); - HANDLE_VAR(Tanimoto); #undef HANDLE_VAR default: FAISS_THROW_MSG("metric type not implemented"); @@ -189,7 +188,6 @@ void knn_extra_metrics( HANDLE_VAR(JensenShannon); HANDLE_VAR(Lp); HANDLE_VAR(Jaccard); - HANDLE_VAR(Tanimoto); #undef HANDLE_VAR default: FAISS_THROW_MSG("metric type not implemented"); @@ -217,7 +215,6 @@ DistanceComputer* get_extra_distance_computer( HANDLE_VAR(JensenShannon); HANDLE_VAR(Lp); HANDLE_VAR(Jaccard); - HANDLE_VAR(Tanimoto); #undef HANDLE_VAR default: FAISS_THROW_MSG("metric type not implemented");