singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [18/19] incubator-singa git commit: SINGA-100 Implement layers using CUDNN for GPU training
Date Wed, 16 Dec 2015 12:11:50 GMT
SINGA-100 Implement layers using CUDNN for GPU training

Replace CopyFrom with Copy in softmax loss layers. CopyFrom copies both
cpu and gpu memory, while Copy copies either cpu or gpu memory.

Rename CudnnLayer to CudnnBase, which was changed but missed during rebase..

Fix bugs from getting stuck using GPU (CUDA) operations.
Possible causes include,
1. different threads access the same GPU memory addr. The non-owner thread would get stuck.
2. cudnnCreate gets stuck if the GPU device is used by other programs.
3. CudnnSoftmaxLossLayer::Setup forgets to call LossLayer::Setup which leaves layer_conf_
un-initialized.

TODO
1. move cudnn headers and libs before cuda in compiling, otherwise there would be
link conflicts if there are multiple versions of cudnn in the system.
2. replace malloc with cudaMalloc in blob.cc(h) which pins the CPU
memory for efficient and stable memory transferring between CPU and GPU.

Tested with multiple GPUs and single GPU.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f3b47c70
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f3b47c70
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f3b47c70

Branch: refs/heads/master
Commit: f3b47c706e09315f2eb8eb1d599597093c980962
Parents: 372245f
Author: Wei Wang <wangwei@comp.nus.edu.sg>
Authored: Mon Dec 14 22:27:12 2015 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Wed Dec 16 19:20:44 2015 +0800

----------------------------------------------------------------------
 Makefile.gpu                                    | 16 ++++----
 include/singa/neuralnet/neuron_layer.h          | 20 ++++-----
 include/singa/utils/blob.h                      |  3 ++
 include/singa/utils/context.h                   | 43 ++++++++++++++++++++
 include/singa/utils/math_addr.h                 |  4 +-
 include/singa/utils/math_blob.h                 |  2 +
 src/neuralnet/loss_layer/cudnn_softmaxloss.cc   | 13 +++---
 src/neuralnet/loss_layer/softmax.cc             |  4 +-
 src/neuralnet/neuron_layer/cudnn_activation.cc  |  2 +-
 src/neuralnet/neuron_layer/cudnn_convolution.cc |  6 ++-
 src/neuralnet/neuron_layer/cudnn_lrn.cc         |  2 +-
 src/neuralnet/neuron_layer/cudnn_pooling.cc     |  2 +-
 src/neuralnet/neuron_layer/cudnn_softmax.cc     |  2 +-
 src/server.cc                                   |  1 +
 src/stub.cc                                     |  8 ++--
 src/utils/blob.cc                               | 12 +++---
 src/utils/param.cc                              | 12 +++---
 src/worker.cc                                   | 25 +++++++++++-
 18 files changed, 123 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/Makefile.gpu
----------------------------------------------------------------------
diff --git a/Makefile.gpu b/Makefile.gpu
index 9cbbc42..c35a445 100644
--- a/Makefile.gpu
+++ b/Makefile.gpu
@@ -20,17 +20,17 @@
 
 ###################User Config Varaibles #############################
 # third-party library installation folder
-HOME_DIR := /home/wangwei/local
-
-CUDA_DIR := /usr/local/cuda
+HOME_DIR := /media/hd1/home/wangwei/local
 
 # must config the cudnn folder if using cudnn
-CUDNN_DIR := /home/wangwei/local/cuda
+CUDNN_DIR := $(HOME_DIR)/cudnn
+
+CUDA_DIR := /usr/local/cuda
 
 # Lib folder for system and external libs. You may need to change it.
-LIBRARY_DIRS := $(HOME_DIR)/lib64 $(HOME_DIR)/lib $(HOME_DIR)/local/lib $(CUDA_DIR)/lib64
$(CUDA_DIR)/lib $(CUDNN_DIR)/lib64
+LIBRARY_DIRS := $(HOME_DIR)/lib64 $(HOME_DIR)/lib $(CUDNN_DIR)/lib64 $(CUDA_DIR)/lib64 $(CUDA_DIR)/lib
 # Header folder for system and external libs. You may need to change it.
-INCLUDE_DIRS := $(HOME_DIR)/include ./include $(HOME_DIR)/local/include/zookeeper $(CUDA_DIR)/include
$(CUDNN_DIR)/include
+INCLUDE_DIRS := ./include $(HOME_DIR)/include $(CUDNN_DIR)/include $(CUDA_DIR)/include 
 # g++ location, should support c++11, tested with 4.8.1
 CXX := g++
 CUCXX := nvcc
@@ -49,7 +49,7 @@ BUILD_DIR := .libs
 MSHADOW_FLAGS :=-DMSHADOW_USE_CUDA=0 -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0
 ZK_FLAGS :=-DTHREADED -fpermissive
 CXXFLAGS := -O2 -msse3 -Wall -pthread -fPIC -std=c++11 -Wno-unknown-pragmas \
-	$(MSHADOW_FLAGS) $(ZK_FLAGS)\
+	$(MSHADOW_FLAGS) -DUSE_CUDNN $(ZK_FLAGS)\
 	-funroll-loops $(foreach includedir, $(INCLUDE_DIRS), -I$(includedir))
 CUCXXFLAGS := -DUSE_CUDNN $(MSHADOW_FLAGS) -std=c++11 $(CUDA_ARCH) \
 	$(foreach includedir, $(INCLUDE_DIRS), -I$(includedir))
@@ -103,7 +103,7 @@ singa: $(PROTO_OBJS) $(SINGA_OBJS) $(SINGA_CUDA_OBJS)
 	$(CXX) -shared -o $(BUILD_DIR)/libsinga.so $(SINGA_OBJS)
 	$(CXX) $(SINGA_OBJS) $(SINGA_CUDA_OBJS) src/main.cc -o singa $(CXXFLAGS) $(LDFLAGS)
 	@echo
-	$(CXX) $(SINGA_OBJS) $(SINGA_CUDA_OBJS) src/utils/tool.cc -o singatool $(CXXFLAGS) $(LDFLAGS)
+	$(CXX) $(BUILD_DIR)/libsinga.so src/utils/tool.cc -o singatool $(CXXFLAGS) $(LDFLAGS) -Wl,-unresolved-symbols=ignore-in-shared-libs
 	@echo
 
 loader: proto $(LOADER_OBJS)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h
index 17891e1..393eeb4 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -25,6 +25,8 @@
 #include <vector>
 #include "singa/neuralnet/layer.h"
 #include "singa/proto/job.pb.h"
+#include "singa/utils/context.h"
+#include "singa/utils/singleton.h"
 
 #ifdef USE_CUDNN
 #include <cudnn.h>
@@ -274,11 +276,9 @@ class STanhLayer : public NeuronLayer {
 #ifdef USE_CUDNN
 #define CHECK_CUDNN(x) CHECK_EQ(x, CUDNN_STATUS_SUCCESS)
 
-class CudnnLayer : virtual public NeuronLayer {
+class CudnnBase : virtual public NeuronLayer {
  public:
-  ~CudnnLayer() {
-    if (handle_ != nullptr)
-      CHECK_CUDNN(cudnnDestroy(handle_));
+  ~CudnnBase() {
     if (src_desc_ != nullptr)
       CHECK_CUDNN(cudnnDestroyTensorDescriptor(src_desc_));
     if (my_desc_ != nullptr)
@@ -286,9 +286,9 @@ class CudnnLayer : virtual public NeuronLayer {
   }
   void virtual InitCudnn() {
     CHECK(!has_init_cudnn_);
-    CHECK_CUDNN(cudnnCreate(&handle_));
     CHECK_CUDNN(cudnnCreateTensorDescriptor(&src_desc_));
     CHECK_CUDNN(cudnnCreateTensorDescriptor(&my_desc_));
+    handle_ = Singleton<Context>::Instance()->cudnn_handle();
     has_init_cudnn_ = true;
   }
  protected:
@@ -304,7 +304,7 @@ class CudnnLayer : virtual public NeuronLayer {
  * - TANH
  * - RELU
  */
-class CudnnActivationLayer : public ActivationLayer, public CudnnLayer {
+class CudnnActivationLayer : public ActivationLayer, public CudnnBase {
  public:
   void InitCudnn() override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
@@ -317,7 +317,7 @@ class CudnnActivationLayer : public ActivationLayer, public CudnnLayer
{
 /**
  * Convolution layer implemeneted using cudnn (v3 version backward functions).
  */
-class CudnnConvLayer : public ConvolutionLayer, public CudnnLayer {
+class CudnnConvLayer : public ConvolutionLayer, public CudnnBase {
  public:
   ~CudnnConvLayer();
   void InitCudnn() override;
@@ -334,7 +334,7 @@ class CudnnConvLayer : public ConvolutionLayer, public CudnnLayer {
   size_t workspace_byte_limit_, workspace_count_;
 };
 
-class CudnnLRNLayer : public LRNLayer, public CudnnLayer {
+class CudnnLRNLayer : public LRNLayer, public CudnnBase {
  public:
   ~CudnnLRNLayer();
   void InitCudnn() override;
@@ -348,7 +348,7 @@ class CudnnLRNLayer : public LRNLayer, public CudnnLayer {
 /**
  * Pooling layer implemented using cudnn.
  */
-class CudnnPoolLayer : public PoolingLayer, public CudnnLayer {
+class CudnnPoolLayer : public PoolingLayer, public CudnnBase {
  public:
   ~CudnnPoolLayer();
   void InitCudnn() override;
@@ -362,7 +362,7 @@ class CudnnPoolLayer : public PoolingLayer, public CudnnLayer {
 /**
  * Cudnn Softmax layer.
  */
-class CudnnSoftmaxLayer : public SoftmaxLayer, public CudnnLayer {
+class CudnnSoftmaxLayer : public SoftmaxLayer, public CudnnBase {
  public:
   void InitCudnn() override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/include/singa/utils/blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/blob.h b/include/singa/utils/blob.h
index 7e1e516..256400e 100644
--- a/include/singa/utils/blob.h
+++ b/include/singa/utils/blob.h
@@ -71,12 +71,15 @@
 
 namespace singa {
 
+// TODO(wangwei) use cudaMallocHost depending on Context::device.
 inline void MallocHost(void** ptr, size_t size) {
   *ptr = malloc(size);
+  // cudaMallocHost(ptr, size);
 }
 
 inline void FreeHost(void* ptr) {
   free(ptr);
+  // cudaFreeHost(ptr);
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/include/singa/utils/context.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/context.h b/include/singa/utils/context.h
index c23b338..5d49dc1 100644
--- a/include/singa/utils/context.h
+++ b/include/singa/utils/context.h
@@ -31,6 +31,11 @@
 
 #ifdef USE_GPU
 #include "singa/utils/cuda_utils.h"
+
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#endif
+
 #endif
 
 namespace singa {
@@ -64,6 +69,13 @@ class Context {
         }
       }
     }
+#ifdef USE_CUDNN
+    for (auto& handle : cudnn_handle_) {
+      if (handle != nullptr)
+        CHECK_EQ(cudnnDestroy(handle), CUDNN_STATUS_SUCCESS);
+      handle = nullptr;
+    }
+#endif
 #endif
     for (auto& entry : rand_generator_) {
       if (entry.second != nullptr) {
@@ -80,6 +92,9 @@ class Context {
 #ifdef USE_GPU
       cublas_handle_.push_back(nullptr);
       curand_generator_.push_back(nullptr);
+#ifdef USE_CUDNN
+      cudnn_handle_.push_back(nullptr);
+#endif
 #endif
     }
   }
@@ -191,6 +206,7 @@ class Context {
    */
   curandGenerator_t curand_generator(const int device_id) {
     CHECK_GE(device_id, 0);
+    CHECK_LT(device_id, cudnn_handle_.size());
     if (curand_generator_.at(device_id) == nullptr) {
       // TODO(wangwei) handle user set seed
       /*
@@ -204,6 +220,28 @@ class Context {
     return curand_generator_[device_id];
   }
 
+#ifdef USE_CUDNN
+  cudnnHandle_t cudnn_handle() {
+    return cudnn_handle(std::this_thread::get_id());
+  }
+
+  cudnnHandle_t cudnn_handle(const std::thread::id thread_id) {
+    return cudnn_handle(device_id(thread_id));
+  }
+
+  cudnnHandle_t cudnn_handle(const int device_id) {
+    CHECK_GE(device_id, 0);
+    CHECK_LT(device_id, cudnn_handle_.size());
+    if (cudnn_handle_.at(device_id) == nullptr) {
+      ActivateDevice(device_id);
+      // LOG(ERROR) << "create cudnn handle for device " << device_id;
+      CHECK_EQ(cudnnCreate(&cudnn_handle_[device_id]), CUDNN_STATUS_SUCCESS);
+    }
+    // LOG(ERROR) << "use cudnn handle from device " << device_id;
+    return cudnn_handle_[device_id];
+  }
+#endif
+
 #endif
 
  protected:
@@ -220,7 +258,12 @@ class Context {
   std::vector<cublasHandle_t> cublas_handle_;
   //!< cublas rand generator indexed by GPU device ID
   std::vector<curandGenerator_t> curand_generator_;
+
+#ifdef USE_CUDNN
+  std::vector<cudnnHandle_t> cudnn_handle_;
 #endif
+#endif
+
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/include/singa/utils/math_addr.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_addr.h b/include/singa/utils/math_addr.h
index d641251..524e13e 100644
--- a/include/singa/utils/math_addr.h
+++ b/include/singa/utils/math_addr.h
@@ -198,8 +198,8 @@ void gpu_gemv(cublasHandle_t handle, const Dtype * A, const Dtype * B,
 }
 
 template<typename Dtype>
-void gpu_axpy(cublasHandle_t handle, const Dtype * A, const int n,
-    const Dtype alpha, Dtype * B) {
+void gpu_axpy(cublasHandle_t handle, const int n, const Dtype alpha,
+    const Dtype * A, Dtype * B) {
   cublasSaxpy(handle, n, &alpha, A, 1, B, 1);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/include/singa/utils/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_blob.h b/include/singa/utils/math_blob.h
index bc38bd4..8807bdb 100644
--- a/include/singa/utils/math_blob.h
+++ b/include/singa/utils/math_blob.h
@@ -344,6 +344,8 @@ void Copy(const Blob<Dtype>& A, Blob<Dtype>* B) {
     std::copy(A.cpu_data(), A.cpu_data() + A.count(), B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
+  CUDA_CHECK(cudaMemcpy(static_cast<Dtype*>(B->mutable_gpu_data()),
+             A.gpu_data(), sizeof(Dtype) * A.count(), cudaMemcpyDefault));
 #endif
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/loss_layer/cudnn_softmaxloss.cc b/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
index 1fe228b..0d4ba45 100644
--- a/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
+++ b/src/neuralnet/loss_layer/cudnn_softmaxloss.cc
@@ -27,30 +27,26 @@
 namespace singa {
 void CudnnSoftmaxLossLayer::Setup(const LayerProto& conf,
     const vector<Layer*>& srclayers) {
+  LossLayer::Setup(conf, srclayers);
   softmax_.Setup(conf, vector<Layer*> {srclayers.at(0)});
   data_.Reshape(softmax_.data(this).shape());
   data_.ShareData(softmax_.mutable_data(this), false);
   batchsize_ = data_.shape(0);
   dim_ = data_.count() / batchsize_;
-  LOG(ERROR) << batchsize_ << " " << dim_;
 }
 void CudnnSoftmaxLossLayer::ComputeFeature(int flag,
     const vector<Layer*>& srclayers) {
   softmax_.ComputeFeature(flag, srclayers);
-  // compute loss
   Blob<int> label(batchsize_);
   int *labelptr = label.mutable_cpu_data();
-
   // aux_data: vector<int>, convert vector to int array.
   for (int i = 0; i < batchsize_; ++i) {
     labelptr[i] = srclayers[1]->aux_data(this)[i];
   }
 
   Blob<float> loss(batchsize_);
-
-  const float *prob = data_.gpu_data();
-  singa_gpu_softmaxloss_forward(batchsize_, dim_, prob, label.gpu_data(),
-      loss.mutable_gpu_data());
+  singa_gpu_softmaxloss_forward(batchsize_, dim_, data_.gpu_data(),
+      label.gpu_data(), loss.mutable_gpu_data());
   loss_ += Asum(loss);
   counter_++;
 }
@@ -58,7 +54,8 @@ void CudnnSoftmaxLossLayer::ComputeFeature(int flag,
 void CudnnSoftmaxLossLayer::ComputeGradient(int flag,
     const vector<Layer*>& srclayers) {
   Blob<float>* gsrcblob = srclayers[0]->mutable_grad(this);
-  gsrcblob->CopyFrom(data_);
+  Copy(data_, gsrcblob);
+  // gsrcblob->CopyFrom(data_);
   float* gsrcptr = gsrcblob->mutable_gpu_data();
 
   Blob<int> label(batchsize_);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/neuralnet/loss_layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/loss_layer/softmax.cc b/src/neuralnet/loss_layer/softmax.cc
index ce858ac..5a3d257 100644
--- a/src/neuralnet/loss_layer/softmax.cc
+++ b/src/neuralnet/loss_layer/softmax.cc
@@ -24,6 +24,7 @@
 #include <algorithm>
 #include "singa/neuralnet/loss_layer.h"
 #include "mshadow/tensor.h"
+#include "singa/utils/math_blob.h"
 
 namespace singa {
 
@@ -88,7 +89,8 @@ void SoftmaxLossLayer::ComputeGradient(int flag,
     const vector<Layer*>& srclayers) {
   const auto& label = srclayers[1]->aux_data();
   Blob<float>* gsrcblob = srclayers[0]->mutable_grad(this);
-  gsrcblob->CopyFrom(data_);
+  Copy(data_, gsrcblob);
+//  gsrcblob->CopyFrom(data_);
   float* gsrcptr = gsrcblob->mutable_cpu_data();
   for (int n = 0; n < batchsize_; n++) {
     gsrcptr[n*dim_ + static_cast<int>(label[n])] -= 1.0f;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/neuralnet/neuron_layer/cudnn_activation.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_activation.cc b/src/neuralnet/neuron_layer/cudnn_activation.cc
index 9088d84..5405b53 100644
--- a/src/neuralnet/neuron_layer/cudnn_activation.cc
+++ b/src/neuralnet/neuron_layer/cudnn_activation.cc
@@ -24,7 +24,7 @@
 namespace singa {
 
 void CudnnActivationLayer::InitCudnn() {
-  CudnnLayer::InitCudnn();
+  CudnnBase::InitCudnn();
 
   // TODO(wangwei) make the mode case insensitive
   if (layer_conf_.activation_conf().type() == SIGMOID)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/neuralnet/neuron_layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_convolution.cc b/src/neuralnet/neuron_layer/cudnn_convolution.cc
index 2cf376c..ca67a7d 100644
--- a/src/neuralnet/neuron_layer/cudnn_convolution.cc
+++ b/src/neuralnet/neuron_layer/cudnn_convolution.cc
@@ -32,7 +32,7 @@ CudnnConvLayer::~CudnnConvLayer() {
 }
 
 void CudnnConvLayer::InitCudnn() {
-  CudnnLayer::InitCudnn();
+  CudnnBase::InitCudnn();
   // convert MB to bytes
   workspace_byte_limit_
     = layer_conf_.convolution_conf().workspace_byte_limit() << 20;
@@ -149,7 +149,6 @@ void CudnnConvLayer::ComputeFeature(int flag, const vector<Layer*>&
srclayers) {
         &beta,
         my_desc_,
         data_.mutable_gpu_data()));
-
   if (bias_) {
     beta = 1.f;
     CHECK_CUDNN(cudnnAddTensor(handle_,
@@ -167,6 +166,7 @@ void
 CudnnConvLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
   float alpha = 1.f, beta = 0.f;
   Blob<float> workspace(vector<int>{static_cast<int>(workspace_count_)});
+  // LOG(ERROR) << "backward bias";
   if (bias_) {
     CHECK_CUDNN(cudnnConvolutionBackwardBias(handle_,
           &alpha,
@@ -176,6 +176,7 @@ CudnnConvLayer::ComputeGradient(int flag, const vector<Layer*>&
srclayers) {
           bias_desc_,
           bias_->mutable_grad()->mutable_gpu_data()));
   }
+  // LOG(ERROR) << "backward w";
   CHECK_CUDNN(cudnnConvolutionBackwardFilter_v3(handle_,
         &alpha,
         src_desc_,
@@ -189,6 +190,7 @@ CudnnConvLayer::ComputeGradient(int flag, const vector<Layer*>&
srclayers) {
         &beta,
         filter_desc_,
         weight_->mutable_grad()->mutable_gpu_data()));
+  // LOG(ERROR) << "backward src";
   if (srclayers[0]->mutable_grad(this) != nullptr) {
     CHECK_CUDNN(cudnnConvolutionBackwardData_v3(handle_,
           &alpha,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/neuralnet/neuron_layer/cudnn_lrn.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_lrn.cc b/src/neuralnet/neuron_layer/cudnn_lrn.cc
index f976b16..8237b13 100644
--- a/src/neuralnet/neuron_layer/cudnn_lrn.cc
+++ b/src/neuralnet/neuron_layer/cudnn_lrn.cc
@@ -30,7 +30,7 @@ CudnnLRNLayer::~CudnnLRNLayer() {
 
 void CudnnLRNLayer::InitCudnn() {
   mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
-  CudnnLayer::InitCudnn();
+  CudnnBase::InitCudnn();
   CHECK_CUDNN(cudnnCreateLRNDescriptor(&norm_desc_));
   CHECK_CUDNN(cudnnSetLRNDescriptor(norm_desc_,
         lsize_,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/neuralnet/neuron_layer/cudnn_pooling.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_pooling.cc b/src/neuralnet/neuron_layer/cudnn_pooling.cc
index 217df51..4c4c038 100644
--- a/src/neuralnet/neuron_layer/cudnn_pooling.cc
+++ b/src/neuralnet/neuron_layer/cudnn_pooling.cc
@@ -30,7 +30,7 @@ CudnnPoolLayer::~CudnnPoolLayer() {
 }
 
 void CudnnPoolLayer::InitCudnn() {
-  CudnnLayer::InitCudnn();
+  CudnnBase::InitCudnn();
   CHECK_CUDNN(cudnnCreatePoolingDescriptor(&pool_desc_));
   CHECK_CUDNN(cudnnSetTensor4dDescriptor(src_desc_,
         CUDNN_TENSOR_NCHW,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/neuralnet/neuron_layer/cudnn_softmax.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/cudnn_softmax.cc b/src/neuralnet/neuron_layer/cudnn_softmax.cc
index a1a492e..bf5a8d3 100644
--- a/src/neuralnet/neuron_layer/cudnn_softmax.cc
+++ b/src/neuralnet/neuron_layer/cudnn_softmax.cc
@@ -24,7 +24,7 @@
 namespace singa {
 
 void CudnnSoftmaxLayer::InitCudnn() {
-  CudnnLayer::InitCudnn();
+  CudnnBase::InitCudnn();
   CHECK_CUDNN(cudnnSetTensor4dDescriptor(src_desc_,
         CUDNN_TENSOR_NCHW,
         CUDNN_DATA_FLOAT,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/server.cc
----------------------------------------------------------------------
diff --git a/src/server.cc b/src/server.cc
index 003d4e3..bd7b5f8 100644
--- a/src/server.cc
+++ b/src/server.cc
@@ -200,6 +200,7 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
     auto param = entry->shares.at(0);
     // extract and aggregate gradients
     param->ParseUpdateMsgs(request);
+    // DLOG(ERROR) << "update param " << param->id() << " @ step " <<
step;
     updater_->Update(step, param, 1.0f / entry->num_total);
     param->set_version(param->version() + 1);
     // response to all shares of this param

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/stub.cc
----------------------------------------------------------------------
diff --git a/src/stub.cc b/src/stub.cc
index 184a1f3..82526de 100644
--- a/src/stub.cc
+++ b/src/stub.cc
@@ -26,11 +26,11 @@
 #include <map>
 #include <thread>
 #include <set>
-#include "mshadow/tensor.h"
 #include "singa/proto/common.pb.h"
 #include "singa/utils/cluster.h"
 #include "singa/utils/common.h"
 #include "singa/utils/tinydir.h"
+#include "singa/utils/math_blob.h"
 
 namespace singa {
 
@@ -242,11 +242,9 @@ const vector<Msg*> Stub::HandleUpdateRequest(ParamEntry *entry,
Msg** msg) {
     // average local gradient
     if (entry->num_local > 1) {
       auto it = entry->shares.begin();
-      auto shape = mshadow::Shape1((*it)->size());
-      mshadow::Tensor<mshadow::cpu, 1> sum((*it)->mutable_cpu_grad(), shape);
+      auto sum = it;
       for (++it; it != entry->shares.end(); it++) {
-        mshadow::Tensor<mshadow::cpu, 1> grad((*it)->mutable_cpu_grad(), shape);
-        sum += grad;
+        AXPY(1.0f, (*it)->grad(), (*sum)->mutable_grad());
       }
     }
     int step = (*msg)->trgt_version();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/utils/blob.cc
----------------------------------------------------------------------
diff --git a/src/utils/blob.cc b/src/utils/blob.cc
index 735226d..8e30060 100644
--- a/src/utils/blob.cc
+++ b/src/utils/blob.cc
@@ -213,13 +213,13 @@ void Blob<Dtype>::CopyFrom(const Blob& source) {
 }
 
 template <typename Dtype>
-void Blob<Dtype>::CopyFrom(const Blob& source, bool reshape) {
-  if (!std::equal(shape_.begin(), shape_.end(), source.shape_.begin())) {
-    if (reshape) {
-      Reshape(source.shape_);
-    } else {
+void Blob<Dtype>::CopyFrom(const Blob& source, bool shape_check) {
+  LOG(WARNING) << "Better use Copy(const Blob&, Blob*)";
+  CHECK_EQ (source.count(), count()) << " cp between blobs of diff size";
+
+  if (shape_check &&
+      !std::equal(shape_.begin(), shape_.end(), source.shape_.begin())) {
       LOG(FATAL) << "Trying to copy blobs of different sizes.";
-    }
   }
 #ifndef CPU_ONLY
   CUDA_CHECK(cudaMemcpy(static_cast<Dtype*>(data_->mutable_gpu_data()),

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 097fa61..70a969f 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -212,8 +212,8 @@ Msg* Param::GenPutMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
   Msg* msg = new Msg();
   msg->set_type(kPut);
-  void* ptr = mutable_cpu_data() + slice_offset_[idx];
-  void* p = ptr;
+  const void* ptr = data_.cpu_data() + slice_offset_[idx];
+  const void* p = ptr;
   if (copy) p = nullptr;
   msg->AddFormatFrame("iffp", slice_size_[idx], lr_scale(), wd_scale(), p);
   if (copy) {
@@ -226,7 +226,8 @@ Msg* Param::GenGetMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
   Msg* msg = new Msg();
   msg->set_type(kGet);
-  msg->AddFormatFrame("ip",  copy, data_.cpu_data() + slice_offset_[idx]);
+  msg->AddFormatFrame("ip",  copy, data_.mutable_cpu_data()
+      + slice_offset_[idx]);
   pending_get_[idx] = true;
   num_pending_requests_++;
   return msg;
@@ -237,10 +238,7 @@ Msg* Param::GenUpdateMsg(bool copy, int idx) {
   Msg* msg = new Msg();
   msg->set_type(kUpdate);
   msg->AddFormatFrame("i", copy);
-  void* ptr = grad_.mutable_cpu_data() + slice_offset_[idx];
-  // to change the head of SyncMem to cpu; otherwise, the updated parameter
-  //   // values would not be synced to gpu (since the head is at gpu).
-  data_.mutable_cpu_data();
+  const void* ptr = grad_.cpu_data() + slice_offset_[idx];
   if (copy) {
     msg->AddFrame(ptr, slice_size_[idx]*sizeof(float));
   } else {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f3b47c70/src/worker.cc
----------------------------------------------------------------------
diff --git a/src/worker.cc b/src/worker.cc
index 333408d..01872d0 100644
--- a/src/worker.cc
+++ b/src/worker.cc
@@ -241,6 +241,9 @@ int Worker::Put(int step, Param* param) {
     LOG(WARNING) << "Null dealer in worker (" << grp_id_ << ", " <<
id_ << ")";
     return 1;
   }
+  // set Blob head to cpu to avoid calling cudaMemcpy by the stub thread, which
+  // would hang on some machines.
+  param->data().cpu_data();
   Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
   msg->set_trgt(ParamTrgt(param->owner(), 0), step);
   msg->set_type(kPut);
@@ -255,6 +258,10 @@ int Worker::Get(int step, Param* param) {
     LOG(WARNING) << "Null dealer in worker (" << grp_id_ << ", " <<
id_ << ")";
     return 1;
   }
+  // set Blob head to cpu to avoid calling cudaMemcpy by the stub thread, which
+  // would hang on some machines.
+  param->mutable_data()->mutable_cpu_data();
+
   Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
   msg->set_trgt(ParamTrgt(param->owner(), 0), step);
   msg->set_type(kGet);
@@ -268,6 +275,18 @@ int Worker::Update(int step, Param* param) {
     LOG(WARNING) << "Null dealer in worker (" << grp_id_ << ", " <<
id_ << ")";
     return 1;
   }
+  // head of data Blob (SyncMem) to cpu, because the stub thread may use
+  // cudaMemcpy copy gradients into msgs. cudaMemcpy hangs when called by the
+  // stub thread on some GPU machines.
+  // TODO(wangwei) fix this issue and remove the following line.
+  // optimize for training with single worker by removing stub and server, and
+  // updating parameters locally inside the worker GPU. Then we do not need to
+  // transfer gradients and parameter values between GPU-CPU.
+  param->grad().cpu_data();
+  // change the head of SyncMem to cpu; otherwise, the updated parameter
+  // values would not be synced to gpu (since the head is at gpu).
+  param->mutable_data()->mutable_cpu_data();
+
   Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
   msg->set_trgt(ParamTrgt(param->owner(), 0), step);
   msg->set_type(kUpdate);
@@ -288,8 +307,10 @@ int Worker::CollectAll(int step, NeuralNet* net) {
 }
 
 int Worker::Collect(int step, Param* param) {
-  while (param->version() <= param->last_version())
+  while (param->version() <= param->last_version()) {
     std::this_thread::sleep_for(std::chrono::milliseconds(kCollectSleepTime));
+    // LOG(ERROR) << "wait  "<< param->id() << " at " << step
<< " by " <<id_;
+  }
   return 1;
 }
 
@@ -332,6 +353,7 @@ void BPWorker::Forward(int step, Phase phase, NeuralNet* net) {
           Collect(step, p);
         }
       }
+      // LOG(ERROR) << layer->name() << " forward";
       layer->ComputeFeature(phase | kForward, net->srclayers(layer));
       // TODO(wangwei): enable this for model partition
       // send data to other workers
@@ -350,6 +372,7 @@ void BPWorker::Backward(int step, NeuralNet* net) {
       // send data to other workers
       // if (typeid(layer) == typeid(BridgeSrcLayer))
       //   ReceiveBlobs(false, true, layer, net);
+      // LOG(ERROR) << layer->name() << " backward";
       layer->ComputeGradient(kTrain | kBackward, net->srclayers(layer));
       for (Param* p : layer->GetParams())
         Update(step, p);


Mime
View raw message