Skip to content
This repository was archived by the owner on Aug 16, 2023. It is now read-only.

Commit 75177e4

Browse files
authored
Fix creation of diskann in async test (#482)
Signed-off-by: zh Wang <[email protected]> Signed-off-by: zh Wang <[email protected]>
1 parent 95e3ac0 commit 75177e4

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

python/knowhere/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,31 @@ def CreateIndexDiskANN(index_name, index_prefix, metric_type, simd_type="auto"):
5858
'diskann_f' 'diskann_i8' 'diskann_ui8'."""
5959
)
6060

61-
def CreateAsyncIndex(index_name, simd_type="auto", *args):
6261

62+
def CreateAsyncIndex(index_name, index_prefix="", metric_type="", simd_type="auto"):
6363
if simd_type not in ["auto", "avx512", "avx2", "avx", "sse4_2"]:
6464
raise ValueError("simd type only support auto avx512 avx2 avx sse4_2")
6565

66+
SetSimdType(simd_type)
67+
6668
if index_name not in ["bin_flat", "bin_ivf_flat", "flat", "ivf_flat", "ivf_pq", "ivf_sq8",
67-
"hnsw", "annoy","gpu_flat", "gpu_ivf_flat",
69+
"hnsw", "annoy", "gpu_flat", "gpu_ivf_flat",
6870
"gpu_ivf_pq", "gpu_ivf_sq8", "diskann_f", "diskann_i8", "diskann_ui8"]:
6971
raise ValueError(
70-
""" index name only support
72+
""" index name only support
7173
'bin_flat', 'bin_ivf_flat', 'flat', 'ivf_flat', 'ivf_pq', 'ivf_sq8',
7274
'hnsw', 'annoy', 'gpu_flat', 'gpu_ivf_flat',
7375
'gpu_ivf_pq', 'gpu_ivf_sq8', 'diskann_f', 'diskann_i8', 'diskann_ui8'."""
74-
)
75-
if index_name in ["diskann_f", "diskann_i8", "diskann_ui8"] :
76-
index_prefix = args[0]
77-
metric_type = args[1]
78-
return AsyncIndex(index_name, index_prefix, metric_type)
79-
return AsyncIndex(index_name)
76+
)
77+
78+
if index_name in ["diskann_f", "diskann_i8", "diskann_ui8"]:
79+
if index_prefix == "":
80+
raise ValueError("Must pass index_prefix to DiskANN")
81+
if metric_type == "":
82+
raise ValueError("Must pass metric_type to DiskANN")
83+
return AsyncIndex(index_name, index_prefix, metric_type)
84+
else:
85+
return AsyncIndex(index_name)
8086

8187

8288
class GpuContext:

unittest/AsyncIndex.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "knowhere/index/vector_index/IndexIVFPQ.h"
2323
#include "knowhere/index/vector_index/IndexIVFSQ.h"
2424
#ifdef KNOWHERE_WITH_DISKANN
25+
#include "LocalFileManager.h"
2526
#include "knowhere/index/vector_index/IndexDiskANN.h"
2627
#include "knowhere/index/vector_index/IndexDiskANNConfig.h"
2728
#endif
@@ -54,13 +55,17 @@ class AsyncIndex : public VecIndex {
5455
}
5556

5657
#ifdef KNOWHERE_WITH_DISKANN
57-
AsyncIndex(std::string type, MetricType metric_type, std::shared_ptr<FileManager> file_manager) {
58+
AsyncIndex(std::string type, std::string index_prefix, std::string metric_type) {
59+
std::transform(metric_type.begin(), metric_type.end(), metric_type.begin(), toupper);
5860
if (type == "diskann_f") {
59-
index_ = std::make_unique<knowhere::IndexDiskANN<float>>(type, metric_type, file_manager);
61+
index_ = std::make_unique<knowhere::IndexDiskANN<float>>(index_prefix, metric_type,
62+
std::make_shared<LocalFileManager>());
6063
} else if (type == "disann_ui8") {
61-
index_ = std::make_unique<knowhere::IndexDiskANN<uint8_t>>(type, metric_type, file_manager);
64+
index_ = std::make_unique<knowhere::IndexDiskANN<uint8_t>>(index_prefix, metric_type,
65+
std::make_shared<LocalFileManager>());
6266
} else if (type == "diskann_i8") {
63-
index_ = std::make_unique<knowhere::IndexDiskANN<int8_t>>(type, metric_type, file_manager);
67+
index_ = std::make_unique<knowhere::IndexDiskANN<int8_t>>(index_prefix, metric_type,
68+
std::make_shared<LocalFileManager>());
6469
} else {
6570
KNOWHERE_THROW_FORMAT("Invalid index type %s", std::string(type).c_str());
6671
}

0 commit comments

Comments
 (0)