Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Try to support float16 for flat.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
jjyaoao committed May 13, 2023
1 parent fbf2b6e commit be92afc
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ class FlatIndexNode : public IndexNode {
return err;
}

std::vector<float>
convertFloat16ToFloat32(const std::vector<float16>& input) {
std::vector<float> output(input.size());
std::transform(input.begin(), input.end(), output.begin(), [](float16 f) { return static_cast<float>(f); });
return output;
}

Status
Train(const DataSet& dataset, const Config& cfg) override {
const FlatConfig& f_cfg = static_cast<const FlatConfig&>(cfg);
Expand All @@ -55,6 +62,14 @@ class FlatIndexNode : public IndexNode {
LOG_KNOWHERE_WARNING_ << "please check metric type: " << f_cfg.metric_type;
return metric.error();
}

auto dim_data = dataset.GetDim();

// If dim_data is float16, convert it to float32
if (typeid(dim_data[0]) == typeid(float16)) {
dim_data = convertFloat16ToFloat32(dim_data);
}

index_ = std::make_unique<T>(dataset.GetDim(), metric.value());
return Status::success;
}
Expand All @@ -63,6 +78,12 @@ class FlatIndexNode : public IndexNode {
Add(const DataSet& dataset, const Config& cfg) override {
auto x = dataset.GetTensor();
auto n = dataset.GetRows();

if (typeid(x[0]) == typeid(float16)) {
std::vector<float> x_float32(x.begin(), x.end());
x = x_float32;
}

if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
index_->add(n, (const float*)x);
}
Expand Down Expand Up @@ -92,6 +113,11 @@ class FlatIndexNode : public IndexNode {
auto x = dataset.GetTensor();
auto dim = dataset.GetDim();

// If x is float16, convert it to float32
if (typeid(x[0]) == typeid(float16)) {
x = convertFloat16ToFloat32(x);
}

auto len = k * nq;
int64_t* ids = nullptr;
float* distances = nullptr;
Expand Down Expand Up @@ -150,6 +176,11 @@ class FlatIndexNode : public IndexNode {
auto xq = dataset.GetTensor();
auto dim = dataset.GetDim();

// If xq is float16, convert it to float32
if (typeid(xq[0]) == typeid(float16)) {
xq = convertFloat16ToFloat32(xq);
}

int64_t* ids = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
Expand Down Expand Up @@ -212,7 +243,14 @@ class FlatIndexNode : public IndexNode {
for (int64_t i = 0; i < rows; i++) {
index_->reconstruct(ids[i], data + i * dim);
}
return GenResultDataSet(rows, dim, data);
// If original data was float16, convert it back before returning
if (typeid(dataset.GetTensor()[0]) == typeid(float16)) {
auto data16 = convertFloat32ToFloat16(data, rows * dim);
delete[] data;
return GenResultDataSet(rows, dim, data16);
} else {
return GenResultDataSet(rows, dim, data);
}
} catch (const std::exception& e) {
std::unique_ptr<float[]> auto_del(data);
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
Expand Down

0 comments on commit be92afc

Please sign in to comment.