Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Remove metrics TANIMOTO (#994)
Browse files Browse the repository at this point in the history
Signed-off-by: ZhouLingyu <[email protected]>
  • Loading branch information
Cheesaa authored Jul 19, 2023
1 parent f5c0b7b commit 27e5667
Show file tree
Hide file tree
Showing 14 changed files with 11 additions and 117 deletions.
4 changes: 1 addition & 3 deletions benchmark/hdf5/benchmark_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ class Benchmark_binary : public Benchmark_knowhere, public ::testing::Test {
load_hdf5_data<true>();

assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR);
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING
: (metric_str_ == METRIC_JAC_STR) ? knowhere::metric::JACCARD
: knowhere::metric::TANIMOTO;
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD;
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());
Expand Down
4 changes: 1 addition & 3 deletions benchmark/hdf5/benchmark_binary_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ class Benchmark_binary_range : public Benchmark_knowhere, public ::testing::Test
#endif

assert(metric_str_ == METRIC_HAM_STR || metric_str_ == METRIC_JAC_STR || metric_str_ == METRIC_TAN_STR);
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING
: (metric_str_ == METRIC_JAC_STR) ? knowhere::metric::JACCARD
: knowhere::metric::TANIMOTO;
metric_type_ = (metric_str_ == METRIC_HAM_STR) ? knowhere::metric::HAMMING : knowhere::metric::JACCARD;
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
cfg_[knowhere::meta::RADIUS] = *gt_radius_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
Expand Down
1 change: 0 additions & 1 deletion include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ constexpr const char* L2 = "L2";
constexpr const char* COSINE = "COSINE";
constexpr const char* HAMMING = "HAMMING";
constexpr const char* JACCARD = "JACCARD";
constexpr const char* TANIMOTO = "TANIMOTO";
constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE";
constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE";
} // namespace metric
Expand Down
24 changes: 2 additions & 22 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,10 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk; i++) {
cur_distances[i] = faiss::Jaccard_2_Tanimoto(cur_distances[i]);
}
}
break;
}
case faiss::METRIC_Hamming: {
Expand Down Expand Up @@ -181,17 +174,10 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
case faiss::METRIC_Jaccard: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances};
binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk; i++) {
cur_distances[i] = faiss::Jaccard_2_Tanimoto(cur_distances[i]);
}
}
break;
}
case faiss::METRIC_Hamming: {
Expand Down Expand Up @@ -288,12 +274,6 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset);
break;
}
case faiss::METRIC_Tanimoto: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(
faiss::METRIC_Tanimoto, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset);
break;
}
case faiss::METRIC_Hamming: {
auto cur_query = (const uint8_t*)xq + (dim / 8) * index;
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming, cur_query,
Expand Down
1 change: 0 additions & 1 deletion src/common/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ Str2FaissMetricType(std::string metric) {
{metric::COSINE, faiss::MetricType::METRIC_INNER_PRODUCT},
{metric::HAMMING, faiss::MetricType::METRIC_Hamming},
{metric::JACCARD, faiss::MetricType::METRIC_Jaccard},
{metric::TANIMOTO, faiss::MetricType::METRIC_Tanimoto},
{metric::SUBSTRUCTURE, faiss::MetricType::METRIC_Substructure},
{metric::SUPERSTRUCTURE, faiss::MetricType::METRIC_Superstructure},
};
Expand Down
6 changes: 2 additions & 4 deletions tests/ut/test_bruteforce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,15 @@ TEST_CASE("Test Brute Force", "[binary vector]") {
const int64_t dim = 1024;
const int64_t k = 5;

auto metric =
GENERATE(as<std::string>{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD, knowhere::metric::TANIMOTO,
knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE);
auto metric = GENERATE(as<std::string>{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD,
knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE);

const auto train_ds = GenBinDataSet(nb, dim);
const auto query_ds = CopyBinDataSet(train_ds, nq);

std::unordered_map<std::string, float> radius_map = {
{knowhere::metric::HAMMING, 1.0},
{knowhere::metric::JACCARD, 0.1},
{knowhere::metric::TANIMOTO, 0.1},
};
const knowhere::Json conf = {
{knowhere::meta::DIM, dim},
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ TEST_CASE("Invalid diskann params test", "[diskann]") {
knowhere::Status test_stat;
// invalid metric type
test_json = test_gen();
test_json["metric_type"] = knowhere::metric::TANIMOTO;
test_json["metric_type"] = knowhere::metric::JACCARD;
test_stat = diskann.Build(*ds_ptr, test_json);
REQUIRE(test_stat == knowhere::Status::invalid_metric_type);
// raw data path not exist
Expand Down
21 changes: 1 addition & 20 deletions thirdparty/faiss/faiss/IndexBinaryFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,10 @@ void IndexBinaryFlat::search(
const BitsetView bitset) const {
FAISS_THROW_IF_NOT(k > 0);

if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
if (metric_type == METRIC_Jaccard) {
float* D = reinterpret_cast<float*>(distances);
float_maxheap_array_t res = {size_t(n), size_t(k), labels, D};
binary_knn_hc(METRIC_Jaccard, &res, x, xb.data(), ntotal, code_size, bitset);

if (metric_type == METRIC_Tanimoto) {
for (int i = 0; i < k * n; i++) {
D[i] = Jaccard_2_Tanimoto(D[i]);
}
}
} else if (metric_type == METRIC_Hamming) {
int_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
binary_knn_hc(METRIC_Hamming, &res, x, xb.data(), ntotal, code_size, bitset);
Expand Down Expand Up @@ -126,19 +120,6 @@ void IndexBinaryFlat::range_search(
bitset);
break;
}
case METRIC_Tanimoto: {
binary_range_search<CMin<float, int64_t>, float>(
METRIC_Tanimoto,
x,
xb.data(),
n,
ntotal,
radius,
code_size,
result,
bitset);
break;
}
case METRIC_Hamming: {
binary_range_search<CMin<int, int64_t>, int>(
METRIC_Hamming,
Expand Down
22 changes: 2 additions & 20 deletions thirdparty/faiss/faiss/IndexBinaryIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ void IndexBinaryIVF::train(idx_t n, const uint8_t* x) {
quantizer->reset();

IndexFlat index_tmp;
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
if (metric_type == METRIC_Jaccard) {
index_tmp = IndexFlat(d, METRIC_Jaccard);
} else if (
metric_type == METRIC_Substructure ||
Expand Down Expand Up @@ -858,7 +858,6 @@ BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
return select_IVFBinaryScannerL2<false>(code_size);
}
case METRIC_Jaccard:
case METRIC_Tanimoto:
if (store_pairs) {
return select_IVFBinaryScannerJaccard<true>(code_size);
} else {
Expand All @@ -884,7 +883,7 @@ void IndexBinaryIVF::search_preassigned(
bool store_pairs,
const IVFSearchParameters* params,
const BitsetView bitset) const {
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
if (metric_type == METRIC_Jaccard) {
if (use_heap) {
float* D = new float[k * n];
float* c_dis = new float[n * nprobe];
Expand All @@ -901,11 +900,6 @@ void IndexBinaryIVF::search_preassigned(
store_pairs,
params,
bitset);
if (metric_type == METRIC_Tanimoto) {
for (int i = 0; i < k * n; i++) {
D[i] = Jaccard_2_Tanimoto(D[i]);
}
}
memcpy(distances, D, sizeof(float) * n * k);
delete[] D;
delete[] c_dis;
Expand Down Expand Up @@ -958,20 +952,8 @@ void IndexBinaryIVF::range_search(

t0 = getmillisecs();
invlists->prefetch_lists(idx.get(), n * nprobe);

if (metric_type == METRIC_Tanimoto) {
radius = Tanimoto_2_Jaccard(radius);
}

range_search_preassigned(
n, x, radius, idx.get(), coarse_dis.get(), res, bitset);

if (metric_type == METRIC_Tanimoto) {
for (auto i = 0; i < res->lims[n]; i++) {
res->distances[i] = Jaccard_2_Tanimoto(res->distances[i]);
}
}

indexIVF_stats.search_time += getmillisecs() - t0;
}

Expand Down
19 changes: 1 addition & 18 deletions thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ void IndexBinaryIVF::search_preassigned_thread_safe(
const IVFSearchParameters* params,
const size_t nprobe,
const BitsetView bitset) const {
if (metric_type == METRIC_Jaccard || metric_type == METRIC_Tanimoto) {
if (metric_type == METRIC_Jaccard) {
if (use_heap) {
float* D = new float[k * n];
float* c_dis = new float[n * nprobe];
Expand All @@ -670,11 +670,6 @@ void IndexBinaryIVF::search_preassigned_thread_safe(
params,
nprobe,
bitset);
if (metric_type == METRIC_Tanimoto) {
for (int i = 0; i < k * n; i++) {
D[i] = Jaccard_2_Tanimoto(D[i]);
}
}
memcpy(distances, D, sizeof(float) * n * k);
delete[] D;
delete[] c_dis;
Expand Down Expand Up @@ -747,20 +742,8 @@ void IndexBinaryIVF::range_search_thread_safe(

t0 = getmillisecs();
invlists->prefetch_lists(idx.get(), n * nprobe);

if (metric_type == METRIC_Tanimoto) {
radius = Tanimoto_2_Jaccard(radius);
}

range_search_preassigned_thread_safe(
n, x, radius, idx.get(), coarse_dis.get(), res, nprobe, bitset);

if (metric_type == METRIC_Tanimoto) {
for (auto i = 0; i < res->lims[n]; i++) {
res->distances[i] = Jaccard_2_Tanimoto(res->distances[i]);
}
}

indexIVF_stats.search_time += getmillisecs() - t0;
}

Expand Down
1 change: 0 additions & 1 deletion thirdparty/faiss/faiss/MetricType.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ enum MetricType {
/// metric_arg

METRIC_Jaccard,
METRIC_Tanimoto,
METRIC_Hamming,
METRIC_Substructure, ///< Tversky case alpha = 0, beta = 1
METRIC_Superstructure, ///< Tversky case alpha = 1, beta = 0
Expand Down
7 changes: 0 additions & 7 deletions thirdparty/faiss/faiss/utils/binary_distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,6 @@ void binary_range_search(
RangeSearchResult* res,
const BitsetView bitset) {
switch (metric_type) {
case METRIC_Tanimoto:
radius = Tanimoto_2_Jaccard(radius);
case METRIC_Jaccard: {
{
switch (code_size) {
Expand All @@ -614,11 +612,6 @@ void binary_range_search(
break;
}
}
if (METRIC_Tanimoto == metric_type) {
for (auto i = 0; i < res->lims[na]; i++) {
res->distances[i] = Jaccard_2_Tanimoto(res->distances[i]);
}
}
break;
}

Expand Down
13 changes: 0 additions & 13 deletions thirdparty/faiss/faiss/utils/extra_distances-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,4 @@ inline float VectorDistance<METRIC_Jaccard>::operator()(
return 1 - accu_num / accu_den;
}

template <>
inline float VectorDistance<METRIC_Tanimoto>::operator()(
const float* x,
const float* y) const {
float accu_num = 0, accu_den = 0;
for (size_t i = 0; i < d; i++) {
float xi = x[i], yi = y[i];
accu_num += xi * yi;
accu_den += xi * xi + yi * yi - xi * yi;
}
return -log2(accu_num / accu_den) ;
}

} // namespace faiss
3 changes: 0 additions & 3 deletions thirdparty/faiss/faiss/utils/extra_distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ void pairwise_extra_distances(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(Tanimoto);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down Expand Up @@ -189,7 +188,6 @@ void knn_extra_metrics(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(Tanimoto);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down Expand Up @@ -217,7 +215,6 @@ DistanceComputer* get_extra_distance_computer(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(Tanimoto);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down

0 comments on commit 27e5667

Please sign in to comment.