Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
refactor binaryset and reuse binaryset memory after load
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Aug 10, 2023
1 parent 2de62cb commit c491525
Show file tree
Hide file tree
Showing 26 changed files with 314 additions and 156 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ut.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
run: |
sudo apt update \
&& sudo apt install -y cmake g++ gcc libopenblas-dev libaio-dev libcurl4-openssl-dev libevent-dev libgflags-dev python3 python3-pip python3-setuptools \
&& pip3 install conan==1.58.0 pytest faiss-cpu numpy wheel \
&& pip3 install conan==1.58.0 swig==4.1.1 pytest faiss-cpu numpy wheel \
&& conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local
- name: Build
run: |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Here's a list of verified OS types where Knowhere can successfully build and run
```bash
$ sudo apt install build-essential libopenblas-dev libaio-dev python3-dev python3-pip
$ pip3 install conan==1.59.0 --user
$ pip3 install swig==4.1.1 --user
$ export PATH=$PATH:$HOME/.local/bin
```

Expand Down
1 change: 1 addition & 0 deletions ci/E2E2.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pipeline {
sh "apt-get update || true"
sh "apt-get install libaio-dev libopenblas-dev libcurl4-openssl-dev libdouble-conversion-dev libevent-dev libgflags-dev git -y"
sh "pip3 install conan==1.58.0"
sh "pip3 install swig==4.1.1"
sh "conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local"
sh "rm -rf /usr/local/lib/cmake/"
sh "mkdir build"
Expand Down
1 change: 1 addition & 0 deletions ci/E2E_GPU.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pipeline {
sh "git config --global --add safe.directory '*'"
sh "git submodule update --recursive --init"
sh "pip3 install conan==1.58.0"
sh "pip3 install swig==4.1.1"
sh "conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local"
sh "rm -rf /usr/local/lib/cmake/"
sh "mkdir build"
Expand Down
1 change: 1 addition & 0 deletions ci/UT_GPU.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pipeline {
sh "apt-get update || true"
sh "apt-get install libaio-dev libcurl4-openssl-dev libdouble-conversion-dev libevent-dev libgflags-dev git -y"
sh "pip3 install conan==1.58.0"
sh "pip3 install swig==4.1.1"
sh "conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local"
sh "rm -rf /usr/local/lib/cmake/"
sh "mkdir build"
Expand Down
97 changes: 43 additions & 54 deletions include/knowhere/binaryset.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,80 +20,69 @@

namespace knowhere {

struct Binary {
std::shared_ptr<uint8_t[]> data;
int64_t size = 0;
};
using BinaryPtr = std::shared_ptr<Binary>;

inline uint8_t*
CopyBinary(const BinaryPtr& bin) {
uint8_t* newdata = new uint8_t[bin->size];
std::memcpy(newdata, bin->data.get(), bin->size);
return newdata;
}

class BinarySet {
public:
BinaryPtr
GetByName(const std::string& name) const {
if (Contains(name)) {
return binary_map_.at(name);
}
return nullptr;
BinarySet() : data_(nullptr), size_(0) {
}

// This API is used to be compatible with knowhere-1.x.
// It tries each key name one by one, and returns the first matched.
BinaryPtr
GetByNames(const std::vector<std::string>& names) const {
for (auto& name : names) {
if (Contains(name)) {
return binary_map_.at(name);
}
}
return nullptr;
BinarySet(std::unique_ptr<uint8_t[]>& data, uint64_t size) {
Clear();
Set(data, size);
}

void
Append(const std::string& name, BinaryPtr binary) {
binary_map_[name] = std::move(binary);
template <typename T>
BinarySet(T&& binset) {
Clear();
size_ = binset.size_;
data_.swap(binset.data_);
binset.size_ = 0;
}

template <typename T>
BinarySet&
operator=(T&& binset) {
Clear();
size_ = binset.size_;
data_.swap(binset.data_);
binset.size_ = 0;
return *this;
}

void
Append(const std::string& name, std::shared_ptr<uint8_t[]> data, int64_t size) {
auto binary = std::make_shared<Binary>();
binary->data = data;
binary->size = size;
binary_map_[name] = std::move(binary);
Set(std::unique_ptr<uint8_t[]>& data, uint64_t size) {
data_ = std::move(data);
size_ = size;
}

BinaryPtr
Erase(const std::string& name) {
BinaryPtr result = nullptr;
auto it = binary_map_.find(name);
if (it != binary_map_.end()) {
result = it->second;
binary_map_.erase(it);
void
Clear() {
if (data_ != nullptr) {
data_.reset(nullptr);
size_ = 0;
}
return result;
}

void
clear() {
binary_map_.clear();
std::unique_ptr<uint8_t[]>
Release() {
size_ = 0;
return std::move(data_);
}

bool
Contains(const std::string& key) const {
return binary_map_.find(key) != binary_map_.end();
const uint8_t*
GetData() const {
return data_.get();
}

public:
std::map<std::string, BinaryPtr> binary_map_;
const uint64_t
GetSize() const {
return size_;
}

private:
std::unique_ptr<uint8_t[]> data_;
uint64_t size_;
};

using BinarySetPtr = std::shared_ptr<BinarySet>;
} // namespace knowhere

#endif /* BINARYSET_H */
5 changes: 3 additions & 2 deletions include/knowhere/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ class Index {
return this->node->Serialize(binset);
}

template <typename Bin_T>
Status
Deserialize(const BinarySet& binset, const Json& json = {}) {
Deserialize(Bin_T&& binset, const Json& json = {}) {
Json json_(json);
auto cfg = this->node->CreateConfig();
{
Expand All @@ -225,7 +226,7 @@ class Index {
if (res != Status::success) {
return res;
}
return this->node->Deserialize(binset, *cfg);
return this->node->Deserialize(std::forward<Bin_T>(binset), *cfg);
}

Status
Expand Down
3 changes: 3 additions & 0 deletions include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class IndexNode : public Object {
virtual Status
Deserialize(const BinarySet& binset, const Config& config) = 0;

virtual Status
Deserialize(BinarySet&& binset, const Config& config) = 0;

virtual Status
DeserializeFromFile(const std::string& filename, const Config& config) = 0;

Expand Down
5 changes: 5 additions & 0 deletions include/knowhere/index_node_thread_pool_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class IndexNodeThreadPoolWrapper : public IndexNode {
return index_node_->Deserialize(binset, config);
}

Status
Deserialize(BinarySet&& binset, const Config& config) override {
return index_node_->Deserialize(binset, config);
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
return index_node_->DeserializeFromFile(filename, config);
Expand Down
8 changes: 5 additions & 3 deletions python/knowhere/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import swigknowhere
from .swigknowhere import Status
from .swigknowhere import GetBinarySet, GetNullDataSet, GetNullBitSetView
from .swigknowhere import Status, BinarySet
from .swigknowhere import GetNullDataSet, GetNullBitSetView, GetBinarySet
import numpy as np

def CreateIndex(name):
Expand All @@ -10,6 +10,9 @@ def CreateIndex(name):
def CreateBitSet(bits_num):
return swigknowhere.BitSet(bits_num)

# def GetBinarySet():
# return BinarySet


def ArrayToDataSet(arr):
if arr.ndim == 1:
Expand All @@ -25,7 +28,6 @@ def ArrayToDataSet(arr):
"""
)


def DataSetToArray(ans):
dim = swigknowhere.DataSet_Dim(ans)
rows = swigknowhere.DataSet_Rows(ans)
Expand Down
22 changes: 11 additions & 11 deletions python/knowhere/knowhere.i
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ typedef uint64_t size_t;
#include <numpy/arrayobject.h>
#endif
#include <knowhere/expected.h>
#include <knowhere/binaryset.h>
#include <knowhere/factory.h>
#include <knowhere/comp/local_file_manager.h>
using namespace knowhere;
Expand All @@ -45,11 +46,9 @@ import_array();
%include <std_pair.i>
%include <std_map.i>
%include <std_shared_ptr.i>
%include <std_unique_ptr.i>
%include <exception.i>
%shared_ptr(knowhere::DataSet)
%shared_ptr(knowhere::BinarySet)
%template(DataSetPtr) std::shared_ptr<knowhere::DataSet>;
%template(BinarySetPtr) std::shared_ptr<knowhere::BinarySet>;
%include <knowhere/expected.h>
%include <knowhere/dataset.h>
%include <knowhere/binaryset.h>
Expand Down Expand Up @@ -183,15 +182,15 @@ class IndexWrap {
}

knowhere::Status
Serialize(knowhere::BinarySetPtr binset) {
Serialize(knowhere::BinarySet& binset) {
GILReleaser rel;
return idx.Serialize(*binset);
return idx.Serialize(binset);
}

knowhere::Status
Deserialize(knowhere::BinarySetPtr binset, const std::string& json) {
Deserialize(knowhere::BinarySet binset, const std::string& json) {
GILReleaser rel;
return idx.Deserialize(*binset, knowhere::Json::parse(json));
return idx.Deserialize(binset, knowhere::Json::parse(json));
}

int64_t
Expand Down Expand Up @@ -239,6 +238,11 @@ class BitSet {
int num_bits_ = 0;
};

BinarySet
GetBinarySet() {
return knowhere::BinarySet();
}

knowhere::BitsetView
GetNullBitSetView() {
return nullptr;
Expand Down Expand Up @@ -286,10 +290,6 @@ int64_t DataSet_Dim(knowhere::DataSetPtr results){
return results->GetDim();
}

knowhere::BinarySetPtr GetBinarySet() {
return std::make_shared<knowhere::BinarySet>();
}

knowhere::DataSetPtr GetNullDataSet() {
return nullptr;
}
Expand Down
6 changes: 6 additions & 0 deletions src/index/cagra/cagra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ class CagraIndexNode : public IndexNode {
return Status::success;
}

Status
Deserialize(BinarySet&& binset, const Config& config) override {
LOG_KNOWHERE_ERROR_ << "Not support Deserialization from BinarySet&& yet.";
return Status::not_implemented;
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
}
Expand Down
5 changes: 5 additions & 0 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class DiskANNIndexNode : public IndexNode {
Status
Deserialize(const BinarySet& binset, const Config& cfg) override;

Status
Deserialize(BinarySet&& binset, const Config& config) override {
return Deserialize(static_cast<const BinarySet&>(binset), config);
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
LOG_KNOWHERE_ERROR_ << "DiskANN doesn't support Deserialization from file.";
Expand Down
26 changes: 13 additions & 13 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ class FlatIndexNode : public IndexNode {
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
faiss::write_index_binary(index_.get(), &writer);
}
std::shared_ptr<uint8_t[]> data(writer.data_);
binset.Append(Type(), data, writer.rp);
std::unique_ptr<uint8_t[]> data(writer.data_);
binset.Set(data, writer.rp);

return Status::success;
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what();
Expand All @@ -269,18 +270,11 @@ class FlatIndexNode : public IndexNode {

Status
Deserialize(const BinarySet& binset, const Config& config) override {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
"BinaryIVF", // compatible with knowhere-1.x
Type()};
auto binary = binset.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader;
reader.total = binary->size;
reader.data_ = binary->data.get();

reader.total = const_cast<BinarySet&>(binset).GetSize();
reader.data_ = const_cast<BinarySet&>(binset).GetData();

if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
faiss::Index* index = faiss::read_index(&reader);
index_.reset(static_cast<T*>(index));
Expand All @@ -292,6 +286,12 @@ class FlatIndexNode : public IndexNode {
return Status::success;
}

Status
Deserialize(BinarySet&& binset, const Config& config) override {
LOG_KNOWHERE_ERROR_ << "Not support Deserialization from BinarySet&& yet.";
return Status::not_implemented;
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
auto cfg = static_cast<const knowhere::BaseConfig&>(config);
Expand Down
6 changes: 6 additions & 0 deletions src/index/gpu/flat_gpu/flat_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ class GpuFlatIndexNode : public IndexNode {
return Status::success;
}

Status
Deserialize(BinarySet&& binset, const Config& config) override {
LOG_KNOWHERE_ERROR_ << "Not support Deserialization from BinarySet&& yet.";
return Status::not_implemented;
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
LOG_KNOWHERE_ERROR_ << "GpuFlatIndex doesn't support Deserialization from file.";
Expand Down
Loading

0 comments on commit c491525

Please sign in to comment.