Skip to content
This repository was archived by the owner on Aug 16, 2023. It is now read-only.

Commit 028915f

Browse files
committed
Try to support float16 for flat.cc
Signed-off-by: jjyaoao <[email protected]>
1 parent fbf2b6e commit 028915f

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

src/index/flat/flat.cc

+39-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ class FlatIndexNode : public IndexNode {
4141
return err;
4242
}
4343

44+
std::vector<float>
45+
convertFloat16ToFloat32(const std::vector<float16>& input) {
46+
std::vector<float> output(input.size());
47+
std::transform(input.begin(), input.end(), output.begin(), [](float16 f) { return static_cast<float>(f); });
48+
return output;
49+
}
50+
4451
Status
4552
Train(const DataSet& dataset, const Config& cfg) override {
4653
const FlatConfig& f_cfg = static_cast<const FlatConfig&>(cfg);
@@ -55,6 +62,14 @@ class FlatIndexNode : public IndexNode {
5562
LOG_KNOWHERE_WARNING_ << "please check metric type: " << f_cfg.metric_type;
5663
return metric.error();
5764
}
65+
66+
auto dim_data = dataset.GetDim();
67+
68+
// If dim_data is float16, convert it to float32
69+
if (typeid(dim_data[0]) == typeid(float16)) {
70+
dim_data = convertFloat16ToFloat32(dim_data);
71+
}
72+
5873
index_ = std::make_unique<T>(dataset.GetDim(), metric.value());
5974
return Status::success;
6075
}
@@ -63,6 +78,12 @@ class FlatIndexNode : public IndexNode {
6378
Add(const DataSet& dataset, const Config& cfg) override {
6479
auto x = dataset.GetTensor();
6580
auto n = dataset.GetRows();
81+
82+
if (typeid(x[0]) == typeid(float16)) {
83+
std::vector<float> x_float32(x.begin(), x.end());
84+
x = x_float32;
85+
}
86+
6687
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
6788
index_->add(n, (const float*)x);
6889
}
@@ -92,6 +113,11 @@ class FlatIndexNode : public IndexNode {
92113
auto x = dataset.GetTensor();
93114
auto dim = dataset.GetDim();
94115

116+
// If x is float16, convert it to float32
117+
if (typeid(x[0]) == typeid(float16)) {
118+
x = convertFloat16ToFloat32(x);
119+
}
120+
95121
auto len = k * nq;
96122
int64_t* ids = nullptr;
97123
float* distances = nullptr;
@@ -150,6 +176,11 @@ class FlatIndexNode : public IndexNode {
150176
auto xq = dataset.GetTensor();
151177
auto dim = dataset.GetDim();
152178

179+
// If xq is float16, convert it to float32
180+
if (typeid(xq[0]) == typeid(float16)) {
181+
xq = convertFloat16ToFloat32(xq);
182+
}
183+
153184
int64_t* ids = nullptr;
154185
float* distances = nullptr;
155186
size_t* lims = nullptr;
@@ -212,7 +243,14 @@ class FlatIndexNode : public IndexNode {
212243
for (int64_t i = 0; i < rows; i++) {
213244
index_->reconstruct(ids[i], data + i * dim);
214245
}
215-
return GenResultDataSet(rows, dim, data);
246+
// If original data was float16, convert it back before returning
247+
if (typeid(dataset.GetTensor()[0]) == typeid(float16)) {
248+
auto data16 = convertFloat32ToFloat16(data, rows * dim);
249+
delete[] data;
250+
return GenResultDataSet(rows, dim, data16);
251+
} else {
252+
return GenResultDataSet(rows, dim, data);
253+
}
216254
} catch (const std::exception& e) {
217255
std::unique_ptr<float[]> auto_del(data);
218256
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();

0 commit comments

Comments
 (0)