mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] eric-haibin-lin closed pull request #10717: fix race condition in dot(csr.T, dns) on GPU
Date Sat, 28 Apr 2018 00:55:18 GMT
eric-haibin-lin closed pull request #10717: fix race condition in dot(csr.T, dns) on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10717
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh
index c546c4351a2..e97926da62a 100644
--- a/src/operator/tensor/dot-inl.cuh
+++ b/src/operator/tensor/dot-inl.cuh
@@ -29,6 +29,7 @@
 #include <mxnet/operator.h>
 #include "./util/tensor_util-inl.h"
 #include "./util/tensor_util-inl.cuh"
+#include "./indexing_op.h"
 
 namespace mxnet {
 namespace op {
@@ -287,50 +288,6 @@ struct DotCsrTransDnsDnsWarpBlockKernel {
   }
 };
 
-/*!
- * \brief GPU warp kernel of dot(csr.T, dns) = rsp
- * Parallelization by columns: 1 warp computes one lhs column for one rhs column
- */
-struct DotCsrTransDnsRspWarpKernel {
-  /*!
-   * \brief
-   * \param tid              global thread id
-   * \param out              output rsp matrix data
-   * \param row_flg_sum_out  inclusive prefix sum array over 0/1 marked row flag array
-   * \param data_l           csr matrix data
-   * \param indptr_l         csr matrix row index pointer
-   * \param col_idx_l        csr matrix column indices
-   * \param data_r           dns matrix data
-   * \param num_cols_r       dns matrix number of columns
-   */
-  template<typename DType, typename IType, typename CType>
-  __device__ __forceinline__ static void Map(int tid,
-                                             DType* out,
-                                             const nnvm::dim_t* row_flg_sum_out,
-                                             const DType* data_l,
-                                             const IType* indptr_l,
-                                             const CType* col_idx_l,
-                                             const DType* data_r,
-                                             const nnvm::dim_t num_cols_r) {
-    using nnvm::dim_t;
-    const dim_t warp_id = tid / 32;           // global warp id
-    const dim_t lane = tid & (32-1);          // local thread id within warp
-    const dim_t icol = warp_id / num_cols_r;  // lhs column that this warp computes
-    const dim_t kcol = warp_id % num_cols_r;  // rhs column that this warp computes
-
-    // Compute range of nnz elements in this column
-    const dim_t low  = static_cast<dim_t>(indptr_l[icol]);
-    const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);
-
-    // Iterate through the nnz elements in this column
-    for (dim_t j = low+lane; j < high; j+=32) {
-      const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
-      const dim_t rsp_row = row_flg_sum_out[irow]-1;
-      const DType val = data_l[j]*data_r[icol*num_cols_r+kcol];
-      atomicAdd(static_cast<DType *>(&(out[rsp_row*num_cols_r+kcol])), val);
-    }
-  }
-};
 
 /*!
  * \brief GPU Kernel of dot(csr.T, rsp1) = rsp2
@@ -442,6 +399,41 @@ struct DotCsrRspDnsScalarKernel {
   }
 };
 
+/*!
+ * \brief GPU Kernel to scatter row id to corresponding entries
+ * \param tid         global thread id
+ * \param csr_indptr  indptr array of csr
+ * \param csr_rows    array of row id of csr elements
+ * \param num_rows    total number of rows in csr matrix
+ * Parallelization by output elements: 1 thread/row
+ */
+struct CsrRowScatterKernel {
+  template<typename CType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             const CType* csr_indptr,
+                                             CType* csr_rows,
+                                             const nnvm::dim_t num_rows) {
+    if (tid < num_rows) {
+      for (CType i = csr_indptr[tid]; i < csr_indptr[tid+1]; ++i) {
+        csr_rows[i] = tid;
+      }
+    }
+  }
+};
+
+/*
+ * \brief the kernel to generate a lookup table for positions of row ids
+ * \param i thread id
+ * \param out output table
+ * \param data the input row id in sorted order
+ */
+struct MarkLookupTable {
+  template<typename IType, typename DType>
+  MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) {
+    out[static_cast<nnvm::dim_t>(data[i])] = i;
+  }
+};
+
 /*!
  * \brief GPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
  */
@@ -576,8 +568,60 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
   });
 }
 
+struct DotCsrTransDnsRspKernel {
+  /*!
+   * \brief
+   * \param tid              global thread id
+   * \param out              output rsp matrix data
+   * \param lookup_table     lookup table from row in lhs to row in dst
+   * \param sorted_indices   csr matrix column indices in sorted order
+   * \param nnz              number of non-zeros in csr matrix
+   * \param original_idx     original indices to the unsorted csr column indices
+   * \param rhs              dns rhs data
+   * \param val_array        csr matrix data
+   * \param idx_array        csr matrix row indices
+   * \param row_length       length of a row in the output rsp matrix
+   */
+  template<typename DType, typename IType>
+  __device__ __forceinline__ static void Map(int thread_id,
+                                             DType* out,
+                                             const IType* lookup_table,
+                                             const IType* sorted_indices,
+                                             const nnvm::dim_t nnz,
+                                             const IType* original_idx,
+                                             const DType* rhs,
+                                             const DType* val_array,
+                                             const IType* idx_array,
+                                             const nnvm::dim_t row_length) {
+    int tid = thread_id / row_length;
+    const nnvm::dim_t offset = thread_id % row_length;
+    if (tid == 0 || sorted_indices[tid - 1] != sorted_indices[tid]) {
+      DType acc = 0;
+      const IType src_row_idx = sorted_indices[tid];
+      const IType dst_row_idx = lookup_table[src_row_idx];
+      const IType out_offset = dst_row_idx * row_length + offset;
+      do {
+        const IType idx = original_idx[tid];
+        const DType val = val_array[idx];
+        const DType col_idx = idx_array[idx];
+        const IType rhs_offset = col_idx * row_length + offset;
+        acc += rhs[rhs_offset] * val;
+        tid++;
+      } while (tid < nnz && sorted_indices[tid - 1] == sorted_indices[tid]);
+      out[out_offset] = acc;
+    }
+  }
+};
+
+// Returns integer log2(a) rounded up
+inline int log2i(size_t a) {
+  int k = 1;
+  while (a >>= 1) k++;
+  return k;
+}
+
 /*!
- * \brief GPU Impl of dot(csr, dns) = rsp and dot(csr.T, dns) = rsp
+ * \brief GPU Impl of dot(csr.T, dns) = rsp
  */
 inline void DotCsrDnsRspImpl(const OpContext& ctx,
                              const gpu& gpu_dev,
@@ -597,92 +641,114 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
   }
 
   using mshadow::Shape1;
+  using mshadow::Tensor;
   using mxnet_op::Kernel;
   using mxnet_op::set_zero;
   using nnvm::dim_t;
+  using namespace csr;
 
   const TBlob data_l = lhs.data();
-  const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
-  const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
+  const TBlob indptr_l = lhs.aux_data(kIndPtr);
+  const TBlob col_idx_l = lhs.aux_data(kIdx);
   const TBlob& data_r = rhs;
+  size_t nnz = lhs.aux_data(kIdx).Size();
 
   const dim_t num_rows_l = lhs.shape()[0];
   const dim_t num_cols_l = lhs.shape()[1];
   const dim_t num_cols_r = rhs.shape_[1];
-  const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
-  dim_t num_threads;
-  // TODO: remove kernel dependency on warpSize=32
-  if (threads_per_warp != 32) {
-    LOG(FATAL) << "DotCsrDnsRspImpl GPU kernels expect warpSize=32";
-  }
-
+  CHECK_EQ(ret->aux_type(rowsparse::kIdx), col_idx_l.type_flag_)
+    << "Mismatch indices dtype detected";
   MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
     MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
       MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
         if (trans_lhs) {
-          // Compute number of non-zero rows (nnr) of output matrix
-          // - alloc temp storage for row_flg array and for cub's prefix sum
-          // - mark non-zero columns of csr matrix in row_flg
-          // - compute inclusive prefix sum over marked array
-          // - copy last value (nnr_out) from device to host
-          dim_t* row_flg_out = NULL;
-          void* d_temp_storage = NULL;
-          size_t temp_storage_bytes = 0;
-          cub::DeviceScan::InclusiveSum(d_temp_storage,
-                                        temp_storage_bytes,
-                                        row_flg_out,
-                                        row_flg_out,
-                                        num_cols_l,
-                                        mshadow::Stream<gpu>::GetStream(s));
-          mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
-              .get_space_typed<gpu, 1, char>(Shape1(num_cols_l * sizeof(dim_t) +
-                                                    temp_storage_bytes), s);
-          row_flg_out = reinterpret_cast<dim_t*>(workspace.dptr_);
-          d_temp_storage = workspace.dptr_ + num_cols_l*sizeof(dim_t);
-          num_threads = num_cols_l;
-          Kernel<set_zero, gpu>::Launch(s, num_threads, row_flg_out);
-          num_threads = num_rows_l * threads_per_warp;
-          Kernel<MarkCsrColWarpKernel, gpu>::Launch(s, num_threads,
-              row_flg_out, col_idx_l.dptr<CType>(), indptr_l.dptr<IType>(),
-              num_rows_l, num_cols_l);
-          cub::DeviceScan::InclusiveSum(d_temp_storage,
-                                        temp_storage_bytes,
-                                        row_flg_out,
-                                        row_flg_out,
-                                        num_cols_l,
-                                        mshadow::Stream<gpu>::GetStream(s));
-          dim_t nnr_out = 0;
-          CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t),
-                               cudaMemcpyDeviceToHost));
-          if (0 == nnr_out) {
-            FillZerosRspImpl(s, *ret);
-            return;
-          }
-
-          // Allocate output matrix space
-          ret->CheckAndAlloc({Shape1(nnr_out)});
-          const TBlob data_out_blob = ret->data();
-          const TBlob row_idx_out_blob = ret->aux_data(rowsparse::kIdx);
-          MSHADOW_IDX_TYPE_SWITCH(row_idx_out_blob.type_flag_, RType, {  // row idx type
-            DType* data_out = data_out_blob.dptr<DType>();
-            RType* row_idx_out = row_idx_out_blob.dptr<RType>();
-            num_threads = nnr_out * num_cols_r;
-            Kernel<set_zero, gpu>::Launch(s, num_threads, data_out);
-            num_threads = nnr_out;
-            Kernel<set_zero, gpu>::Launch(s, num_threads, row_idx_out);
-
-            // Fill row_idx array of output matrix, using the row_flg values
-            num_threads = num_cols_l;
-            Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_threads,
-                row_idx_out, row_flg_out, num_cols_l);
-
-            // Perform matrix-matrix multiply
-            num_threads = threads_per_warp * num_rows_l * num_cols_r;
-            Kernel<DotCsrTransDnsRspWarpKernel, gpu>::Launch(s, num_threads,
-                data_out, row_flg_out,
-                data_l.dptr<DType>(), indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(),
-                data_r.dptr<DType>(), num_cols_r);
-          });
+          IType* col_idx_l_ptr = col_idx_l.dptr<IType>();
+          // temporary memory layout
+          size_t* nnr_ptr = nullptr;
+          IType* original_idx_ptr = nullptr;
+          IType* row_idx_ptr = nullptr;
+          IType* col_idx_copy_ptr = nullptr;
+          IType* lookup_table_ptr = nullptr;
+          char* temp_storage_ptr = nullptr;
+
+          // estimate temp space for unique.
+          const size_t nnr_bytes = sizeof(size_t);
+          size_t unique_temp_bytes = 0;
+          size_t *null_ptr = nullptr;
+          size_t *null_dptr = nullptr;
+          cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+          cub::DeviceSelect::Unique(NULL, unique_temp_bytes, null_dptr, null_dptr,
+                                    null_ptr, nnz, stream);
+          // the temp storage for sort and unique
+          size_t original_idx_bytes = nnz * sizeof(IType);
+          size_t row_idx_bytes = nnz * sizeof(IType);
+          size_t col_idx_copy_bytes = nnz * sizeof(IType);
+          size_t lookup_table_bytes = num_cols_l * sizeof(IType);
+          size_t sort_temp_bytes = SortByKeyWorkspaceSize<IType, IType, gpu>(nnz);
+          size_t total_temp_bytes = std::max(sort_temp_bytes, unique_temp_bytes);
+
+          // layout: original_idx, col_idx_copy, temp_storage
+          size_t total_workspace_bytes = nnr_bytes + original_idx_bytes + row_idx_bytes +
+                                         col_idx_copy_bytes +
+                                         lookup_table_bytes + total_temp_bytes;
+          // request temp space
+          Tensor<gpu, 1, char> workspace = ctx.requested[0]
+              .get_space_typed<gpu, 1, char>(Shape1(total_workspace_bytes), s);
+          // update individual temp space ptrs
+          nnr_ptr = reinterpret_cast<size_t*>(workspace.dptr_);
+          original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes);
+          row_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes +
+                                                 original_idx_bytes);
+          col_idx_copy_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes +
+                                                      original_idx_bytes + row_idx_bytes);
+          lookup_table_ptr = reinterpret_cast<IType*>(workspace.dptr_ + nnr_bytes +
+                                                      original_idx_bytes + row_idx_bytes
+
+                                                      col_idx_copy_bytes);
+          temp_storage_ptr = workspace.dptr_ + nnr_bytes + original_idx_bytes +
+                             row_idx_bytes + col_idx_copy_bytes + lookup_table_bytes;
+
+          // Fill original_idx
+          Kernel<range_fwd, gpu>::Launch(
+            s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
+          // Make a copy of col_idx_l
+          Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
+            s, nnz, col_idx_copy_ptr, col_idx_l_ptr);
+
+          // Construct the tensors needed for SortByKey
+          Tensor<gpu, 1, IType> col_idx_copy(col_idx_copy_ptr, Shape1(nnz), s);
+          Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(nnz), s);
+          Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(total_temp_bytes),
s);
+
+          int num_bits = log2i(num_cols_l - 1);
+          SortByKey(col_idx_copy, original_idx, true, &temp_storage, 0, num_bits);
+
+          // over-allocate aux indices
+          ret->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(nnz));
+          // compute unique indices
+          IType* ret_idx_ptr = ret->aux_data(rowsparse::kIdx).dptr<IType>();
+          cub::DeviceSelect::Unique(temp_storage_ptr, unique_temp_bytes, col_idx_copy_ptr,
ret_idx_ptr,
+                                    nnr_ptr, nnz, stream);
+          // retrieve num non-zero rows
+          size_t nnr = 0;
+          CUDA_CALL(cudaMemcpy(&nnr, nnr_ptr, nnr_bytes, cudaMemcpyDeviceToHost));
+          // allocate data
+          ret->CheckAndAllocData(mshadow::Shape2(nnz, num_cols_r));
+          // generate lookup table
+          Kernel<MarkLookupTable, gpu>::Launch(s, nnr, lookup_table_ptr, ret_idx_ptr);
+
+          // Scatter csr indptr to row id
+          Kernel<CsrRowScatterKernel, gpu>::Launch(
+            s, num_rows_l, indptr_l.dptr<IType>(), row_idx_ptr, num_rows_l);
+
+          Kernel<DotCsrTransDnsRspKernel, gpu>::Launch(s, nnz * num_cols_r,
+                 ret->data().dptr<DType>(),
+                 lookup_table_ptr, col_idx_copy_ptr, nnz,
+                 original_idx_ptr, data_r.dptr<DType>(),
+                 data_l.dptr<DType>(),
+                 row_idx_ptr, num_cols_r);
+
+          // reshape aux data
+          ret->set_aux_shape(rowsparse::kIdx, Shape1(nnr));
         } else {
           LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns) = rsp yet.";
         }
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 34794866546..dd0c5a68ab4 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1294,6 +1294,30 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
     test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
     test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)
 
+@with_seed()
+def test_sparse_dot_determinism():
+    def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a,
transpose_b):
+        lhs_row = rnd.randint(50, 100)
+        lhs_col = rnd.randint(50, 100)
+        if transpose_a:
+            if transpose_b:
+                rhs_shape = (rnd.randint(50, 100), lhs_row)
+            else:
+                rhs_shape = (lhs_row, rnd.randint(50, 100))
+        else:
+            if transpose_b:
+                rhs_shape = (rnd.randint(50, 100), lhs_col)
+            else:
+                rhs_shape = (lhs_col, rnd.randint(50, 100))
+        lhs_shape = (lhs_row, lhs_col)
+        lhs = rand_ndarray(lhs_shape, lhs_stype, density=lhs_density)
+        rhs = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density)
+        res1 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b)
+        res2 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b)
+        assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.0, atol=0.0)
+
+    test_dot_determinism('csr', 'default', 0.1, 1.0, True, False)
+
 
 @with_seed()
 def test_sparse_slice():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message