Skip to content

Commit

Permalink
SINGA-21 Code Review-2
Browse files Browse the repository at this point in the history
Rebase to lastest master. This pull request should be at the frontest.
Tested with mnist and cifar10, with different cluster settings.
  • Loading branch information
nudles committed Jun 24, 2015
1 parent cfde471 commit 7d39f88
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 53 deletions.
1 change: 1 addition & 0 deletions examples/cifar10/cluster.conf
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ nserver_groups: 1
nservers_per_group: 1
nworkers_per_group: 1
nworkers_per_procs: 1
nservers_per_procs: 1
workspace: "examples/cifar10/"
22 changes: 22 additions & 0 deletions examples/mnist/Makefile.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
libs :=singa glog protobuf

.PHONY: all download create

download: mnist

mnist:
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
gunzip train-images-idx3-ubyte.gz && gunzip train-labels-idx1-ubyte.gz
gunzip t10k-images-idx3-ubyte.gz && gunzip t10k-labels-idx1-ubyte.gz

create:
$(CXX) create_shard.cc -std=c++11 -lsinga -lprotobuf -lglog -I../../include \
-L../../.libs/ -Wl,-unresolved-symbols=ignore-in-shared-libs -Wl,-rpath=../../.libs/ \
-o create_shard.bin
mkdir mnist_train_shard
mkdir mnist_test_shard
./create_shard.bin train-images-idx3-ubyte train-labels-idx1-ubyte mnist_train_shard
./create_shard.bin t10k-images-idx3-ubyte t10k-labels-idx1-ubyte mnist_test_shard
4 changes: 3 additions & 1 deletion examples/mnist/cluster.conf
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ nworker_groups: 1
nserver_groups: 1
nservers_per_group: 1
nworkers_per_group: 1
workspace: "examples/cifar10/"
nservers_per_procs: 1
nworkers_per_procs: 1
workspace: "examples/mnist/"
2 changes: 1 addition & 1 deletion examples/mnist/model.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: "deep-big-simple-mlp"
train_steps: 10000
train_steps: 1000
test_steps:10
test_frequency:60
display_frequency:30
Expand Down
2 changes: 1 addition & 1 deletion include/trainer/trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Trainer{
protected:

vector<shared_ptr<Server>> CreateServers(int nthread, const ModelProto& mproto,
const vector<int> slices, vector<HandleContext>* ctx);
const vector<int> slices, vector<HandleContext*>* ctx);
vector<shared_ptr<Worker>> CreateWorkers(int nthread,
const ModelProto& mproto, vector<int> *slice_size);

Expand Down
9 changes: 3 additions & 6 deletions include/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

namespace singa {

std::string IntVecToString(const std::vector<int>& vec) ;
std::string VStringPrintf(std::string fmt, va_list l) ;
std::string StringPrintf(std::string fmt, ...) ;
void ReadProtoFromTextFile(const char* filename,
google::protobuf::Message* proto);
void WriteProtoToTextFile(const google::protobuf::Message& proto,
Expand All @@ -17,13 +20,7 @@ void ReadProtoFromBinaryFile(const char* filename,
google::protobuf::Message* proto);
void WriteProtoToBinaryFile(const google::protobuf::Message& proto,
const char* filename);
std::string IntVecToString(const std::vector<int>& vec);
std::string StringPrintf(std::string fmt, ...);
inline float rand_real() {
return static_cast<float>(rand()) / (RAND_MAX + 1.0f);
}

<<<<<<< HEAD
/*
inline void Sleep(int millisec=1){
std::this_thread::sleep_for(std::chrono::milliseconds(millisec));
Expand Down
23 changes: 0 additions & 23 deletions src/proto/cluster.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,3 @@ message ServerTopology {
// neighbor group id
repeated int32 neighbor = 3;
}
enum MsgType{
kGet=0;
kPut=1;
kSync=2;
kUpdate=3;
kSyncRequest=4;
kSyncResponse=5;
kStop=6;
kData=7;
kRGet=8;
kRUpdate=9;
kConnect=10;
kMetric=11;
};

enum EntityType{
kWorkerParam=0;
kWorkerLayer=1;
kServer=2;
kStub=3;
kRuntime=4;
};

30 changes: 16 additions & 14 deletions src/trainer/trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ void HandleWorkerFinish(void * ctx){

const std::unordered_map<int, vector<std::pair<int, int>>> SliceParams(int num,
const vector<shared_ptr<Param>>& params){
CHECK_GT(num,0);
std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices;
if (num==0)
return paramid2slices;
vector<int> param_size;
int avg=0;
for(const auto& x:params){
Expand All @@ -43,7 +45,6 @@ const std::unordered_map<int, vector<std::pair<int, int>>> SliceParams(int num,
LOG(INFO)<<"Slicer, param avg="<<avg<<", diff= "<<diff;

int capacity=avg, sliceid=0, nbox=0;
std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices;
for(auto& param: params){
if(param->id()!=param->owner())
continue;
Expand Down Expand Up @@ -115,7 +116,7 @@ const vector<int> PartitionSlice(int num, const vector<int>& slices){
vector<shared_ptr<Server>> Trainer::CreateServers(int nthreads,
const ModelProto & mproto,
const vector<int> slices,
vector<HandleContext>* ctx){
vector<HandleContext*>* ctx){
auto cluster=Cluster::Get();
vector<shared_ptr<Server>> servers;
if(!cluster->has_server())
Expand All @@ -137,10 +138,10 @@ vector<shared_ptr<Server>> Trainer::CreateServers(int nthreads,
auto server=make_shared<Server>(nthreads++, gid, sid);
server->Setup(mproto.updater(), server_shard_, slice2group);
servers.push_back(server);
HandleContext hc{dealer, gid, sid};
auto *hc=new HandleContext{dealer, gid, sid};
ctx->push_back(hc);
CHECK(cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish,
&(ctx->back())));
CHECK(cluster->runtime()->WatchSGroup(gid, sid, HandleWorkerFinish,
ctx->back()));
}
}
return servers;
Expand Down Expand Up @@ -174,12 +175,12 @@ vector<shared_ptr<Worker>> Trainer::CreateWorkers(int nthreads,
auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
cluster->nworkers_per_group());
int lcm=LeastCommonMultiple(cluster->nserver_groups(), cluster->nservers_per_group());
auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size
for(auto param: net->params()){
if(param->id()==param->owner())
for(auto entry: paramid2slices[param->id()])
slice_size->push_back(entry.second);
}
auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size
for(auto param: net->params()){
if(param->id()==param->owner())
for(auto entry: paramid2slices[param->id()])
slice_size->push_back(entry.second);
}

for(int gid=gstart;gid<gend;gid++){
shared_ptr<NeuralNet> train_net, test_net, validation_net;
Expand Down Expand Up @@ -257,10 +258,11 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
// create workers
vector<int> slices;
vector<shared_ptr<Worker>> workers=CreateWorkers(nthreads, mproto, &slices);
slice2server_=PartitionSlice(cluster->nservers_per_group(), slices);
if(cluster->nserver_groups()&&cluster->nservers_per_group())
slice2server_=PartitionSlice(cluster->nservers_per_group(), slices);
nthreads+=workers.size();
// create servers
vector<HandleContext> ctx;
vector<HandleContext*> ctx;
vector<shared_ptr<Server>> servers=CreateServers(nthreads, mproto, slices,
&ctx);

Expand Down
13 changes: 6 additions & 7 deletions src/trainer/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ void Worker::Setup(const ModelProto& model,
train_net_=train_net;
modelproto_=model;
auto cluster=Cluster::Get();
if(cluster->nserver_groups()&&cluster->server_update()){
int sgid=group_id_/cluster->nworker_groups_per_server_group();
CHECK(cluster->runtime()->JoinSGroup(group_id_, worker_id_, sgid));
}else{
if(!(cluster->nserver_groups()&&cluster->server_update())){
updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
->Create("Updater"));
updater_->Init(model.updater());
Expand All @@ -33,7 +30,7 @@ void Worker::ConnectStub(shared_ptr<Dealer> dealer, EntityType type){
if(updater_==nullptr){
auto cluster=Cluster::Get();
int sgid=group_id_/cluster->nworker_groups_per_server_group();
CHECK(cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid));
CHECK(cluster->runtime()->JoinSGroup(group_id_, worker_id_, sgid));
}

dealer->Connect(kInprocRouterEndpoint);
Expand Down Expand Up @@ -93,8 +90,10 @@ void Worker::Run(){

void Worker::Stop(){
auto cluster=Cluster::Get();
int sgid=group_id_/cluster->nworker_groups_per_server_group();
cluster->runtime()->LeaveSGroup(group_id_, worker_id_, sgid);
if(updater_ == nullptr){
int sgid=group_id_/cluster->nworker_groups_per_server_group();
cluster->runtime()->LeaveSGroup(group_id_, worker_id_, sgid);
}
Msg* msg=new Msg();
msg->set_src(group_id_, worker_id_, kWorkerParam);
msg->set_dst(-1,-1, kStub);
Expand Down
1 change: 1 addition & 0 deletions src/utils/cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <fstream>
#include "utils/cluster.h"
#include "proto/cluster.pb.h"
#include "proto/common.pb.h"
#include <sys/stat.h>
#include <sys/types.h>
namespace singa {
Expand Down
35 changes: 35 additions & 0 deletions src/utils/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,41 @@ using google::protobuf::io::ZeroCopyInputStream;
using google::protobuf::Message;

const int kBufLen = 1024;
std::string IntVecToString(const vector<int>& vec) {
string disp="(";
for(int x: vec)
disp+=std::to_string(x)+", ";
return disp+")";
}
/**
* * Formatted string.
* */
string VStringPrintf(string fmt, va_list l) {
char buffer[32768];
vsnprintf(buffer, 32768, fmt.c_str(), l);
return string(buffer);
}

/**
* * Formatted string.
* */
string StringPrintf(string fmt, ...) {
va_list l;
va_start(l, fmt); //fmt.AsString().c_str());
string result = VStringPrintf(fmt, l);
va_end(l);
return result;
}

void Debug() {
int i = 0;
char hostname[256];
gethostname(hostname, sizeof(hostname));
printf("PID %d on %s ready for attach\n", getpid(), hostname);
fflush(stdout);
while (0 == i)
sleep(5);
}

// the proto related functions are from Caffe.
void ReadProtoFromTextFile(const char* filename, Message* proto) {
Expand Down

0 comments on commit 7d39f88

Please sign in to comment.