@@ -41,6 +41,13 @@ class FlatIndexNode : public IndexNode {
41
41
return err;
42
42
}
43
43
44
+ std::vector<float >
45
+ convertFloat16ToFloat32 (const std::vector<float16>& input) {
46
+ std::vector<float > output (input.size ());
47
+ std::transform (input.begin (), input.end (), output.begin (), [](float16 f) { return static_cast <float >(f); });
48
+ return output;
49
+ }
50
+
44
51
Status
45
52
Train (const DataSet& dataset, const Config& cfg) override {
46
53
const FlatConfig& f_cfg = static_cast <const FlatConfig&>(cfg);
@@ -55,6 +62,14 @@ class FlatIndexNode : public IndexNode {
55
62
LOG_KNOWHERE_WARNING_ << " please check metric type: " << f_cfg.metric_type ;
56
63
return metric.error ();
57
64
}
65
+
66
+ auto dim_data = dataset.GetDim ();
67
+
68
+ // If dim_data is float16, convert it to float32
69
+ if (typeid (dim_data[0 ]) == typeid (float16)) {
70
+ dim_data = convertFloat16ToFloat32 (dim_data);
71
+ }
72
+
58
73
index_ = std::make_unique<T>(dataset.GetDim (), metric.value ());
59
74
return Status::success;
60
75
}
@@ -63,6 +78,12 @@ class FlatIndexNode : public IndexNode {
63
78
Add (const DataSet& dataset, const Config& cfg) override {
64
79
auto x = dataset.GetTensor ();
65
80
auto n = dataset.GetRows ();
81
+
82
+ if (typeid (x[0 ]) == typeid (float16)) {
83
+ std::vector<float > x_float32 (x.begin (), x.end ());
84
+ x = x_float32;
85
+ }
86
+
66
87
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
67
88
index_->add (n, (const float *)x);
68
89
}
@@ -92,6 +113,11 @@ class FlatIndexNode : public IndexNode {
92
113
auto x = dataset.GetTensor ();
93
114
auto dim = dataset.GetDim ();
94
115
116
+ // If x is float16, convert it to float32
117
+ if (typeid (x[0 ]) == typeid (float16)) {
118
+ x = convertFloat16ToFloat32 (x);
119
+ }
120
+
95
121
auto len = k * nq;
96
122
int64_t * ids = nullptr ;
97
123
float * distances = nullptr ;
@@ -150,6 +176,11 @@ class FlatIndexNode : public IndexNode {
150
176
auto xq = dataset.GetTensor ();
151
177
auto dim = dataset.GetDim ();
152
178
179
+ // If xq is float16, convert it to float32
180
+ if (typeid (xq[0 ]) == typeid (float16)) {
181
+ xq = convertFloat16ToFloat32 (xq);
182
+ }
183
+
153
184
int64_t * ids = nullptr ;
154
185
float * distances = nullptr ;
155
186
size_t * lims = nullptr ;
@@ -212,7 +243,14 @@ class FlatIndexNode : public IndexNode {
212
243
for (int64_t i = 0 ; i < rows; i++) {
213
244
index_->reconstruct (ids[i], data + i * dim);
214
245
}
215
- return GenResultDataSet (rows, dim, data);
246
+ // If original data was float16, convert it back before returning
247
+ if (typeid (dataset.GetTensor ()[0 ]) == typeid (float16)) {
248
+ auto data16 = convertFloat32ToFloat16 (data, rows * dim);
249
+ delete[] data;
250
+ return GenResultDataSet (rows, dim, data16);
251
+ } else {
252
+ return GenResultDataSet (rows, dim, data);
253
+ }
216
254
} catch (const std::exception & e) {
217
255
std::unique_ptr<float []> auto_del (data);
218
256
LOG_KNOWHERE_WARNING_ << " faiss inner error: " << e.what ();
0 commit comments