From commits-return-30190-archive-asf-public=cust-asf.ponee.io@mxnet.incubator.apache.org Wed Apr 4 09:48:28 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 27766180677 for ; Wed, 4 Apr 2018 09:48:27 +0200 (CEST) Received: (qmail 51343 invoked by uid 500); 4 Apr 2018 07:48:26 -0000 Mailing-List: contact commits-help@mxnet.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@mxnet.incubator.apache.org Delivered-To: mailing list commits@mxnet.incubator.apache.org Received: (qmail 51332 invoked by uid 99); 4 Apr 2018 07:48:26 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 04 Apr 2018 07:48:26 +0000 From: GitBox To: commits@mxnet.apache.org Subject: [GitHub] haojin2 commented on a change in pull request #10208: [MXNET-117] Sparse operator broadcast_mul/div(csr, dense) = csr Message-ID: <152282810638.30204.14790496664349735345.gitbox@gitbox.apache.org> Date: Wed, 04 Apr 2018 07:48:26 -0000 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit haojin2 commented on a change in pull request #10208: [MXNET-117] Sparse operator broadcast_mul/div(csr, dense) = csr URL: https://github.com/apache/incubator-mxnet/pull/10208#discussion_r179052324 ########## File path: src/operator/tensor/elemwise_binary_broadcast_op.h ########## @@ -185,6 +230,84 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } } +template +void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx, + const NDArray& csr, + const NDArray& dns, + const OpReqType req, + const NDArray& output) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + CHECK(req != kAddTo && req != kWriteInplace); + mshadow::Stream *s = ctx.get_stream(); + bool col_vec; + if (dns.shape().ndim() == 1) { + col_vec = false; + } else { + col_vec = (dns.shape()[0] == csr.shape()[0])? true : false; + } + + if (csr.storage_initialized()) { + const nnvm::dim_t nnz = csr.storage_shape()[0]; + const nnvm::dim_t num_rows = output.shape()[0]; + output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)}); + + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, { + MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + Kernel, xpu>::Launch( + s, num_rows, csr.data().dptr(), csr.aux_data(kIdx).dptr(), + csr.aux_data(kIndPtr).dptr(), dns.data().dptr(), + output.data().dptr(), csr.shape()[1], col_vec); + Copy(output.aux_data(kIdx).FlatTo1D(), + csr.aux_data(kIdx).FlatTo1D()); + Copy(output.aux_data(kIndPtr).FlatTo1D(), + csr.aux_data(kIndPtr).FlatTo1D()); + }); + }); + }); + }); + // If input csr is an empty matrix, fill zeros and return + } else { + FillZerosCsrImpl(s, output); + return; + } +} + +template +void BinaryBroadcastComputeCsrEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_LE(inputs[1].shape().ndim(), 2U) << "input dense matrix should have less than 2 dimensions"; + const NDArray& lhs = inputs[0]; + const NDArray& rhs = inputs[1]; + const NDArray& out = outputs[0]; + const auto lhs_stype = lhs.storage_type(); + const auto rhs_stype = rhs.storage_type(); + const auto out_stype = out.storage_type(); + // If the input is not a vector + if ((rhs.shape().ndim() != 1U) && (rhs.shape()[0] != 1) && (rhs.shape()[1] != 1)) { + // Currently do not support elementwise_mul/div(csr, dense) = csr, log and exit + LogUnimplementedOp(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] != kNullOp) { + // broadcast(CSR, Dense(1D)) = CSR + if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kCSRStorage) { Review comment: Made it work ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org With regards, Apache Git Services