From dev-return-5084-archive-asf-public=cust-asf.ponee.io@singa.apache.org Sat Apr 11 14:41:56 2020 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [207.244.88.153]) by mx-eu-01.ponee.io (Postfix) with SMTP id 537FE180665 for ; Sat, 11 Apr 2020 16:41:56 +0200 (CEST) Received: (qmail 92341 invoked by uid 500); 11 Apr 2020 14:41:55 -0000 Mailing-List: contact dev-help@singa.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@singa.apache.org Delivered-To: mailing list dev@singa.apache.org Received: (qmail 92322 invoked by uid 99); 11 Apr 2020 14:41:55 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Sat, 11 Apr 2020 14:41:55 +0000 From: GitBox To: dev@singa.apache.org Subject: [GitHub] [singa] nudles commented on a change in pull request #662: CUDNN LSTM Message-ID: <158661611559.30592.3515077419542601113.gitbox@gitbox.apache.org> References: In-Reply-To: Date: Sat, 11 Apr 2020 14:41:55 -0000 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit nudles commented on a change in pull request #662: CUDNN LSTM URL: https://github.com/apache/singa/pull/662#discussion_r407071662 ########## File path: python/singa/autograd.py ########## @@ -3330,6 +3330,94 @@ def step_forward(self, x, h, c, Wx, Wh, Bx, Bh): return hout, cout +class _RNN(Operation): + """ RNN operation with c++ backend + """ + def __init__(self, handle): + assert singa.USE_CUDA is True, "Not able to run without CUDA" + super(_RNN, self).__init__() + self.handle = handle + + def forward(self, x, W): + # TODO: CPU forward + + # GPU forward + if training: + y = singa.GpuRNNForwardTraining(x, W, self.handle) + self.inputs = (x, W, y) + else: + y = singa.GpuRNNForwardInference(x, W, self.handle) + + return y + + def backward(self, dy): + assert training is True and hasattr( + self, "inputs"), "Please set training as True before do BP. " + + # TODO: CPU backward + + # GPU backward + dx = singa.GpuRNNBackwardx(self.inputs[2], dy, self.inputs[1], self.handle) + dW = singa.GpuRNNBackwardW(self.inputs[0], self.inputs[2], self.handle) + return dx, dW + +class RNN_direct(Layer): + """ `RNN_direct` class implements with c++ backend and run the operation + directly on cuDNN + + While `RNN` class implements with high level singa API + """ + def __init__(self, input_size, hidden_size, rnn_mode="lstm"): + """ + Args: + input_size: input feature dim + hidden_size: hidden feature dim + rnn_mode: accepted value: "vanilla", "tanh", "relu", "lstm", "gru" + """ + assert singa.USE_CUDA is True, "Not able to run without CUDA" + + self.rnn_mode = rnn_mode + self.input_size = input_size + self.hidden_size = hidden_size + + # TODO: CPU parameter + + # GPU parameter + # cudnn_rnn_mode: 0 - RNN RELU, 1 - RNN TANH, 2 - LSTM, 3 - GRU + if self.rnn_mode == "lstm": + self.cudnn_rnn_mode = 2 + elif self.rnn_mode == "vanilla" or self.rnn_mode == "tanh": + self.cudnn_rnn_mode = 1 + elif self.rnn_mode == "relu": + self.cudnn_rnn_mode = 0 + elif self.rnn_mode == "gru": + self.cudnn_rnn_mode = 3 + + def __call__(self, x): + if not hasattr(self, "handle"): + cpp_x = singa.VecTensor() + [cpp_x.append(i.data) for i in x] + + # TODO: CPU handle + + # GPU handle + self.handle = singa.CudnnRNNHandle(cpp_x, self.input_size, self.hidden_size, self.cudnn_rnn_mode) + + self.W = Tensor(shape=(self.handle.weights_size,), + requires_grad=True, + stores_grad=True) + self.W.gaussian(0.0, 1.0) + + return _RNN(self.handle)(x, self.W)[0] + + def get_params(self): + return self.W Review comment: since W is created after the layer is called in forward propagation, if this function is called after the layer is created, there will be error. we can add an assert to print some hints, e.g., W is not initialized until the layer is forwarded with data for at least once ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org With regards, Apache Git Services