parallel caffe
The parallel caffe is contributed by Inspur Ltd. They claim that get a 10.49x boost from 8-Gpu instance. The parallel caffe origins from caffe with some modification:
-
framework
-
used MPI to data-parallelism
-
each MPI process run one solve
-
training code is also mostly untouched
-
use a parameter server(thread),every solve compute each parameter , update to parameter server(PS) , PS compute and download new parameter to solve.
-
class/files
-
solver/SGDSolver
-
data_layer/base_datalayer (parallel data read or distribute)
-
net (some interface and parameter update optimization)
-
other (include headfile, some interface, etc.)
We would to dissect all the modifications to get the essence of parallization.
mpi
In the mpi.h header, it introduce six functions to encapsulate the original mpi calls:
template <typename Dtype> int caffe_mpi_send(void *buf, int count, int dest, int tag, MPI_Comm comm); int caffe_mpi_send(void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm); template <typename Dtype> int caffe_mpi_recv(void *buf, int count, int source, int tag, MPI_Comm comm, MPI_Status *status); int caffe_mpi_recv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status *status); template <typename Dtype> int caffe_mpi_isend(void *buf, int count, int dest, int tag, MPI_Comm comm, MPI_Request *req); int caffe_mpi_isend(void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *req);
The template
version functions is to deal with the types which are not covered by mpi.
In the mpi.cpp file, these function definitions are straight forward.
template<> int caffe_mpi_send<double>(void *buf, int count, int dest, int tag, MPI_Comm comm) { return MPI_Send(buf, count, MPI_DOUBLE, dest, tag, comm); } int caffe_mpi_send(void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm) { return MPI_Send(buf, count, datatype, dest, tag, comm); } template<> int caffe_mpi_recv<double>(void *buf, int count, int dest, int tag, MPI_Comm comm, MPI_Status *status) { return MPI_Recv(buf, count, MPI_DOUBLE, dest, tag, comm, status); } int caffe_mpi_recv(void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Status *status) { return MPI_Recv(buf, count, datatype, dest, tag, comm, status); } template <> int caffe_mpi_isend<double>(void *buf, int count, int dest, int tag, MPI_Comm comm, MPI_Request *req) { return MPI_Isend(buf, count, MPI_DOUBLE, dest, tag,comm, req); } int caffe_mpi_isend(void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *req) { return MPI_Isend(buf, count, datatype, dest, tag,comm, req); }
The template functions are instantiated with float
and double
, while in the code above i just show the double
version. The MPI_Isend
is a asynchronized function call, so it won't block current sender thread. We don't need a MPI_Irecv
function because the reciever thread is separated from network thread and the reciever would only recieve one message from rank 0.
base_datalayer
In the parallel caffe, the base_datalayer
add some data members :
int datum_channels_; int datum_height_; int datum_width_; int datum_size_; Blob<Dtype> data_mean_; const Dtype* mean_; int rank;
Well, through the mess I only know what does rank
mean: the rank of mechine. So the data is feed to network in parallel. Each machine has a replication of the whole network.
As for the datum thing, they are origined from older caffe version . The newest caffe put the datum thing into memoryDataLayer
. The base_datalayer
has defined some get
function:
int datum_channels() const { return datum_channels_; } int datum_height() const { return datum_height_; } int datum_width() const { return datum_width_; } int datum_size() const { return datum_size_; }
In the BaseDataLayer
declaration, the parallel caffe add some code to differentiate the combination of cpu/gpu and root/test mode. These functions may be used by old caffe version.
virtual void Forward_cpu_test(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top); virtual void Forward_cpu_root(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top,const int source); virtual void Forward_gpu_test(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top); virtual void Forward_gpu_root(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top, const int source); int rank;
And it add a data member rank
as well. But why would you need it when the father class already has the same data member?
There are some modifications in the base_data_layer.cpp. Firstly, we get the mpi rank
in BaseDataLayer
constructor and in LayerSetUp
by:
MPI_Comm_rank (MPI_COMM_WORLD, &rank);
In the Forward_cpu
definition, it calls the caffe_mpi_recv
to recieve data from rank
0
MPI_Status status; status.MPI_ERROR=0; caffe_mpi_recv<Dtype>((*top)[0]->mutable_cpu_data(),prefetch_data_.count(), 0,TAG_DATA_OUT,MPI_COMM_WORLD,&status); DLOG(INFO)<<"Recv Dataout status "<<status.MPI_ERROR; if (this->output_labels_) { caffe_mpi_recv<Dtype>((*top)[1]->mutable_cpu_data(),prefetch_label_.count(), 0,TAG_DATA_OUT_IF,MPI_COMM_WORLD,&status); DLOG(INFO)<<"Recv Dataout if status "<<status.MPI_ERROR;
While in the Forward_cpu_root
, the server actively send the data and label to clients.
template<typename Dtype> void BasePrefetchingDataLayer<Dtype>::Forward_cpu_root( const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top, const int source) { switch (this->layer_param_.data_param().backend()) { case DataParameter_DB_LEVELDB: { Forward_cpu_test(bottom, top); caffe_mpi_send < Dtype > ((*top)[0]->mutable_cpu_data(), prefetch_data_.count(), source, TAG_DATA_OUT, MPI_COMM_WORLD); if (this->output_labels_) { caffe_mpi_send < Dtype > ((*top)[1]->mutable_cpu_data(), prefetch_label_.count(), source, TAG_DATA_OUT_IF, MPI_COMM_WORLD); } } break; case DataParameter_DB_LMDB: { } break; default: LOG(FATAL) << "Unknown database backend"; } }
And for gpu mode, it firstly recieve all the message on cpu, then synchronize the gpu with cpu.
SGDSolver
In the SGDSolver
declaration, there are three new functions:
virtual void ComputeUpdateValueServerThread(); virtual void ComputeUpdateValueClientThread(int& mpi_source,int tid); virtual void GetValue(int &mpi_source,const int tid);
But the bigger picture is: there a param server here.
typedef struct TPRAMA{ void* layer; int tid; }tprama; class lockmutex{ public: lockmutex(pthread_mutex_t* mut){mutex = mut;lock();}; ~lockmutex(){unlock();}; void lock(){pthread_mutex_lock(mutex);}; void unlock(){pthread_mutex_unlock(mutex);}; private: pthread_mutex_t *mutex; }; class atomInt{ public: atomInt(int val=0){ atomValue=val; pthread_rwlock_init(&rwlockAtom,NULL); }; ~atomInt(){ pthread_rwlock_destroy(&rwlockAtom); }; int getValue(){ int ret; pthread_rwlock_rdlock(&rwlockAtom); ret=atomValue; pthread_rwlock_unlock(&rwlockAtom); return ret; } int add(int val){ int ret; pthread_rwlock_wrlock(&rwlockAtom); atomValue+=val; ret=atomValue; pthread_rwlock_unlock(&rwlockAtom); return ret; }; int sub(int val){ int ret; pthread_rwlock_wrlock(&rwlockAtom); atomValue-=val; ret=atomValue; pthread_rwlock_unlock(&rwlockAtom); return ret; }; private: int atomValue; pthread_rwlock_t rwlockAtom; }; #define WAIT_SEC (3) #define WAIT_USEC (0)
It defines a atomInt
structure. But as far as I'm concerned, doesn't c++11 already have std::atomic_int
?
In the solver.cpp, where major modification lie, it begins with some mutex
, condition variable
and semaphore
variable definition:
std::queue<int> idleQ; sem_t semQ;//wait program finish int taskS1; int upNum=0; int upSum; void * tempDiff=NULL;//like float/double int *flagCC=NULL; pthread_mutex_t mutexFin;//=PTHREAD_MUTEX_INITIALIZER;//check and wait program finish pthread_cond_t condFin;//=PTHREAD_MUTEX_INITIALIZER;//check and wait program finish pthread_mutex_t mutexUp=PTHREAD_MUTEX_INITIALIZER;//wait update net paramater in server thread pthread_cond_t condUp=PTHREAD_COND_INITIALIZER;//wait update net paramater in server thread pthread_mutex_t mutexCtrl=PTHREAD_MUTEX_INITIALIZER;//when update net paramaters finished, broadcast to send data to MPI clients pthread_cond_t condCtrl=PTHREAD_COND_INITIALIZER; pthread_mutex_t mutexData=PTHREAD_MUTEX_INITIALIZER;//update data and diff pthread_mutex_t mutexQ=PTHREAD_MUTEX_INITIALIZER;//update idleQ atomInt taskSum,taskS;
The upNum
represents for how many nodes have been updated, while the upSum
represents the total nodes number (commonworld.size).
The definition of GetValue
is straight forward. It recieves the diff from some client. Because we can't decide the order the message comes in, we use caffe_mpi_recv any
to get one message. After the header message(which is esssentially the weight of top layer ) get recieved, we can identify the message source and store it to mpi_source
. Then we recieve other weights layer by layer, stored to diff
, which is tempDiff[tid]
. After all layer weights of one specific client are successfully transfered, we mark flagCC[tid]=1
and add upNum
by one. So flagCC[i]
is to indicate tempDiff[i]
set or not. And upNum
is the number of clients we have gotten the diffs from.
template <typename Dtype> void SGDSolver<Dtype>::GetValue(int &mpi_source,const int tid) { MPI_Status status; Dtype **diff = ((Dtype***)tempDiff)[tid]; vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params(); for (int param_id = net_params.size()-1; param_id >= 0; --param_id) { memset(&status,0,sizeof(status)); if(param_id==net_params.size()-1) { caffe_mpi_recv<Dtype>(&diff[param_id][0],net_params[param_id]->count(), MPI_ANY_SOURCE,TAG_UPDATE_1,MPI_COMM_WORLD,&status); mpi_source=status.MPI_SOURCE; } else { caffe_mpi_recv<Dtype>(&diff[param_id][0],net_params[param_id]->count(), mpi_source,TAG_UPDATE,MPI_COMM_WORLD,&status); } } { lockmutex lockm(&mutexData); flagCC[tid] = 1; ++upNum; pthread_cond_broadcast(&condUp); } }
What the heck are you thinking, why not use C++11. You kill the protability, you bastard!
ServerUpdate
If we want to update the net parameter in server thread, we have to call:
template <typename Dtype> void* ComputeValueThreadServer(void* param) { SGDSolver<Dtype>* layer = static_cast<SGDSolver<Dtype>*>( ((tprama*) param)->layer); //int tid = ((tprama*)param)->tid; struct timeval now_time; struct timespec wait_time; int timeoutret; while(true){ if(taskSum.getValue() <=0){LOG(INFO)<<"Server out"; pthread_exit(NULL); } gettimeofday(&now_time,NULL); wait_time.tv_sec = now_time.tv_sec + WAIT_SEC; wait_time.tv_nsec = now_time.tv_usec*1000 + WAIT_USEC;//nano seconds { lockmutex lockm(&mutexData); while(upNum < upSum){ timeoutret=pthread_cond_timedwait(&condUp,&mutexData,&wait_time); if(timeoutret==ETIMEDOUT){ LOG(INFO)<<"time out " << upNum; break; } } if(upNum>0){ layer->ComputeValueServer(); pthread_cond_broadcast(&condCtrl); } } } }
Well, it uses pthread_cond_timewait
, which would perfectly replaced by std::cond_variable.wait_for
. Besides, the pthread_cond_broadcast
could be easily replaced by std::cond_variable.notify_all
.
The ComputeValueServer
wraps the train and test job :
template <typename Dtype> void Solver<Dtype>::ComputeValueServer(){ ComputeUpdateValueServerThread(); ++itest; if(itest % param_.test_interval() ==0) TestAll(); upNum=0; }
The ComputeUpdateValueServerThread
has a pretty big body, so I will explain it bit by bit. The function is a template function, with such declaration.
template <typename Dtype> void SGDSolver<Dtype>::ComputeUpdateValueServerThread()
The function's countpart in normal caffe is SGDSolver<Dtype>::ComputeUpdateValue
. And it seems that the code is identical to some older version of caffe. So i'm not going to explain it. Just go back to section ~\ref{sec:solver}. But there are some modifications that are interesting.
for(int i=0;i<upSum;++i){ if(flagCC[i]==1){ Dtype **diff = ((Dtype***)tempDiff)[i]; for(int param_id = 0; param_id < net_params.size(); ++param_id){ caffe_axpy(net_params[param_id]->count(),(Dtype)1, &diff[param_id][0],net_params[param_id]->mutable_cpu_diff()); } } }
The code in listing above is to gather all the diffs sent from distributed nodes. The diffs are stored in tempDiff[upSum][netsize][netcount]
). Then add them to server's diff. After all the diffs are gathered, it begins to do the network update according to L1
or L2
regulation. And after the server network update, it save the new network weights back to tempDiff
, which is illustrated below:
for(int i=0;i<upSum;++i){ if(flagCC[i]==1){ //Dtype **data = ((Dtype***)tempData)[i]; //test del tempData20150113 Dtype **diff = ((Dtype***)tempDiff)[i]; for(int param_id = 0; param_id < net_params.size(); ++param_id){ caffe_copy(net_params[param_id]->count(), net_params[param_id]->cpu_data(), &diff[param_id][0]); } } }
So in short, the tempDiff
is the buffer to store diff/weight message communicated between Server and Client.
Client Update
As for the ComputeValueThreadClient
, it seems that it issue all the synchronization work and wait for successful transmission of the messages.
template <typename Dtype> void* ComputeValueThreadClient(void* param) { SGDSolver<Dtype>* layer = static_cast<SGDSolver<Dtype>*>(((tprama*)param)->layer); int tid = ((tprama* )param)->tid; CHECK(layer); int flagFin=0; if(taskSum.getValue() <=0){ LOG(INFO)<<"client task out";pthread_exit(NULL);} while(true){ if(taskS.getValue()<taskS1)break; layer->ComputeValueClient(tid); sem_post(&semQ); pthread_mutex_lock(&mutexFin); if(taskSum.sub(1) <=0)flagFin=1;taskS.sub(1); pthread_cond_signal(&condFin); pthread_mutex_unlock(&mutexFin); if(flagFin)break; }
The taskSum
record how many clients to sync with. Everytime the layer->ComputeValueClient
finishes, the taskSum--
. And if taskSum==0
, the sync task finishes.
In the ComputeValueClient
definition, it calls ComputeUpdateValueClientThread
:
template <typename Dtype> void Solver<Dtype>::ComputeValueClient(int tid){ int mpi_source; ComputeUpdateValueClientThread(mpi_source,tid); { lockmutex lockm(&mutexData); while(upNum!=0){ pthread_cond_wait(&condCtrl,&mutexData); break; } flagCC[tid]=0; } Dtype **diff = ((Dtype***)tempDiff)[tid]; caffe_mpi_send(diff[0],1,mpiTypeDiff,mpi_source,TAG_NET_OUT,MPI_COMM_WORLD); pthread_mutex_lock(&mutexQ); idleQ.push(mpi_source); pthread_mutex_unlock(&mutexQ); }
My understanding is this function revieves the diffs sent from source
and wait for the server thread to update. After update,the upNum
is set to 0
. Then it calls caffe_mpi_send
to send the updated weight to client tid
. And after the sending, the tid
is pushed to idleQ
to mark this client is successfully updated. It seems that flagCC
is equivalent to idleQ
,redundant? Look that piece of shit:
while(upNum!=0) { pthread_cond_wait(&condCtrl,&mutexData); break; }
As for the ComputeUpdateValueClientThread(mpi_source,tid)
, it's a wrapper for GetValue
:
template <typename Dtype> void SGDSolver<Dtype>::ComputeUpdateValueClientThread(int& mpi_source,int tid) { GetValue(mpi_source,tid); }
So in short, the ComputeValueThreadClient
simply get the diff from client,wait for server thread to update, then send back the updated weight.
Solver
Here comes the core of parallel caffe :void Solver<Dtype>::Solve(const char* resume_file)
. The leading lines are identical to original caffe.
Caffe::set_phase(Caffe::TRAIN); LOG(INFO) << "Solving " << net_->name(); PreSolve(); iter_ = 0; if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Restore(resume_file); } // Remember the initial iter_ value; will be non-zero if we loaded from a // resume_file above. const int start_iter = iter_; // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. vector<Blob<Dtype>*> bottom_vec; vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
According to different values of rank
in mpi, the execution path is vastly different. We differentiate the server and client role.
Server solver
When the rank is 0(ie the server), it executes some initiation code first.The upSum
is initiated to msize-1
, where the msize
is the MPI_Comm_size
. taskSum
is the total iteration number remains to be done.
pthread_mutex_init(&mutexFin,NULL); pthread_cond_init(&condFin,NULL); sem_init(&semQ,0,idleQ.size()); taskSum.add(param_.max_iter()-iter_); int msize; int tNetCount=0; MPI_Comm_size (MPI_COMM_WORLD, &msize); upSum= msize -1 ; taskS.add(taskSum.getValue()); taskS1=upSum;
Now it's time to complete memory allocation:
-
flagCC
is a array of int with sizeupSum
. -
tempDiff
is three dimension array: the first dimension stands forMIP_Comm_size
,the second dimension is layer index, the third dimension is index in layer blob(the size istNetCount
). ThetempDiff
is treated like a one dimension array during memory allocation. -
netDataType
is an array to record theDtype
of every layer. -
displacement
is an array to record every layer's beginning address related tolayer[0]
. -
blocklen
is an array to record every layer's weight count.
The netDataType
,displacement
,blocklen
are composed to define a MPI type struct mpiTypeDiff
. which is commited to the MPI_COMMONWORLD.
flagCC=new int[upSum]; memset(flagCC,0,sizeof(int)*upSum); //tempData=new Dtype**[upSum]; //test del tempData20150113 tempDiff=new Dtype**[upSum]; for(int j=0;j<net_params.size();++j) { tNetCount += net_params[j]->count(); } for(int i=0;i<upSum;++i) { //((Dtype***)tempData)[i]=new Dtype*[net_params.size()]; //test del tempData20150113 ((Dtype***)tempDiff)[i]=new Dtype*[net_params.size()]; ((Dtype***)tempDiff)[i][0] = new Dtype[tNetCount]; for(int j=1;j<net_params.size();++j) { ((Dtype***)tempDiff)[i][j]= ((Dtype***)tempDiff)[i][j-1]+net_params[j-1]->count(); } } MPI_Datatype *netDataType=new MPI_Datatype[net_params.size()]; int *blocklen = new int[net_params.size()]; MPI_Aint *displacement = new MPI_Aint[net_params.size()]; Dtype **diff = ((Dtype***)tempDiff)[0]; for (int param_id = 0; param_id < net_params.size(); ++param_id) { blocklen[param_id]=net_params[param_id]->count(); if(typeid(Dtype)==typeid(float)) netDataType[param_id] = MPI_FLOAT; else if(typeid(Dtype)==typeid(double)) netDataType[param_id] = MPI_DOUBLE; else LOG(FATAL)<<"This datetype is not support!"<<typeid(Dtype).name(); displacement[param_id] = (char*) diff[param_id]- (char*) diff[0]; } MPI_Type_struct(net_params.size(),blocklen,displacement,netDataType,&mpiTypeDiff); MPI_Type_commit(&mpiTypeDiff); delete[] netDataType; delete[] blocklen; delete[] displacement;
After all the necessary memory allocation, it begins to create threads. These threads are authorized to do the message gathering and weight updating job.One thread for the ComputeValueThreadServer
, one thread vector for ComputeValueThreadClient
where every thread represent one client. The tid for server thread is -1
, and the tid for client thread is assigned consecutively from 0
to upSum
.
pthread_t threads; pthread_t *threadc=new pthread_t[msize-1]; tprama pramas; pramas.layer=static_cast<void*>(this); pramas.tid=-1; CHECK(!pthread_create(&threads, NULL, ComputeValueThreadServer<Dtype>, &pramas)) << "Pthread(solve) execution failed."; tprama *pramac = new tprama[msize-1]; for(int i=0;i<upSum;++i) { pramac[i].layer = static_cast<void*>(this); pramac[i].tid = i; CHECK(!pthread_create(&threadc[i], NULL, ComputeValueThreadClient<Dtype>, &pramac[i])) << "Pthread(solve) execution failed."; }
From now on is the iteration forloop.Every iteration the server send a TAG_ITER
to all availble clients, indicates a new iteration.
int qfront; for (; iter_ < param_.max_iter(); ++iter_) { sem_wait(&semQ); pthread_mutex_lock(&mutexQ); if(!idleQ.empty()) { qfront=idleQ.front(); idleQ.pop(); pthread_mutex_unlock(&mutexQ); caffe_mpi_send(&iter_,1,MPI_INT,qfront,TAG_ITER,MPI_COMM_WORLD); /*Dtype loss = */net_->ForwardBackwardRoot(bottom_vec,qfront); } else { pthread_mutex_unlock(&mutexQ); LOG(FATAL)<<"ERROR! idleQ is empty!"; } }
After all iterations are finished, it send the iteration end message to all mpi clients.Then it begins the test.Finally mpi resources and thread resources are freed.
pthread_mutex_lock(&mutexFin); while(taskSum.getValue()>0) { pthread_cond_wait(&condFin,&mutexFin); LOG(INFO)<<"TaskSum "<<taskSum.getValue(); } pthread_mutex_unlock(&mutexFin); pthread_mutex_destroy(&mutexFin); pthread_cond_destroy(&condFin); while(!idleQ.empty()) { int flagFin= -1; caffe_mpi_send(&flagFin,1,MPI_INT,idleQ.front(),TAG_ITER,MPI_COMM_WORLD); idleQ.pop(); } TestAll(); sleep(WAIT_SEC); for(int i=0;i<upSum;++i) { pthread_cancel(threadc[i]); } pthread_cancel(threads); delete[] threadc; delete[] pramac; LOG(INFO)<<"DESTROY "<< (pthread_mutex_destroy(&mutexData));
Then the server do the final Snapshot and test work.
if (param_.snapshot_after_train()) { Snapshot(); } if (param_.display() && iter_ % param_.display() == 0) { Dtype loss; net_->taskiter=0; net_->ForwardTest(bottom_vec, &loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; } if (param_.test_interval() && iter_ % param_.test_interval() == 0) { TestAll(); } LOG(INFO) << "Optimization Done.";
Client solver
The client part is way simpler than server. All codes are contained in a while(true)
enviroment. Firstly, it recieve the iter
to get which iteration it is. If the iteration is -1
, the loop ends.And it's the only way to end the loop.
MPI_Status status; status.MPI_ERROR=0; caffe_mpi_recv(&iter_,1,MPI_INT,0,TAG_ITER,MPI_COMM_WORLD,&status); if(iter_== -1)break;
After the iteration recv code, here comes some MPI struct registeration. The piece of code is roughly the same in server. But the structure is shit. We can lift these code out of the while loop.
Then it's ComputeUpdateValueClient
. This function trains a batch and send the diff to server. Then it waits to recieve new weights from server.
ComputeUpdateValueClient(); memset(&status,0,sizeof(status)); vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params(); for (int param_id = 0; param_id < net_params.size(); ++param_id) { net_params[param_id]->mutable_cpu_data(); } caffe_mpi_recv(net_params[0]->mutable_cpu_data(),1, mpiTypeCpuData,0,TAG_NET_OUT,MPI_COMM_WORLD,&status); if (param_.snapshot() && iter_ > start_iter && iter_ % param_.snapshot() == 0) { Snapshot();//TODO } for (int param_id = 0; param_id < net_params.size(); ++param_id) { if(param_id==0) { caffe_mpi_send<Dtype>(net_params[param_id]->mutable_cpu_diff(),net_params[param_id]->count(),0,TAG_UPDATE_1,MPI_COMM_WORLD); } else { caffe_mpi_send<Dtype>(net_params[param_id]->mutable_cpu_diff(),net_params[param_id]->count(),0,TAG_UPDATE,MPI_COMM_WORLD); } }
The rest code is just plain enough.It just output the loss . Nothing to talk about.
error handle
In the distributed world, we can't assure ourself that everything would works as expected. The most disastrous thing would be some nodes drop out. So we can't take it for granted that we would recieve upSum
diffs from the distributed clients. So we have flagCC[]
vector to indicate clients' state. If the tid
client sends the diff, we store it at tempDiff[tid]
and mark flagCC[tid]=1
.
And when we are waiting for income diffs, we can't expect we would get all clients. So we use a pthread_cond_timewait
. And we adjust the update policy:
caffe_scal(net_params[param_id]->count(), (Dtype)(1.0 / upNum), net_params[param_id]->mutable_cpu_diff());