mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] eric-haibin-lin commented on a change in pull request #8331: [sparse] slice for csr on two dimensions, cpu implementation
Date Thu, 01 Jan 1970 00:00:00 GMT
eric-haibin-lin commented on a change in pull request #8331: [sparse] slice for csr on two
dimensions, cpu implementation
URL: https://github.com/apache/incubator-mxnet/pull/8331#discussion_r149505293
 
 

 ##########
 File path: src/operator/tensor/matrix_op-inl.h
 ##########
 @@ -551,48 +550,43 @@ void SliceCsrIndPtrImpl(const int begin, const int end, RunContext
ctx,
 }
 
 /*
- * Slice a CSR NDArray
+ * Slice a CSR NDArray for first dimension
  * Only implemented for CPU
  */
 template<typename xpu>
-void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
-                  const NDArray &in, OpReqType req, const NDArray &out) {
+void SliceDimOneCsrImpl(const TShape &begin, const TShape &end, const OpContext&
ctx,
+                        const NDArray &in, const NDArray &out) {
   using namespace mshadow;
   using namespace mxnet_op;
   using namespace csr;
-  CHECK((std::is_same<xpu, cpu>::value)) << "Slice for CSR input only implemented
for CPU";
-  if (req == kNullOp) return;
-  CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported";
-  CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported";
-  const TShape ishape = in.shape();
-  int begin = *param.begin[0];
-  if (begin < 0) begin += ishape[0];
-  int end = *param.end[0];
-  if (end < 0) end += ishape[0];
-  int indptr_len = end - begin + 1;
+  CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimOneCsrImpl is only implemented
for CPU";
+  nnvm::dim_t begin_row = begin[0];
+  nnvm::dim_t end_row = end[0];
+  nnvm::dim_t indptr_len = end_row - begin_row + 1;
   out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
-  if (!in.storage_initialized()) {
-    out.set_aux_shape(kIndPtr, Shape1(0));
-    return;
-  }
   // assume idx indptr share the same type
   MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
     MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
       MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
-        auto in_indptr = in.aux_data(kIndPtr).dptr<RType>();
-        auto out_indptr = out.aux_data(kIndPtr).dptr<RType>();
-        SliceCsrIndPtrImpl<cpu, RType>(begin, end, ctx.run_ctx, in_indptr, out_indptr);
+        RType* in_indptr = in.aux_data(kIndPtr).dptr<RType>();
+        RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
+        SliceCsrIndPtrImpl<cpu, RType>(begin_row, end_row, ctx.run_ctx, in_indptr,
out_indptr);
 
         // retrieve nnz (CPU implementation)
         int nnz = out_indptr[indptr_len - 1];
+        // return csr zeros if nnz = 0
+        if (nnz == 0) {
+          out.set_aux_shape(kIdx, Shape1(0));
 
 Review comment:
   According to https://github.com/apache/incubator-mxnet/pull/7935 setting the aux shape
is not sufficient to generate a valid CSR ndarray. Let's always call `ZerosCsrImpl` instead.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message