Under the Hood : parallel caffe

parallel caffe

The parallel caffe is contributed by Inspur Ltd. They claim that get a 10.49x boost from 8-Gpu instance. The parallel caffe origins from caffe with some modification:

  • framework

  • used MPI to data-parallelism

  • each MPI process run one solve

  • training code is also mostly untouched

  • use a parameter server(thread),every solve compute each parameter , update to parameter server(PS) , PS compute and download new parameter to solve.

  • class/files

  • solver/SGDSolver

  • data_layer/base_datalayer (parallel data read or distribute)

  • net (some interface and parameter update optimization)

  • other (include headfile, some interface, etc.)

We would to dissect all the modifications to get the essence of parallization.

mpi

In the mpi.h header, it introduce six functions to encapsulate the original mpi calls:

template <typename Dtype>
int caffe_mpi_send(void *buf, int count, int dest, int tag,
                    MPI_Comm comm);

int caffe_mpi_send(void *buf, int count, MPI_Datatype datatype, int dest, int tag,
                    MPI_Comm comm);

template <typename Dtype>
int caffe_mpi_recv(void *buf, int count,  int source, int tag,
                    MPI_Comm comm, MPI_Status *status);

int caffe_mpi_recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
                    MPI_Comm comm, MPI_Status *status);

template <typename Dtype>
int caffe_mpi_isend(void *buf, int count, int dest, int tag,
                    MPI_Comm comm, MPI_Request *req);

int caffe_mpi_isend(void *buf, int count, MPI_Datatype datatype, int dest, int tag,
                    MPI_Comm comm, MPI_Request *req);

The template version functions is to deal with the types which are not covered by mpi.

In the mpi.cpp file, these function definitions are straight forward.

template<>
int caffe_mpi_send<double>(void *buf, int count,  int dest, int tag,
                    MPI_Comm comm) {
    return MPI_Send(buf, count, MPI_DOUBLE, dest, tag,
                    comm);
}

int caffe_mpi_send(void *buf, int count, MPI_Datatype datatype, int dest, int tag,
                    MPI_Comm comm) {
    return MPI_Send(buf, count, datatype, dest, tag,
                    comm);
}
template<>
int caffe_mpi_recv<double>(void *buf, int count,  int dest, int tag,
                    MPI_Comm comm, MPI_Status *status) {
    return MPI_Recv(buf, count, MPI_DOUBLE, dest, tag,
                    comm, status);
}

int caffe_mpi_recv(void *buf, int count, MPI_Datatype datatype, int dest, int tag,
                    MPI_Comm comm, MPI_Status *status) {
    return MPI_Recv(buf, count, datatype, dest, tag,
                    comm, status);
}
template <>
int caffe_mpi_isend<double>(void *buf, int count, int dest, int tag,
                    MPI_Comm comm, MPI_Request *req) {
    return MPI_Isend(buf, count, MPI_DOUBLE, dest, tag,comm, req);
}

int caffe_mpi_isend(void *buf, int count, MPI_Datatype datatype, int dest, int tag,
                    MPI_Comm comm, MPI_Request *req) {
    return MPI_Isend(buf, count, datatype, dest, tag,comm, req);
}

The template functions are instantiated with float and double, while in the code above i just show the double version. The MPI_Isend is a asynchronized function call, so it won't block current sender thread. We don't need a MPI_Irecv function because the reciever thread is separated from network thread and the reciever would only recieve one message from rank 0.

base_datalayer

In the parallel caffe, the base_datalayer add some data members :

int datum_channels_;
int datum_height_;
int datum_width_;
int datum_size_;
Blob<Dtype> data_mean_;
const Dtype* mean_;
int rank;

Well, through the mess I only know what does rank mean: the rank of mechine. So the data is feed to network in parallel. Each machine has a replication of the whole network.

As for the datum thing, they are origined from older caffe version . The newest caffe put the datum thing into memoryDataLayer. The base_datalayer has defined some get function:

int datum_channels() const { return datum_channels_; }
int datum_height() const { return datum_height_; }
int datum_width() const { return datum_width_; }
int datum_size() const { return datum_size_; }

In the BaseDataLayer declaration, the parallel caffe add some code to differentiate the combination of cpu/gpu and root/test mode. These functions may be used by old caffe version.

virtual void Forward_cpu_test(const vector<Blob<Dtype>*>& bottom,
  vector<Blob<Dtype>*>* top);
virtual void Forward_cpu_root(const vector<Blob<Dtype>*>& bottom,
  vector<Blob<Dtype>*>* top,const int source);
virtual void Forward_gpu_test(const vector<Blob<Dtype>*>& bottom,
  vector<Blob<Dtype>*>* top);
virtual void Forward_gpu_root(const vector<Blob<Dtype>*>& bottom,
  vector<Blob<Dtype>*>* top, const int source);
int rank;

And it add a data member rank as well. But why would you need it when the father class already has the same data member?

There are some modifications in the base_data_layer.cpp. Firstly, we get the mpi rank in BaseDataLayer constructor and in LayerSetUp by:

MPI_Comm_rank (MPI_COMM_WORLD, &rank);

In the Forward_cpu definition, it calls the caffe_mpi_recv to recieve data from rank 0

MPI_Status status;
    status.MPI_ERROR=0;
caffe_mpi_recv<Dtype>((*top)[0]->mutable_cpu_data(),prefetch_data_.count(),
            0,TAG_DATA_OUT,MPI_COMM_WORLD,&status);
DLOG(INFO)<<"Recv Dataout status "<<status.MPI_ERROR;
if (this->output_labels_) {
    caffe_mpi_recv<Dtype>((*top)[1]->mutable_cpu_data(),prefetch_label_.count(),
            0,TAG_DATA_OUT_IF,MPI_COMM_WORLD,&status);
DLOG(INFO)<<"Recv Dataout if status "<<status.MPI_ERROR;

While in the Forward_cpu_root, the server actively send the data and label to clients.

template<typename Dtype>
void BasePrefetchingDataLayer<Dtype>::Forward_cpu_root(
        const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top,
        const int source)
{
    switch (this->layer_param_.data_param().backend())
    {
        case DataParameter_DB_LEVELDB:
        {
            Forward_cpu_test(bottom, top);
            caffe_mpi_send < Dtype
                    > ((*top)[0]->mutable_cpu_data(), prefetch_data_.count(), source, TAG_DATA_OUT, MPI_COMM_WORLD);
            if (this->output_labels_)
            {
                caffe_mpi_send < Dtype
                        > ((*top)[1]->mutable_cpu_data(), prefetch_label_.count(), source, TAG_DATA_OUT_IF, MPI_COMM_WORLD);
            }
        }
            break;
        case DataParameter_DB_LMDB:
        {
        }
            break;
        default:
            LOG(FATAL) << "Unknown database backend";
    }
}

And for gpu mode, it firstly recieve all the message on cpu, then synchronize the gpu with cpu.

SGDSolver

In the SGDSolver declaration, there are three new functions:

virtual void ComputeUpdateValueServerThread();
virtual void ComputeUpdateValueClientThread(int& mpi_source,int tid);
virtual void GetValue(int &mpi_source,const int tid);

But the bigger picture is: there a param server here.

typedef struct TPRAMA{
void* layer;
int tid;
}tprama;

class lockmutex{
public:
        lockmutex(pthread_mutex_t* mut){mutex = mut;lock();};
        ~lockmutex(){unlock();};
        void lock(){pthread_mutex_lock(mutex);};
        void unlock(){pthread_mutex_unlock(mutex);};
private:
        pthread_mutex_t *mutex;
};

class atomInt{
public:
        atomInt(int val=0){
                atomValue=val;
                pthread_rwlock_init(&rwlockAtom,NULL);
        };
        ~atomInt(){
                pthread_rwlock_destroy(&rwlockAtom);
        };
        int getValue(){
                int ret;
                pthread_rwlock_rdlock(&rwlockAtom);
                ret=atomValue;
                pthread_rwlock_unlock(&rwlockAtom);
                return ret;
        }
        int add(int val){
                int ret;
                pthread_rwlock_wrlock(&rwlockAtom);
                atomValue+=val;
                ret=atomValue;
                pthread_rwlock_unlock(&rwlockAtom);
                return ret;
        };
        int sub(int val){
                int ret;
                pthread_rwlock_wrlock(&rwlockAtom);
                atomValue-=val;
                ret=atomValue;
                pthread_rwlock_unlock(&rwlockAtom);
                return ret;
        };
private:
        int atomValue;
        pthread_rwlock_t rwlockAtom;
};

#define WAIT_SEC (3)
#define WAIT_USEC (0)

It defines a atomInt structure. But as far as I'm concerned, doesn't c++11 already have std::atomic_int?

In the solver.cpp, where major modification lie, it begins with some mutex, condition variable and semaphore variable definition:

std::queue<int> idleQ;
sem_t semQ;//wait program finish
int taskS1;
int upNum=0;
int upSum;
void * tempDiff=NULL;//like float/double
int *flagCC=NULL;
pthread_mutex_t mutexFin;//=PTHREAD_MUTEX_INITIALIZER;//check and wait program finish
pthread_cond_t condFin;//=PTHREAD_MUTEX_INITIALIZER;//check and wait program finish
pthread_mutex_t mutexUp=PTHREAD_MUTEX_INITIALIZER;//wait update net paramater in server thread
pthread_cond_t condUp=PTHREAD_COND_INITIALIZER;//wait update net paramater in server thread
pthread_mutex_t mutexCtrl=PTHREAD_MUTEX_INITIALIZER;//when update net paramaters finished, broadcast to send data to MPI clients
pthread_cond_t condCtrl=PTHREAD_COND_INITIALIZER;
pthread_mutex_t mutexData=PTHREAD_MUTEX_INITIALIZER;//update data and diff
pthread_mutex_t mutexQ=PTHREAD_MUTEX_INITIALIZER;//update idleQ
atomInt taskSum,taskS;

The upNum represents for how many nodes have been updated, while the upSum represents the total nodes number (commonworld.size).

The definition of GetValue is straight forward. It recieves the diff from some client. Because we can't decide the order the message comes in, we use caffe_mpi_recv any to get one message. After the header message(which is esssentially the weight of top layer ) get recieved, we can identify the message source and store it to mpi_source. Then we recieve other weights layer by layer, stored to diff , which is tempDiff[tid]. After all layer weights of one specific client are successfully transfered, we mark flagCC[tid]=1 and add upNum by one. So flagCC[i] is to indicate tempDiff[i] set or not. And upNum is the number of clients we have gotten the diffs from.

template <typename Dtype>
void SGDSolver<Dtype>::GetValue(int &mpi_source,const int tid) 
{
    MPI_Status status;
    Dtype **diff = ((Dtype***)tempDiff)[tid];
    vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
    for (int param_id = net_params.size()-1; param_id >= 0; --param_id) 
    {
        memset(&status,0,sizeof(status));
        if(param_id==net_params.size()-1)
        {
            caffe_mpi_recv<Dtype>(&diff[param_id][0],net_params[param_id]->count(),
                            MPI_ANY_SOURCE,TAG_UPDATE_1,MPI_COMM_WORLD,&status);
            mpi_source=status.MPI_SOURCE;
        }
        else
        {
            caffe_mpi_recv<Dtype>(&diff[param_id][0],net_params[param_id]->count(),
                            mpi_source,TAG_UPDATE,MPI_COMM_WORLD,&status);
        }
    }
    {
            lockmutex lockm(&mutexData);
            flagCC[tid] = 1;
            ++upNum;
            pthread_cond_broadcast(&condUp);
    }
}

What the heck are you thinking, why not use C++11. You kill the protability, you bastard!

ServerUpdate

If we want to update the net parameter in server thread, we have to call:

template <typename Dtype>
void* ComputeValueThreadServer(void* param) {
    SGDSolver<Dtype>* layer = static_cast<SGDSolver<Dtype>*>( ((tprama*) param)->layer);
    //int tid = ((tprama*)param)->tid;
    struct timeval now_time;
    struct timespec wait_time;
    int timeoutret;
    while(true){
        if(taskSum.getValue() <=0){LOG(INFO)<<"Server out"; pthread_exit(NULL); }
        gettimeofday(&now_time,NULL);
        wait_time.tv_sec      =       now_time.tv_sec + WAIT_SEC;
        wait_time.tv_nsec     =       now_time.tv_usec*1000 + WAIT_USEC;//nano seconds
        {
            lockmutex lockm(&mutexData);
            while(upNum < upSum){
                timeoutret=pthread_cond_timedwait(&condUp,&mutexData,&wait_time);
                if(timeoutret==ETIMEDOUT){
                    LOG(INFO)<<"time out " << upNum;
                    break;
                }
            }
            if(upNum>0){
                layer->ComputeValueServer();
                pthread_cond_broadcast(&condCtrl);
            }
        }
    }
}

Well, it uses pthread_cond_timewait, which would perfectly replaced by std::cond_variable.wait_for. Besides, the pthread_cond_broadcast could be easily replaced by std::cond_variable.notify_all.

The ComputeValueServer wraps the train and test job :

template <typename Dtype>
void Solver<Dtype>::ComputeValueServer(){
    ComputeUpdateValueServerThread();
    ++itest;
    if(itest % param_.test_interval() ==0)
        TestAll();
    upNum=0;
}

The ComputeUpdateValueServerThread has a pretty big body, so I will explain it bit by bit. The function is a template function, with such declaration.

template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValueServerThread()

The function's countpart in normal caffe is SGDSolver<Dtype>::ComputeUpdateValue. And it seems that the code is identical to some older version of caffe. So i'm not going to explain it. Just go back to section ~\ref{sec:solver}. But there are some modifications that are interesting.

for(int i=0;i<upSum;++i){
    if(flagCC[i]==1){
        Dtype **diff = ((Dtype***)tempDiff)[i];
        for(int param_id = 0; param_id < net_params.size(); ++param_id){
            caffe_axpy(net_params[param_id]->count(),(Dtype)1,
                    &diff[param_id][0],net_params[param_id]->mutable_cpu_diff());
        }
    }
}

The code in listing above is to gather all the diffs sent from distributed nodes. The diffs are stored in tempDiff[upSum][netsize][netcount]). Then add them to server's diff. After all the diffs are gathered, it begins to do the network update according to L1 or L2 regulation. And after the server network update, it save the new network weights back to tempDiff, which is illustrated below:

for(int i=0;i<upSum;++i){
    if(flagCC[i]==1){
        //Dtype **data = ((Dtype***)tempData)[i]; //test del tempData20150113
        Dtype **diff = ((Dtype***)tempDiff)[i];
        for(int param_id = 0; param_id < net_params.size(); ++param_id){
            caffe_copy(net_params[param_id]->count(),
                    net_params[param_id]->cpu_data(),
                    &diff[param_id][0]);
        }
    }
}

So in short, the tempDiff is the buffer to store diff/weight message communicated between Server and Client.

Client Update

As for the ComputeValueThreadClient, it seems that it issue all the synchronization work and wait for successful transmission of the messages.

template <typename Dtype>
void* ComputeValueThreadClient(void* param) {
   SGDSolver<Dtype>* layer = static_cast<SGDSolver<Dtype>*>(((tprama*)param)->layer);
  int tid = ((tprama* )param)->tid;
  CHECK(layer);
  int flagFin=0;
    if(taskSum.getValue() <=0){ LOG(INFO)<<"client task out";pthread_exit(NULL);}
while(true){
    if(taskS.getValue()<taskS1)break;
    layer->ComputeValueClient(tid);
    sem_post(&semQ);
    pthread_mutex_lock(&mutexFin);
    if(taskSum.sub(1) <=0)flagFin=1;taskS.sub(1);
    pthread_cond_signal(&condFin);
    pthread_mutex_unlock(&mutexFin);
    if(flagFin)break;
}

The taskSum record how many clients to sync with. Everytime the layer->ComputeValueClient finishes, the taskSum--. And if taskSum==0, the sync task finishes.

In the ComputeValueClient definition, it calls ComputeUpdateValueClientThread:

template <typename Dtype>
void Solver<Dtype>::ComputeValueClient(int tid){
    int mpi_source;
    ComputeUpdateValueClientThread(mpi_source,tid);
    {
        lockmutex lockm(&mutexData);
        while(upNum!=0){
            pthread_cond_wait(&condCtrl,&mutexData);
            break;
        }
        flagCC[tid]=0;
    }
    Dtype **diff = ((Dtype***)tempDiff)[tid];
    caffe_mpi_send(diff[0],1,mpiTypeDiff,mpi_source,TAG_NET_OUT,MPI_COMM_WORLD);
    pthread_mutex_lock(&mutexQ);
    idleQ.push(mpi_source);
    pthread_mutex_unlock(&mutexQ);
}

My understanding is this function revieves the diffs sent from source and wait for the server thread to update. After update,the upNum is set to 0. Then it calls caffe_mpi_send to send the updated weight to client tid. And after the sending, the tid is pushed to idleQ to mark this client is successfully updated. It seems that flagCC is equivalent to idleQ,redundant? Look that piece of shit:

while(upNum!=0)
{
    pthread_cond_wait(&condCtrl,&mutexData);
    break;
}

As for the ComputeUpdateValueClientThread(mpi_source,tid), it's a wrapper for GetValue:

template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValueClientThread(int& mpi_source,int tid)
{
        GetValue(mpi_source,tid);
}

So in short, the ComputeValueThreadClient simply get the diff from client,wait for server thread to update, then send back the updated weight.

Solver

Here comes the core of parallel caffe :void Solver<Dtype>::Solve(const char* resume_file). The leading lines are identical to original caffe.

Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
PreSolve();

iter_ = 0;
if (resume_file) {
    LOG(INFO) << "Restoring previous solver status from " << resume_file;
    Restore(resume_file);
}
// Remember the initial iter_ value; will be non-zero if we loaded from a
// resume_file above.
const int start_iter = iter_;

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();

According to different values of rank in mpi, the execution path is vastly different. We differentiate the server and client role.

Server solver

When the rank is 0(ie the server), it executes some initiation code first.The upSum is initiated to msize-1, where the msize is the MPI_Comm_size. taskSum is the total iteration number remains to be done.

pthread_mutex_init(&mutexFin,NULL);
pthread_cond_init(&condFin,NULL);
sem_init(&semQ,0,idleQ.size());
taskSum.add(param_.max_iter()-iter_);
int msize;
int tNetCount=0;
MPI_Comm_size (MPI_COMM_WORLD, &msize);
upSum= msize -1 ;
taskS.add(taskSum.getValue());
taskS1=upSum;

Now it's time to complete memory allocation:

  • flagCC is a array of int with sizeupSum.

  • tempDiff is three dimension array: the first dimension stands for MIP_Comm_size,the second dimension is layer index, the third dimension is index in layer blob(the size is tNetCount). The tempDiff is treated like a one dimension array during memory allocation.

  • netDataType is an array to record the Dtype of every layer.

  • displacement is an array to record every layer's beginning address related to layer[0].

  • blocklen is an array to record every layer's weight count.

The netDataType,displacement,blocklen are composed to define a MPI type struct mpiTypeDiff. which is commited to the MPI_COMMONWORLD.

flagCC=new int[upSum];
memset(flagCC,0,sizeof(int)*upSum);
//tempData=new Dtype**[upSum];  //test del tempData20150113
tempDiff=new Dtype**[upSum];
for(int j=0;j<net_params.size();++j)
{
    tNetCount += net_params[j]->count();
}
for(int i=0;i<upSum;++i)
{
    //((Dtype***)tempData)[i]=new Dtype*[net_params.size()]; //test del tempData20150113
    ((Dtype***)tempDiff)[i]=new Dtype*[net_params.size()];

    ((Dtype***)tempDiff)[i][0] = new Dtype[tNetCount];
    for(int j=1;j<net_params.size();++j)
    {
        ((Dtype***)tempDiff)[i][j]= ((Dtype***)tempDiff)[i][j-1]+net_params[j-1]->count();
    }
}

MPI_Datatype *netDataType=new MPI_Datatype[net_params.size()];
int *blocklen = new int[net_params.size()];
MPI_Aint *displacement = new MPI_Aint[net_params.size()];
Dtype **diff = ((Dtype***)tempDiff)[0];
for (int param_id = 0; param_id < net_params.size(); ++param_id)
{
    blocklen[param_id]=net_params[param_id]->count();
    if(typeid(Dtype)==typeid(float))
        netDataType[param_id] = MPI_FLOAT;
    else if(typeid(Dtype)==typeid(double))
        netDataType[param_id] = MPI_DOUBLE;
    else
        LOG(FATAL)<<"This datetype is not support!"<<typeid(Dtype).name();
    displacement[param_id] = (char*) diff[param_id]- (char*) diff[0];
}
MPI_Type_struct(net_params.size(),blocklen,displacement,netDataType,&mpiTypeDiff);
MPI_Type_commit(&mpiTypeDiff);
delete[] netDataType;
delete[] blocklen;
delete[] displacement;

After all the necessary memory allocation, it begins to create threads. These threads are authorized to do the message gathering and weight updating job.One thread for the ComputeValueThreadServer, one thread vector for ComputeValueThreadClient where every thread represent one client. The tid for server thread is -1, and the tid for client thread is assigned consecutively from 0 to upSum.

pthread_t threads;
pthread_t *threadc=new pthread_t[msize-1];
tprama pramas;
pramas.layer=static_cast<void*>(this);
pramas.tid=-1;
CHECK(!pthread_create(&threads, NULL, ComputeValueThreadServer<Dtype>,
            &pramas)) << "Pthread(solve) execution failed.";
tprama *pramac = new tprama[msize-1];
for(int i=0;i<upSum;++i)
{
    pramac[i].layer = static_cast<void*>(this);
    pramac[i].tid = i;
    CHECK(!pthread_create(&threadc[i], NULL, ComputeValueThreadClient<Dtype>,
                &pramac[i])) << "Pthread(solve) execution failed.";
}

From now on is the iteration forloop.Every iteration the server send a TAG_ITER to all availble clients, indicates a new iteration.

int qfront;
for (; iter_ < param_.max_iter(); ++iter_) 
{
    sem_wait(&semQ);
    pthread_mutex_lock(&mutexQ);
    if(!idleQ.empty())
    {
        qfront=idleQ.front();
        idleQ.pop();
        pthread_mutex_unlock(&mutexQ);
        caffe_mpi_send(&iter_,1,MPI_INT,qfront,TAG_ITER,MPI_COMM_WORLD);
        /*Dtype loss = */net_->ForwardBackwardRoot(bottom_vec,qfront);
    }
    else
    {
        pthread_mutex_unlock(&mutexQ);
        LOG(FATAL)<<"ERROR! idleQ is empty!";
    }
}

After all iterations are finished, it send the iteration end message to all mpi clients.Then it begins the test.Finally mpi resources and thread resources are freed.

pthread_mutex_lock(&mutexFin);
while(taskSum.getValue()>0)
{
    pthread_cond_wait(&condFin,&mutexFin);
    LOG(INFO)<<"TaskSum "<<taskSum.getValue();
}
pthread_mutex_unlock(&mutexFin);
pthread_mutex_destroy(&mutexFin);
pthread_cond_destroy(&condFin);
while(!idleQ.empty())
{
    int flagFin= -1;
    caffe_mpi_send(&flagFin,1,MPI_INT,idleQ.front(),TAG_ITER,MPI_COMM_WORLD);
    idleQ.pop();
}
TestAll();
sleep(WAIT_SEC);
for(int i=0;i<upSum;++i)
{
    pthread_cancel(threadc[i]);
}
pthread_cancel(threads);
delete[] threadc;
delete[] pramac;
LOG(INFO)<<"DESTROY "<< (pthread_mutex_destroy(&mutexData));

Then the server do the final Snapshot and test work.

if (param_.snapshot_after_train()) { Snapshot(); }
if (param_.display() && iter_ % param_.display() == 0) {
    Dtype loss;
    net_->taskiter=0;
    net_->ForwardTest(bottom_vec, &loss);
    LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
    TestAll();
}
LOG(INFO) << "Optimization Done.";

Client solver

The client part is way simpler than server. All codes are contained in a while(true) enviroment. Firstly, it recieve the iter to get which iteration it is. If the iteration is -1, the loop ends.And it's the only way to end the loop.

MPI_Status status;
status.MPI_ERROR=0;
caffe_mpi_recv(&iter_,1,MPI_INT,0,TAG_ITER,MPI_COMM_WORLD,&status);
if(iter_== -1)break;

After the iteration recv code, here comes some MPI struct registeration. The piece of code is roughly the same in server. But the structure is shit. We can lift these code out of the while loop.

Then it's ComputeUpdateValueClient. This function trains a batch and send the diff to server. Then it waits to recieve new weights from server.

ComputeUpdateValueClient();
memset(&status,0,sizeof(status));
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
for (int param_id = 0; param_id < net_params.size(); ++param_id)
{
    net_params[param_id]->mutable_cpu_data();
}
caffe_mpi_recv(net_params[0]->mutable_cpu_data(),1,
        mpiTypeCpuData,0,TAG_NET_OUT,MPI_COMM_WORLD,&status);
if (param_.snapshot() && iter_ > start_iter &&
                    iter_ % param_.snapshot() == 0) 
{
    Snapshot();//TODO
}
for (int param_id = 0; param_id < net_params.size(); ++param_id) 
{
    if(param_id==0)
    {
        caffe_mpi_send<Dtype>(net_params[param_id]->mutable_cpu_diff(),net_params[param_id]->count(),0,TAG_UPDATE_1,MPI_COMM_WORLD);
    }
    else
    {
        caffe_mpi_send<Dtype>(net_params[param_id]->mutable_cpu_diff(),net_params[param_id]->count(),0,TAG_UPDATE,MPI_COMM_WORLD);
    }
}

The rest code is just plain enough.It just output the loss . Nothing to talk about.

error handle

In the distributed world, we can't assure ourself that everything would works as expected. The most disastrous thing would be some nodes drop out. So we can't take it for granted that we would recieve upSum diffs from the distributed clients. So we have flagCC[] vector to indicate clients' state. If the tid client sends the diff, we store it at tempDiff[tid] and mark flagCC[tid]=1.

And when we are waiting for income diffs, we can't expect we would get all clients. So we use a pthread_cond_timewait. And we adjust the update policy:

caffe_scal(net_params[param_id]->count(), (Dtype)(1.0 / upNum),
            net_params[param_id]->mutable_cpu_diff());
Published:
2015-04-21 21:31
Category:
Tag: