diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index 59a9a1903..a19c73339 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -58,25 +58,31 @@ def CreateIndexDiskANN(index_name, index_prefix, metric_type, simd_type="auto"): 'diskann_f' 'diskann_i8' 'diskann_ui8'.""" ) -def CreateAsyncIndex(index_name, simd_type="auto", *args): +def CreateAsyncIndex(index_name, index_prefix="", metric_type="", simd_type="auto"): if simd_type not in ["auto", "avx512", "avx2", "avx", "sse4_2"]: raise ValueError("simd type only support auto avx512 avx2 avx sse4_2") + SetSimdType(simd_type) + if index_name not in ["bin_flat", "bin_ivf_flat", "flat", "ivf_flat", "ivf_pq", "ivf_sq8", - "hnsw", "annoy","gpu_flat", "gpu_ivf_flat", + "hnsw", "annoy", "gpu_flat", "gpu_ivf_flat", "gpu_ivf_pq", "gpu_ivf_sq8", "diskann_f", "diskann_i8", "diskann_ui8"]: raise ValueError( - """ index name only support + """ index name only support 'bin_flat', 'bin_ivf_flat', 'flat', 'ivf_flat', 'ivf_pq', 'ivf_sq8', 'hnsw', 'annoy', 'gpu_flat', 'gpu_ivf_flat', 'gpu_ivf_pq', 'gpu_ivf_sq8', 'diskann_f', 'diskann_i8', 'diskann_ui8'.""" - ) - if index_name in ["diskann_f", "diskann_i8", "diskann_ui8"] : - index_prefix = args[0] - metric_type = args[1] - return AsyncIndex(index_name, index_prefix, metric_type) - return AsyncIndex(index_name) + ) + + if index_name in ["diskann_f", "diskann_i8", "diskann_ui8"]: + if index_prefix == "": + raise ValueError("Must pass index_prefix to DiskANN") + if metric_type == "": + raise ValueError("Must pass metric_type to DiskANN") + return AsyncIndex(index_name, index_prefix, metric_type) + else: + return AsyncIndex(index_name) class GpuContext: diff --git a/unittest/AsyncIndex.h b/unittest/AsyncIndex.h index a300e765d..dbff9dfb1 100644 --- a/unittest/AsyncIndex.h +++ b/unittest/AsyncIndex.h @@ -22,6 +22,7 @@ #include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexIVFSQ.h" #ifdef KNOWHERE_WITH_DISKANN +#include "LocalFileManager.h" #include "knowhere/index/vector_index/IndexDiskANN.h" #include "knowhere/index/vector_index/IndexDiskANNConfig.h" #endif @@ -54,13 +55,17 @@ class AsyncIndex : public VecIndex { } #ifdef KNOWHERE_WITH_DISKANN - AsyncIndex(std::string type, MetricType metric_type, std::shared_ptr file_manager) { + AsyncIndex(std::string type, std::string index_prefix, std::string metric_type) { + std::transform(metric_type.begin(), metric_type.end(), metric_type.begin(), toupper); if (type == "diskann_f") { - index_ = std::make_unique>(type, metric_type, file_manager); + index_ = std::make_unique>(index_prefix, metric_type, + std::make_shared()); } else if (type == "disann_ui8") { - index_ = std::make_unique>(type, metric_type, file_manager); + index_ = std::make_unique>(index_prefix, metric_type, + std::make_shared()); } else if (type == "diskann_i8") { - index_ = std::make_unique>(type, metric_type, file_manager); + index_ = std::make_unique>(index_prefix, metric_type, + std::make_shared()); } else { KNOWHERE_THROW_FORMAT("Invalid index type %s", std::string(type).c_str()); }