Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Fix creation of diskann in async test (#482)
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>

Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 authored Sep 26, 2022
1 parent 95e3ac0 commit 75177e4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
24 changes: 15 additions & 9 deletions python/knowhere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,31 @@ def CreateIndexDiskANN(index_name, index_prefix, metric_type, simd_type="auto"):
'diskann_f' 'diskann_i8' 'diskann_ui8'."""
)

def CreateAsyncIndex(index_name, simd_type="auto", *args):

def CreateAsyncIndex(index_name, index_prefix="", metric_type="", simd_type="auto"):
if simd_type not in ["auto", "avx512", "avx2", "avx", "sse4_2"]:
raise ValueError("simd type only support auto avx512 avx2 avx sse4_2")

SetSimdType(simd_type)

if index_name not in ["bin_flat", "bin_ivf_flat", "flat", "ivf_flat", "ivf_pq", "ivf_sq8",
"hnsw", "annoy","gpu_flat", "gpu_ivf_flat",
"hnsw", "annoy", "gpu_flat", "gpu_ivf_flat",
"gpu_ivf_pq", "gpu_ivf_sq8", "diskann_f", "diskann_i8", "diskann_ui8"]:
raise ValueError(
""" index name only support
""" index name only support
'bin_flat', 'bin_ivf_flat', 'flat', 'ivf_flat', 'ivf_pq', 'ivf_sq8',
'hnsw', 'annoy', 'gpu_flat', 'gpu_ivf_flat',
'gpu_ivf_pq', 'gpu_ivf_sq8', 'diskann_f', 'diskann_i8', 'diskann_ui8'."""
)
if index_name in ["diskann_f", "diskann_i8", "diskann_ui8"] :
index_prefix = args[0]
metric_type = args[1]
return AsyncIndex(index_name, index_prefix, metric_type)
return AsyncIndex(index_name)
)

if index_name in ["diskann_f", "diskann_i8", "diskann_ui8"]:
if index_prefix == "":
raise ValueError("Must pass index_prefix to DiskANN")
if metric_type == "":
raise ValueError("Must pass metric_type to DiskANN")
return AsyncIndex(index_name, index_prefix, metric_type)
else:
return AsyncIndex(index_name)


class GpuContext:
Expand Down
13 changes: 9 additions & 4 deletions unittest/AsyncIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#ifdef KNOWHERE_WITH_DISKANN
#include "LocalFileManager.h"
#include "knowhere/index/vector_index/IndexDiskANN.h"
#include "knowhere/index/vector_index/IndexDiskANNConfig.h"
#endif
Expand Down Expand Up @@ -54,13 +55,17 @@ class AsyncIndex : public VecIndex {
}

#ifdef KNOWHERE_WITH_DISKANN
AsyncIndex(std::string type, MetricType metric_type, std::shared_ptr<FileManager> file_manager) {
AsyncIndex(std::string type, std::string index_prefix, std::string metric_type) {
std::transform(metric_type.begin(), metric_type.end(), metric_type.begin(), toupper);
if (type == "diskann_f") {
index_ = std::make_unique<knowhere::IndexDiskANN<float>>(type, metric_type, file_manager);
index_ = std::make_unique<knowhere::IndexDiskANN<float>>(index_prefix, metric_type,
std::make_shared<LocalFileManager>());
} else if (type == "disann_ui8") {
index_ = std::make_unique<knowhere::IndexDiskANN<uint8_t>>(type, metric_type, file_manager);
index_ = std::make_unique<knowhere::IndexDiskANN<uint8_t>>(index_prefix, metric_type,
std::make_shared<LocalFileManager>());
} else if (type == "diskann_i8") {
index_ = std::make_unique<knowhere::IndexDiskANN<int8_t>>(type, metric_type, file_manager);
index_ = std::make_unique<knowhere::IndexDiskANN<int8_t>>(index_prefix, metric_type,
std::make_shared<LocalFileManager>());
} else {
KNOWHERE_THROW_FORMAT("Invalid index type %s", std::string(type).c_str());
}
Expand Down

0 comments on commit 75177e4

Please sign in to comment.