singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [09/19] incubator-singa git commit: SINGA-100 Implement layers using CUDNN for GPU training
Date Wed, 16 Dec 2015 12:11:41 GMT
SINGA-100 Implement layers using CUDNN for GPU training

tmp commit; setup device in driver.cc, register cudnn layers;

tmp commit;

1.Add kernel_sum_by_row() cuda kernel functions
2.Fixed neuron_layer.h bug:
cudnn.h should be placed outside the namespace
3.Register CudnnSoftmaxLoss class (job.proto)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8cacd83c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8cacd83c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8cacd83c

Branch: refs/heads/master
Commit: 8cacd83ceb5b22df473fb96fcf2e371e470e6d43
Parents: 6e56334
Author: seaokcs <seaokcs@gmail.com>
Authored: Mon Nov 30 14:48:15 2015 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Fri Dec 11 11:48:23 2015 +0800

----------------------------------------------------------------------
 include/singa/neuralnet/neuron_layer.h      | 12 ++++++---
 include/singa/utils/blob.h                  |  5 ++++
 include/singa/utils/cuda_utils.h            | 14 -----------
 include/singa/utils/math_addr.h             |  2 +-
 include/singa/utils/math_blob.h             |  8 +++---
 include/singa/utils/math_kernel.h           |  5 +++-
 src/driver.cc                               | 23 +++++++++++++++--
 src/neuralnet/neuron_layer/inner_product.cc | 32 +++++-------------------
 src/proto/job.proto                         |  9 +++++++
 src/utils/math_kernel.cu                    | 30 +++++++++++++++++++---
 src/utils/param.cc                          |  4 +++
 src/worker.cc                               |  6 +++++
 12 files changed, 95 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h
index 097b209..830f731 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -25,6 +25,11 @@
 #include <vector>
 #include "singa/neuralnet/layer.h"
 #include "singa/proto/job.pb.h"
+
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#endif
+
 namespace singa {
 
 /* Activation layer applies following activations,
@@ -249,7 +254,6 @@ class STanhLayer : public NeuronLayer {
 
 /*************** Layers implemented using cudnn v3 ***************/
 #ifdef USE_CUDNN
-#include <cudnn.h>
 #define CHECK_CUDNN(x) CHECK_EQ(x, CUDNN_STATUS_SUCCESS)
 
 class CudnnLayer : virtual public NeuronLayer {
@@ -273,9 +277,9 @@ class CudnnLayer : virtual public NeuronLayer {
 /**
  * Activation layer implemented using cudnn v3.
  * Activation methods including
- * - "sigmoid"
- * - "tanh"
- * - "relu"
+ * - SIGMOID
+ * - TANH
+ * - RELU
  */
 class CudnnActivationLayer : public ActivationLayer, public CudnnLayer {
  public:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/include/singa/utils/blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/blob.h b/include/singa/utils/blob.h
index e7bf995..bb63208 100644
--- a/include/singa/utils/blob.h
+++ b/include/singa/utils/blob.h
@@ -294,6 +294,11 @@ class Blob {
   inline bool transpose() const {
     return transpose_;
   }
+  inline const Blob<Dtype> T() const {
+    Blob<Dtype> ret(*this);
+    ret.transpose_ = !transpose_;
+    return ret;
+  }
 
  protected:
   std::shared_ptr<SyncedMemory> data_ = nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/include/singa/utils/cuda_utils.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/cuda_utils.h b/include/singa/utils/cuda_utils.h
index b27a6bb..0b3e0a5 100644
--- a/include/singa/utils/cuda_utils.h
+++ b/include/singa/utils/cuda_utils.h
@@ -74,18 +74,4 @@
     CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
   } while (0)
 
-#define CUBLAS_CHECK(condition) \
-  do { \
-    cublasStatus_t status = condition; \
-    CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \
-      << caffe::cublasGetErrorString(status); \
-  } while (0)
-
-#define CURAND_CHECK(condition) \
-  do { \
-    curandStatus_t status = condition; \
-    CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \
-      << caffe::curandGetErrorString(status); \
-  } while (0)
-
 #endif  // SINGA_UTILS_CUDA_UTILS_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/include/singa/utils/math_addr.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_addr.h b/include/singa/utils/math_addr.h
index 3b0eefd..7ab884e 100644
--- a/include/singa/utils/math_addr.h
+++ b/include/singa/utils/math_addr.h
@@ -231,7 +231,7 @@ void gpu_e_f(const int n, const Dtype * A, Dtype * B) {
 }
 
 template<typename Op, typename Dtype>
-void gpu_e_f(const int n, const Dtype * A, const Dtype * B, const Dtype * C) {
+void gpu_e_f(const int n, const Dtype * A, const Dtype * B, Dtype * C) {
   Op::CudaMap(A, B, C, n);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/include/singa/utils/math_blob.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_blob.h b/include/singa/utils/math_blob.h
index 97d5cf7..586554e 100644
--- a/include/singa/utils/math_blob.h
+++ b/include/singa/utils/math_blob.h
@@ -466,7 +466,8 @@ void MVAddCol(Dtype alpha, Dtype beta, const Blob<Dtype> & A,
Blob<Dtype> * B) {
           B->mutable_cpu_data());
     } else {
 #ifdef USE_GPU
-      singa_gpu_add_vec_row(B->gpu_data(), A.gpu_data(), A.gpu_data(), m, n, n);
+      singa_gpu_add_vec_row(A.gpu_data(), B->gpu_data(), B->mutable_gpu_data(),
+          m, n, n);
 #endif  // USE_GPU
     }
   }
@@ -504,7 +505,8 @@ void MVAddRow(Dtype alpha, Dtype beta, const Blob<Dtype> & A,
Blob<Dtype> * B) {
           false, false, B->mutable_cpu_data());
     } else {
 #ifdef USE_GPU
-      singa_gpu_add_vec_row(B->gpu_data(), A.gpu_data(), A.gpu_data(), m, n, n);
+      singa_gpu_add_vec_row(A.gpu_data(), B->gpu_data(), B->mutable_gpu_data(),
+          m, n, n);
 #endif  // USE_GPU
     }
   }
@@ -583,7 +585,7 @@ void MVSumRow(Dtype alpha, Dtype beta, const Blob<Dtype> & A,
Blob<Dtype> * B) {
       false, B->mutable_cpu_data());
   } else {
 #ifdef USE_GPU
-    singa_gpu_sum_row(A.gpu_data(), B->gpu_data(), m, n, n);
+    singa_gpu_sum_vec(A.gpu_data(), B->gpu_data(), m, n, n);
     // gpu part (TODO check transpose case)
 #endif  // USE_GPU
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/include/singa/utils/math_kernel.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/math_kernel.h b/include/singa/utils/math_kernel.h
index d763283..37a9356 100644
--- a/include/singa/utils/math_kernel.h
+++ b/include/singa/utils/math_kernel.h
@@ -26,7 +26,10 @@ namespace singa {
 extern "C" {
   void singa_gpu_sum_vec(float *data, float *sum , int n);
 
-  void singa_gpu_sum_col(const float *src_mat_data, float *dst_vec_data,
+  void singa_gpu_sum_by_col(const float *src_mat_data, float *dst_vec_data,
+    int rows, int cols, int stride);
+
+  void singa_gpu_sum_by_row(const float *src_mat_data, float *dst_vec_data,
     int rows, int cols, int stride);
 
   void singa_gpu_add_vec_row(const float *src_vec_data,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 8a48c30..1ae6d9f 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -29,6 +29,8 @@
 #include "singa/utils/common.h"
 #include "singa/utils/tinydir.h"
 #include "singa/utils/cluster.h"
+#include "singa/utils/context.h"
+#include "singa/proto/job.pb.h"
 #include "singa/server.h"
 #include "singa/stub.h"
 #include "singa/worker.h"
@@ -75,7 +77,16 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<CConvolutionLayer, int>(kCConvolution);
   RegisterLayer<CPoolingLayer, int>(kCPooling);
   RegisterLayer<ConcateLayer, int>(kConcate);
-  RegisterLayer<DropoutLayer, int>(kDropout);
+
+#ifdef USE_CUDNN
+  RegisterLayer<CudnnActivationLayer, int>(kCudnnActivation);
+  RegisterLayer<CudnnConvLayer, int>(kCudnnConv);
+  RegisterLayer<CudnnPoolLayer, int>(kCudnnPool);
+  RegisterLayer<CudnnLRNLayer, int>(kCudnnLRN);
+  RegisterLayer<CudnnSoftmaxLayer, int>(kCudnnSoftmax);
+  RegisterLayer<CudnnSoftmaxLossLayer, int>(kCudnnSoftmaxLoss);
+#endif
+
   RegisterLayer<EuclideanLossLayer, int>(kEuclideanLoss);
   RegisterLayer<InnerProductLayer, int>(kInnerProduct);
   RegisterLayer<LabelLayer, int>(kLabel);
@@ -203,8 +214,16 @@ void Driver::Train(const JobProto& job_conf) {
   vector<std::thread> threads;
   for (auto server : servers)
     threads.push_back(std::thread(&Server::Run, server));
-  for (auto worker : workers)
+  int gpu = 0;
+  auto context = Singleton<Context>::Instance();
+  CHECK_LE(workers.size(), job_conf.gpu_size());
+  for (auto worker : workers) {
     threads.push_back(std::thread(&Worker::Run, worker));
+    if (gpu < job_conf.gpu_size()) {
+      int device_id = job_conf.gpu(gpu++);
+      context->SetupDevice(threads.back().get_id(), device_id);
+    }
+  }
   if (grp_size > 1 || nserver_grps > 0) {
     int nservers_per_grp = cluster->nservers_per_group();
     int lcm = LeastCommonMultiple(nservers_per_grp, nserver_grps);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/src/neuralnet/neuron_layer/inner_product.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/inner_product.cc b/src/neuralnet/neuron_layer/inner_product.cc
index 25b71df..3b18cd7 100644
--- a/src/neuralnet/neuron_layer/inner_product.cc
+++ b/src/neuralnet/neuron_layer/inner_product.cc
@@ -23,6 +23,7 @@
 
 #include <glog/logging.h>
 #include "singa/utils/singleton.h"
+#include "singa/utils/math_blob.h"
 
 namespace singa {
 
@@ -57,38 +58,17 @@ void InnerProductLayer::Setup(const LayerProto& conf,
 
 void InnerProductLayer::ComputeFeature(int flag,
     const vector<Layer*>& srclayers) {
-  auto data = Tensor2(&data_);
-  auto src = Tensor2(srclayers[0]->mutable_data(this));
-  auto weight = Tensor2(weight_->mutable_data());
-  auto bias = Tensor1(bias_->mutable_data());
-  if (transpose_)
-    data = dot(src, weight);
-  else
-    data = dot(src, weight.T());
-  // repmat: repeat bias vector into batchsize rows
-  data += expr::repmat(bias, batchsize_);
+  MMDot(srclayers[0]->data(this), weight_->data(), &data_);
+  MVAddRow(bias_->data(), &data_);
 }
 
 void InnerProductLayer::ComputeGradient(int flag,
     const vector<Layer*>& srclayers) {
-  auto src = Tensor2(srclayers[0]->mutable_data(this));
-  auto grad = Tensor2(&grad_);
-  auto weight = Tensor2(weight_->mutable_data());
-  auto gweight = Tensor2(weight_->mutable_grad());
-  auto gbias = Tensor1(bias_->mutable_grad());
 
-  gbias = expr::sum_rows(grad);
-  if (transpose_)
-    gweight = dot(src.T(), grad);
-  else
-    gweight = dot(grad.T(), src);
+  MVSumRow(1.0f, 0.0f, grad_, bias_->mutable_grad());
+  MVDot(grad_.T(), srclayers[0]->data(this), weight_->mutable_grad());
   if (srclayers[0]->mutable_grad(this) != nullptr) {
-    auto gsrc = Tensor2(srclayers[0]->mutable_grad(this));
-    if (transpose_)
-      gsrc = dot(grad, weight.T());
-    else
-      gsrc = dot(grad, weight);
+    MVDot(grad_, weight_->data(), srclayers[0]->mutable_grad(this));
   }
 }
-
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 12f6929..88caa44 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -52,6 +52,9 @@ message JobProto {
   required int32 train_steps = 16;
   // frequency of displaying training info
   optional int32 disp_freq = 17 [default = 0];
+  // GPU device IDs for use, if fewer than workers per procs, some workers run
+  // on GPU and the rest run on CPU.
+  repeated int32 gpu = 18;
 
   // frequency of test, e.g., do test every 100 training steps
   optional int32 test_freq = 20 [default = 0];
@@ -603,6 +606,12 @@ enum LayerType {
   kConvolution = 1;
   kCConvolution = 27;
   kCPooling = 28;
+  kCudnnConv = 50;
+  kCudnnPool = 51;
+  kCudnnLRN = 52;
+  kCudnnSoftmax = 53;
+  kCudnnActivation = 54;
+  kCudnnSoftmaxLoss = 55;
   kDropout = 4;
   kInnerProduct = 5;
   kLRN = 6;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/src/utils/math_kernel.cu
----------------------------------------------------------------------
diff --git a/src/utils/math_kernel.cu b/src/utils/math_kernel.cu
index 4dd91e0..12501fd 100644
--- a/src/utils/math_kernel.cu
+++ b/src/utils/math_kernel.cu
@@ -63,7 +63,7 @@ void kernel_sum_vec(float *data, float *sum , int n) {
 }
 
 __global__
-void kernel_sum_col(const float *src_mat_data,
+void kernel_sum_by_col(const float *src_mat_data,
     float *dst_vec_data, int rows, int cols, int stride) {
   int j = blockIdx.x;
   int THREADS = blockDim.x;
@@ -98,6 +98,19 @@ void kernel_sum_col(const float *src_mat_data,
 }
 
 __global__
+void kernel_sum_by_row(const float *src_mat_data,
+    float *dst_vec_data, int rows, int cols, int stride) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  int num_threads = blockDim.x * gridDim.x;
+  for (; index < rows; index += num_threads) {
+    dst_vec_data[index] = 0.0f;
+	for (int k = 0; k < cols; k++) {
+	  dst_vec_data[index] += src_mat_data[index * stride + k];
+	}
+  }
+}
+
+__global__
 void kernel_add_vec_row(const float *src_vec_data, const float *src_mat_data,
     float* des_mat_data, int rows, int cols, int stride) {
   int i = blockIdx.x * blockDim.x + threadIdx.x;
@@ -143,7 +156,7 @@ void kernel_sigmoid_grad(const float *src_data, float *des_data, int n)
{
   int index = blockIdx.x * blockDim.x + threadIdx.x;
   int num_threads = blockDim.x * gridDim.x;
   for (; index < n; index += num_threads) {
-  des_data[index] = src_data[index] * (1.0f - src_data[index]);
+    des_data[index] = src_data[index] * (1.0f - src_data[index]);
   }
 }
 
@@ -288,12 +301,21 @@ void singa_gpu_sum_vec(float *data, float *sum , int n) {
   kernel_sum_vec<<<num_blocks, threads_per_block>>>(data, sum, n);
 }
 
-void singa_gpu_sum_col(const float *src_mat_data, float *dst_vec_data,
+void singa_gpu_sum_by_col(const float *src_mat_data, float *dst_vec_data,
     int rows, int cols, int stride) {
   int threads_per_block = rows > CU1DBLOCK ? CU1DBLOCK : rows;
   int num_blocks = cols;
 
-  kernel_sum_col<<<num_blocks, threads_per_block>>>(src_mat_data, dst_vec_data,
+  kernel_sum_by_col<<<num_blocks, threads_per_block>>>(src_mat_data, dst_vec_data,
+      rows, cols, stride);
+}
+
+void singa_gpu_sum_by_row(const float *src_mat_data, float *dst_vec_data,
+    int rows, int cols, int stride) {
+  int threads_per_block = cols > CU1DBLOCK ? CU1DBLOCK : cols;
+  int num_blocks = rows;
+
+  kernel_sum_by_row<<<num_blocks, threads_per_block>>>(src_mat_data, dst_vec_data,
       rows, cols, stride);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 2ccc5a8..54fe2aa 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -237,11 +237,15 @@ Msg* Param::GenUpdateMsg(bool copy, int idx) {
   msg->set_type(kUpdate);
   msg->AddFormatFrame("i", copy);
   void* ptr = grad_.mutable_cpu_data() + slice_offset_[idx];
+  // to change the head of SyncMem to cpu; otherwise, the updated parameter
+  //   // values would not be synced to gpu (since the head is at gpu).
+  data_->mutable_cpu_data();
   if (copy) {
     msg->AddFrame(ptr, slice_size_[idx]*sizeof(float));
   } else {
     msg->AddFormatFrame("p", ptr);  // to share values of grad blob
   }
+
   pending_update_[idx] = true;
   num_pending_requests_++;
   return msg;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8cacd83c/src/worker.cc
----------------------------------------------------------------------
diff --git a/src/worker.cc b/src/worker.cc
index 4756514..8c1d950 100644
--- a/src/worker.cc
+++ b/src/worker.cc
@@ -28,6 +28,7 @@
 #include "singa/utils/cluster.h"
 #include "singa/utils/factory.h"
 #include "singa/utils/singleton.h"
+#include "singa/utils/context.h"
 
 namespace singa {
 
@@ -61,6 +62,11 @@ Worker::~Worker() {
 
 void Worker::Run() {
   LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_
<< ") start";
+  // setup gpu device
+  auto context = Singleton<Context>::Instance();
+  int device = context->device_id(std::this_thread::get_id());
+  if (device > 0)
+    context->ActivateDevice(device);
   auto cluster = Cluster::Get();
   int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
   CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp));


Mime
View raw message