Skip to content

Commit

Permalink
update bn
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangxinglei committed Oct 22, 2024
1 parent 77fbd94 commit 4d44aa1
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 43 deletions.
20 changes: 19 additions & 1 deletion core/kernels/bn_table_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class BnStatisticsPushKernel : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
butil::IOBuf acc_buf;

std::vector<double*> allocated_pointers;

for (int i = 0; i < N_; i++) {
const ResourceHandle& handle = HandleFromInput(c, i);

Expand All @@ -97,12 +99,28 @@ class BnStatisticsPushKernel : public AsyncOpKernel {
CHECK(variable);

Tensor *var_tensor = variable->tensor();
acc_buf.append_user_data(var_tensor->flat<float>().data(), var_tensor->NumElements() * sizeof(float), NoOpDeleter);

int num_elements = var_tensor->NumElements();
double* dynamic_double_data = new double[num_elements];
const float* float_data = var_tensor->flat<float>().data();
for (int i = 0; i < num_elements; ++i) {
// std::cout << "float data is: " << float_data[i] << std::endl;
dynamic_double_data[i] = static_cast<double>(float_data[i]);
// std::cout << "double data is: " << dynamic_double_data[i] << std::endl;
}
acc_buf.append_user_data(dynamic_double_data, num_elements * sizeof(double), NoOpDeleter);
// acc_buf.append_user_data(var_tensor->flat<float>().data(), var_tensor->NumElements() * sizeof(float), NoOpDeleter);
allocated_pointers.push_back(dynamic_double_data);
}

BnTable* table = BnTableRegistry::Instance()->Get(table_handle_);
table->Append(acc_buf, true);

for (auto ptr : allocated_pointers) {
delete[] ptr;
}
allocated_pointers.clear();

if(synchronized_){
PsCluster* cluster = PsCluster::Instance();
OP_REQUIRES_ASYNC( c, true == cluster->IsInitialized(),
Expand Down
8 changes: 6 additions & 2 deletions core/main/py_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,14 @@ PYBIND11_MODULE(_pywrap_tn, m) {

return py::reinterpret_steal<py::object>(obj);
})
.def("create_sparse_table", [](py::object obj, std::string name, int dimension) {
.def("create_sparse_table", [](py::object obj, std::string name, int dimension, bool use_cvm) {
OptimizerBase* opt =
static_cast<OptimizerBase*>(PyCapsule_GetPointer(obj.ptr(), nullptr));

opt->SetUseCvm(use_cvm);

std::cout << "Cvm plugin is: " << opt->ShouldUseCvm() << std::endl;

PsCluster* cluster = PsCluster::Instance();

SparseTable* table = CreateSparseTable(opt, name, dimension, cluster->RankNum(), cluster->Rank());
Expand All @@ -134,7 +138,7 @@ PYBIND11_MODULE(_pywrap_tn, m) {

return table->GetHandle();
})
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, int max_count) {
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, uint64_t max_count) {
PsCluster* cluster = PsCluster::Instance();

BnTable* table = CreateBnTable(name, cluster->RankNum(), cluster->Rank(), bn_size, sync, moment, max_count);
Expand Down
3 changes: 2 additions & 1 deletion core/ps/ps_local_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ void PsLocalServer::BnStatisticsPullAsync(brpc::Controller *cntl,
Callback done) const {
BnTable *table = BnTableRegistry::Instance()->Get(request->table_handle());
CHECK(nullptr != table);
response->set_table_handle(request->table_handle());
butil::IOBuf& bn_statistics_buf = cntl->response_attachment();
table->GetStatistics(request, bn_statistics_buf, response);
table->GetIncStatistics(bn_statistics_buf);

done();
}
Expand Down
79 changes: 50 additions & 29 deletions core/ps/table/bn_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

namespace tensornet {

BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, int max_count)
BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count)
: shard_num_(shard_num)
, self_shard_id_(self_shard_id)
, name_(name)
Expand All @@ -38,7 +38,9 @@ BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int
, max_count_(max_count)
, bn_size_(bn_size) {
total_sum_.setZero(bn_size);
total_sum_err_.setZero(bn_size);
total_squared_sum_.setZero(bn_size);
total_squared_sum_err_.setZero(bn_size);
total_count_.setZero(bn_size);
inc_sum_.setZero(bn_size);
inc_squared_sum_.setZero(bn_size);
Expand All @@ -54,35 +56,54 @@ void BnTable::SetHandle(uint32_t handle) {

void BnTable::Append(butil::IOBuf& bn_statistics_buf, bool isLocal) {
const std::lock_guard<std::mutex> lock(*mu_);
Eigen::ArrayXf acc_sum = Eigen::ArrayXf::Zero(bn_size_);
Eigen::ArrayXf acc_squared_sum = Eigen::ArrayXf::Zero(bn_size_);
Eigen::ArrayXf acc_count = Eigen::ArrayXf::Zero(bn_size_);
Eigen::ArrayXd acc_sum = Eigen::ArrayXd::Zero(bn_size_);
Eigen::ArrayXd acc_squared_sum = Eigen::ArrayXd::Zero(bn_size_);
Eigen::ArrayXd acc_count = Eigen::ArrayXd::Zero(bn_size_);

bn_statistics_buf.cutn(acc_sum.data(), acc_sum.size() * sizeof(float));
bn_statistics_buf.cutn(acc_squared_sum.data(), acc_squared_sum.size() * sizeof(float));
bn_statistics_buf.cutn(acc_count.data(), acc_count.size() * sizeof(float));
bn_statistics_buf.cutn(acc_sum.data(), acc_sum.size() * sizeof(double));
bn_statistics_buf.cutn(acc_squared_sum.data(), acc_squared_sum.size() * sizeof(double));
bn_statistics_buf.cutn(acc_count.data(), acc_count.size() * sizeof(double));
CHECK_EQ(bn_statistics_buf.size(), 0);

if(synchronized_ && isLocal){
if(isLocal){
inc_sum_ += acc_sum;
inc_squared_sum_ += acc_squared_sum;
inc_count_ += acc_count;
}

int cur_count = static_cast<int>(total_count_.maxCoeff());
if(cur_count > max_count_) {
int acc_count_num = static_cast<int>(acc_count.maxCoeff());
float ratio = (float) acc_count_num / cur_count;
total_sum_ = total_sum_ * (1 - (1 - moment_) * ratio) + (1 - moment_) * ratio * acc_sum;
total_squared_sum_ = total_squared_sum_ * (1 - (1 - moment_) * ratio) + (1 - moment_) * ratio * acc_squared_sum;

uint64_t cur_count = static_cast<uint64_t>(total_count_.maxCoeff());

// std::cout << "cur_count is : " << cur_count << std::endl;
// PrintDetail();
// std::cout << "acc_count is : " << acc_count(0) << std::endl;
if(max_count_ > 0 && cur_count > max_count_) {
uint64_t acc_count_num = static_cast<uint64_t>(acc_count.maxCoeff());
double ratio = (double) acc_count_num / cur_count;
total_sum_ *= (1 - (1 - moment_) * ratio);
TotalSumAcc((1 - moment_) * ratio * acc_sum);
total_squared_sum_ *= (1 - (1 - moment_) * ratio);
TotalSquareSumAcc((1 - moment_) * ratio * acc_squared_sum);
} else {

total_sum_ += acc_sum;
total_squared_sum_ += acc_squared_sum;
total_count_ += acc_count;
TotalSumAcc(acc_sum);
TotalSquareSumAcc(acc_squared_sum);
total_count_ += acc_count;
}
}

void BnTable::TotalSquareSumAcc(Eigen::ArrayXd acc){
Eigen::ArrayXd y = acc - total_squared_sum_err_;
Eigen::ArrayXd t = total_squared_sum_ + y;
total_squared_sum_err_ = (t - total_squared_sum_) - y;
total_squared_sum_ = t;
}

void BnTable::TotalSumAcc(Eigen::ArrayXd acc){
Eigen::ArrayXd y = acc - total_sum_err_;
Eigen::ArrayXd t = total_sum_ + y;
total_sum_err_ = (t - total_sum_) - y;
total_sum_ = t;
}


std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments() {
Eigen::ArrayXf global_mean = DivideNoNan(total_sum_, total_count_);
Expand All @@ -93,15 +114,15 @@ std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments() {

void BnTable::GetStatistics(const BnStatisticsPullRequest* req, butil::IOBuf& bn_statistics_buf, BnStatisticsPullResponse* resp) {
resp->set_table_handle(req->table_handle());
bn_statistics_buf.append(total_sum_.data(), total_sum_.size() * sizeof(float));
bn_statistics_buf.append(total_squared_sum_.data(), total_squared_sum_.size() * sizeof(float));
bn_statistics_buf.append(total_count_.data(), total_count_.size() * sizeof(float));
bn_statistics_buf.append(total_sum_.data(), total_sum_.size() * sizeof(double));
bn_statistics_buf.append(total_squared_sum_.data(), total_squared_sum_.size() * sizeof(double));
bn_statistics_buf.append(total_count_.data(), total_count_.size() * sizeof(double));
}

void BnTable::GetIncStatistics(butil::IOBuf& bn_statistics_buf) {
bn_statistics_buf.append(inc_sum_.data(), inc_sum_.size() * sizeof(float));
bn_statistics_buf.append(inc_squared_sum_.data(), inc_squared_sum_.size() * sizeof(float));
bn_statistics_buf.append(inc_count_.data(), inc_count_.size() * sizeof(float));
bn_statistics_buf.append(inc_sum_.data(), inc_sum_.size() * sizeof(double));
bn_statistics_buf.append(inc_squared_sum_.data(), inc_squared_sum_.size() * sizeof(double));
bn_statistics_buf.append(inc_count_.data(), inc_count_.size() * sizeof(double));
inc_sum_.setZero();
inc_squared_sum_.setZero();
inc_count_.setZero();
Expand All @@ -119,16 +140,16 @@ void BnTable::Refresh() {
}


Eigen::ArrayXf BnTable::DivideNoNan(const Eigen::ArrayXf& numerator, const Eigen::ArrayXf& denominator) {
Eigen::ArrayXf result = numerator;
Eigen::ArrayXf BnTable::DivideNoNan(const Eigen::ArrayXd& numerator, const Eigen::ArrayXd& denominator) {
Eigen::ArrayXd result = numerator;
for (int i = 0; i < numerator.size(); ++i) {
if (!std::isnan(denominator(i)) && denominator(i) != 0.0) {
result(i) = numerator(i) / denominator(i);
} else {
result(i) = 0.0;
}
}
return result;
return result.cast<float>();
}

void BnTable::PrintDetail(){
Expand Down Expand Up @@ -199,7 +220,7 @@ uint32_t BnTableRegistry::Register(BnTable* table) {
return table_handle;
}

BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, int max_count) {
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count) {
BnTable* table = new BnTable(name, shard_num, self_shard_id, bn_size, sync, moment, max_count);

table->SetHandle(BnTableRegistry::Instance()->Register(table));
Expand Down
24 changes: 14 additions & 10 deletions core/ps/table/bn_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace tensornet {

class BnTable {
public:
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, int max_count);
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);

~BnTable() = default;

Expand All @@ -44,8 +44,10 @@ class BnTable {
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> GetMoments();
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> GetIncMoments();

Eigen::ArrayXf DivideNoNan(const Eigen::ArrayXf& numerator, const Eigen::ArrayXf& denominator);
Eigen::ArrayXf DivideNoNan(const Eigen::ArrayXd& numerator, const Eigen::ArrayXd& denominator);

void TotalSumAcc(Eigen::ArrayXd acc_sum);
void TotalSquareSumAcc(Eigen::ArrayXd acc_square_sum);
void Save(const std::string& filepath);
void Load(const std::string& filepath);

Expand All @@ -67,13 +69,15 @@ class BnTable {
uint32_t bn_size_ = 0;
bool synchronized_ = false;
float moment_ = 0.0;
int max_count_ = 0;
Eigen::ArrayXf total_sum_;
Eigen::ArrayXf total_squared_sum_;
Eigen::ArrayXf total_count_;
Eigen::ArrayXf inc_sum_;
Eigen::ArrayXf inc_squared_sum_;
Eigen::ArrayXf inc_count_;
uint64_t max_count_ = 0;
Eigen::ArrayXd total_sum_;
Eigen::ArrayXd total_sum_err_;
Eigen::ArrayXd total_squared_sum_;
Eigen::ArrayXd total_squared_sum_err_;
Eigen::ArrayXd total_count_;
Eigen::ArrayXd inc_sum_;
Eigen::ArrayXd inc_squared_sum_;
Eigen::ArrayXd inc_count_;
std::unique_ptr<std::mutex> mu_;

};
Expand All @@ -100,7 +104,7 @@ class BnTableRegistry {
std::vector<BnTable*> tables_;
};

BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, int max_count);
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);

} // namespace tensornet

Expand Down

0 comments on commit 4d44aa1

Please sign in to comment.