Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
fix diskann ip (#904)
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang <[email protected]>
  • Loading branch information
foxspy authored May 24, 2023
1 parent df91db9 commit 98e2a16
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 24 deletions.
15 changes: 9 additions & 6 deletions thirdparty/DiskANN/include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <stack>
#include <string>
#include <unordered_map>
#include <functional>
#include "tsl/robin_set.h"
#include "tsl/robin_map.h"

Expand Down Expand Up @@ -93,13 +94,13 @@ namespace diskann {
public:
// Constructor for Bulk operations and for creating the index object solely
// for loading a prexisting index.
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points,
DISKANN_DLLEXPORT Index(Metric m, bool ip_prepared, const size_t dim, const size_t max_points,
const bool dynamic_index,
const bool enable_tags = false,
const bool support_eager_delete = false);

// Constructor for incremental index
DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points,
DISKANN_DLLEXPORT Index(Metric m, bool ip_prepared, const size_t dim, const size_t max_points,
const bool dynamic_index,
const Parameters &indexParameters,
const Parameters &searchParameters,
Expand Down Expand Up @@ -163,8 +164,8 @@ namespace diskann {
// insertions possible only when id corresponding to tag does not already
// exist in the graph
DISKANN_DLLEXPORT int insert_point(
const T *point,
const TagT tag);
const T *point,
const TagT tag);

// call before triggering deleteions - sets important flags required for
// deletion related operations
Expand Down Expand Up @@ -234,7 +235,7 @@ namespace diskann {
// change.
DISKANN_DLLEXPORT static const int METADATA_ROWS = 5;

// For Bulk Index FastL2 search, we interleave the data with graph
// For Bulk Index FastL2 search, we interleave the data with graph
DISKANN_DLLEXPORT void optimize_index_layout();

// For FastL2 search on optimized layout
Expand Down Expand Up @@ -352,13 +353,15 @@ namespace diskann {
private:
Metric _dist_metric = diskann::L2;
size_t _dim = 0;
size_t _padding_id = 0;
size_t _aligned_dim = 0;
T * _data = nullptr;
size_t _nd = 0; // number of active points i.e. existing in the graph
size_t _max_points = 0; // total number of points in given data set
size_t _num_frozen_pts = 0;
bool _has_built = false;
DISTFUN<T> _distance = nullptr;
DISTFUN<T> _func = nullptr;
std::function<T(const T*, const T*, size_t)> _distance;
unsigned _width = 0;
unsigned _ep = 0;
size_t _max_range_of_loaded_graph = 0;
Expand Down
15 changes: 9 additions & 6 deletions thirdparty/DiskANN/src/aux_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ namespace diskann {

template<typename T>
int build_merged_vamana_index(std::string base_file,
bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build,
double sampling_rate, double ram_budget,
Expand Down Expand Up @@ -528,7 +529,7 @@ namespace diskann {

std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
std::unique_ptr<diskann::Index<T>>(new diskann::Index<T>(
compareMetric, base_dim, base_num, false, false));
compareMetric, ip_prepared, base_dim, base_num, false, false));
_pvamanaIndex->build(base_file.c_str(), base_num, paras);

_pvamanaIndex->save(mem_index_path.c_str(), true);
Expand Down Expand Up @@ -571,7 +572,7 @@ namespace diskann {
get_bin_metadata(shard_base_file, shard_base_pts, shard_base_dim);
std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
std::unique_ptr<diskann::Index<T>>(
new diskann::Index<T>(compareMetric, shard_base_dim,
new diskann::Index<T>(compareMetric, ip_prepared, shard_base_dim,
shard_base_pts, false)); // TODO: Single?
_pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras);
_pvamanaIndex->save(shard_index_file.c_str());
Expand Down Expand Up @@ -908,6 +909,7 @@ namespace diskann {
bool use_disk_pq = disk_pq_dims != 0;

bool reorder_data = config.reorder;
bool ip_prepared = false;

std::string base_file = config.data_file_path;
std::string data_file_to_use = base_file;
Expand Down Expand Up @@ -944,6 +946,7 @@ namespace diskann {
std::string norm_file =
get_disk_index_max_base_norm_file(disk_index_path);
diskann::save_bin<float>(norm_file, &max_norm_of_base, 1, 1);
ip_prepared = true;
}

unsigned R = config.max_degree;
Expand Down Expand Up @@ -1044,7 +1047,7 @@ namespace diskann {
#endif
auto graph_s = std::chrono::high_resolution_clock::now();
diskann::build_merged_vamana_index<T>(
data_file_to_use.c_str(), diskann::Metric::L2, L, R,
data_file_to_use.c_str(), ip_prepared, diskann::Metric::L2, L, R,
config.accelerate_build, p_val, indexing_ram_budget, mem_index_path,
medoids_path, centroids_path);
auto graph_e = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -1141,17 +1144,17 @@ namespace diskann {
const BuildConfig &config);

template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, double sampling_rate,
double ram_budget, std::string mem_index_path, std::string medoids_path,
std::string centroids_file);
template DISKANN_DLLEXPORT int build_merged_vamana_index<float>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, double sampling_rate,
double ram_budget, std::string mem_index_path, std::string medoids_path,
std::string centroids_file);
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, double sampling_rate,
double ram_budget, std::string mem_index_path, std::string medoids_path,
std::string centroids_file);
Expand Down
34 changes: 22 additions & 12 deletions thirdparty/DiskANN/src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ namespace diskann {
// Initialize an index with metric m, load the data of type T with filename
// (bin), and initialize max_points
template<typename T, typename TagT>
Index<T, TagT>::Index(Metric m, const size_t dim, const size_t max_points,
const bool dynamic_index,
const Parameters &indexParams,
Index<T, TagT>::Index(Metric m, bool ip_prepared, const size_t dim, const size_t max_points,
const bool dynamic_index, const Parameters &indexParams,
const Parameters &searchParams, const bool enable_tags,
const bool support_eager_delete)
: Index(m, dim, max_points, dynamic_index, enable_tags, support_eager_delete) { // Thank you C++ 11!
: Index(m, ip_prepared, dim, max_points, dynamic_index, enable_tags,
support_eager_delete) { // Thank you C++ 11!
_indexingQueueSize = indexParams.Get<uint32_t>("L");
_indexingRange = indexParams.Get<uint32_t>("R");
_indexingMaxC = indexParams.Get<uint32_t>("C");
Expand All @@ -227,9 +227,9 @@ namespace diskann {
}

template<typename T, typename TagT>
Index<T, TagT>::Index(Metric m, const size_t dim, const size_t max_points,
const bool dynamic_index,
const bool enable_tags, const bool support_eager_delete)
Index<T, TagT>::Index(Metric m, bool ip_prepared, const size_t dim, const size_t max_points,
const bool dynamic_index, const bool enable_tags,
const bool support_eager_delete)
: _dist_metric(m), _dim(dim), _max_points(max_points),
_dynamic_index(dynamic_index), _enable_tags(enable_tags),
_support_eager_delete(support_eager_delete) {
Expand Down Expand Up @@ -279,7 +279,17 @@ namespace diskann {
_in_graph.reserve(_max_points + _num_frozen_pts);
_in_graph.resize(_max_points + _num_frozen_pts);
}
this->_distance = get_distance_function<T>(m);

this->_func = get_distance_function<T>(m);
if (ip_prepared) {
_padding_id = _dim - 1;
this->_distance = [this](const T* x, const T* y, size_t n) -> T {
auto ret = _func(x, y, n);
return ret + 2*x[_padding_id]*y[_padding_id];
};
} else {
this->_distance = _func;
}

_locks = std::vector<std::mutex>(_max_points + _num_frozen_pts);

Expand Down Expand Up @@ -387,7 +397,7 @@ namespace diskann {
std::ofstream out;
open_file_to_write(out, graph_file);

_u64 file_offset = 0; // we will use this if we want
_u64 file_offset = 0; // we will use this if we want
out.seekp(file_offset, out.beg);
_u64 index_size = 24;
_u32 max_degree = 0;
Expand Down Expand Up @@ -1181,14 +1191,14 @@ namespace diskann {

template<typename T, typename TagT>
void Index<T, TagT>::prune_neighbors(const unsigned location,
std::vector<Neighbor> &pool,
std::vector<Neighbor> &pool,
std::vector<unsigned> &pruned_list) {
prune_neighbors(location, pool, _indexingRange, _indexingMaxC, _indexingAlpha, pruned_list);
}

template<typename T, typename TagT>
void Index<T, TagT>::prune_neighbors(const unsigned location,
std::vector<Neighbor> &pool, const _u32 range, const _u32 max_candidate_size, const float alpha,
std::vector<Neighbor> &pool, const _u32 range, const _u32 max_candidate_size, const float alpha,
std::vector<unsigned> &pruned_list) {
if (pool.size() == 0) {
std::stringstream ss;
Expand Down Expand Up @@ -2715,7 +2725,7 @@ namespace diskann {
}

template<typename T, typename TagT>
int Index<T, TagT>::insert_point(const T *point,
int Index<T, TagT>::insert_point(const T *point,
const TagT tag) {
std::shared_lock<std::shared_timed_mutex> lock(_update_lock);
unsigned range = _indexingRange; // parameters.Get<unsigned>("R");
Expand Down

0 comments on commit 98e2a16

Please sign in to comment.