@@ -58,25 +58,31 @@ def CreateIndexDiskANN(index_name, index_prefix, metric_type, simd_type="auto"):
58
58
'diskann_f' 'diskann_i8' 'diskann_ui8'."""
59
59
)
60
60
61
- def CreateAsyncIndex (index_name , simd_type = "auto" , * args ):
62
61
62
+ def CreateAsyncIndex (index_name , index_prefix = "" , metric_type = "" , simd_type = "auto" ):
63
63
if simd_type not in ["auto" , "avx512" , "avx2" , "avx" , "sse4_2" ]:
64
64
raise ValueError ("simd type only support auto avx512 avx2 avx sse4_2" )
65
65
66
+ SetSimdType (simd_type )
67
+
66
68
if index_name not in ["bin_flat" , "bin_ivf_flat" , "flat" , "ivf_flat" , "ivf_pq" , "ivf_sq8" ,
67
- "hnsw" , "annoy" ,"gpu_flat" , "gpu_ivf_flat" ,
69
+ "hnsw" , "annoy" , "gpu_flat" , "gpu_ivf_flat" ,
68
70
"gpu_ivf_pq" , "gpu_ivf_sq8" , "diskann_f" , "diskann_i8" , "diskann_ui8" ]:
69
71
raise ValueError (
70
- """ index name only support
72
+ """ index name only support
71
73
'bin_flat', 'bin_ivf_flat', 'flat', 'ivf_flat', 'ivf_pq', 'ivf_sq8',
72
74
'hnsw', 'annoy', 'gpu_flat', 'gpu_ivf_flat',
73
75
'gpu_ivf_pq', 'gpu_ivf_sq8', 'diskann_f', 'diskann_i8', 'diskann_ui8'."""
74
- )
75
- if index_name in ["diskann_f" , "diskann_i8" , "diskann_ui8" ] :
76
- index_prefix = args [0 ]
77
- metric_type = args [1 ]
78
- return AsyncIndex (index_name , index_prefix , metric_type )
79
- return AsyncIndex (index_name )
76
+ )
77
+
78
+ if index_name in ["diskann_f" , "diskann_i8" , "diskann_ui8" ]:
79
+ if index_prefix == "" :
80
+ raise ValueError ("Must pass index_prefix to DiskANN" )
81
+ if metric_type == "" :
82
+ raise ValueError ("Must pass metric_type to DiskANN" )
83
+ return AsyncIndex (index_name , index_prefix , metric_type )
84
+ else :
85
+ return AsyncIndex (index_name )
80
86
81
87
82
88
class GpuContext :
0 commit comments