From dev-return-2557-archive-asf-public=cust-asf.ponee.io@singa.incubator.apache.org Sun Mar 10 15:07:39 2019 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id A8C0D18072F for ; Sun, 10 Mar 2019 16:07:37 +0100 (CET) Received: (qmail 77985 invoked by uid 500); 10 Mar 2019 15:07:36 -0000 Mailing-List: contact dev-help@singa.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@singa.incubator.apache.org Delivered-To: mailing list dev@singa.incubator.apache.org Received: (qmail 77844 invoked by uid 99); 10 Mar 2019 15:07:36 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 10 Mar 2019 15:07:36 +0000 From: GitBox To: dev@singa.apache.org Subject: [GitHub] [incubator-singa] nudles commented on a change in pull request #416: singa-onnx Message-ID: <155223045619.14831.1284168139276365959.gitbox@gitbox.apache.org> Date: Sun, 10 Mar 2019 15:07:36 -0000 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit nudles commented on a change in pull request #416: singa-onnx URL: https://github.com/apache/incubator-singa/pull/416#discussion_r264043050 ########## File path: python/singa/sonnx.py ########## @@ -18,272 +18,359 @@ # - from __future__ import division -from singa import tensor -from singa import autograd -from onnx import helper,checker -from onnx import AttributeProto, TensorProto, GraphProto + +from collections import deque +from onnx import helper, checker +from onnx import TensorProto from onnx import numpy_helper -from onnx.backend.base import BackendRep as backendRep -from onnx.backend.base import Backend as backend -from collections import Counter, deque +from onnx.backend.base import BackendRep as backendRep +from onnx.backend.base import Backend as backend +from onnx.backend.base import namedtupledict from . import singa_wrap as singa -from autograd import * -from autograd import _Conv2d,_Pooling2d,_BatchNorm2d -#if not import, there will be an error -from singa.tensor import to_numpy - +import autograd -class BackendRep(backendRep): - def __init__(self,model,device): - self.model, self.modeldic = Backend.onnx_model_init(model,device) - self.handledic={} - def run(self,inputs): - self.y,self.modeldic=Backend.run(self.model, self.modeldic,inputs,self.handledic) - return self.y +import tensor +from device import create_cuda_gpu_on, get_default_device - -class Backend(backend): +class Handle(object): @staticmethod - def convhandle(name,handledic,x,model): - if(name in handledic):return handledic - i = Backend.find_name(model,name) - - shape = Backend.find_shape(model,i.input[1]) - cin,cout,k=shape[1], shape[0], (shape[2],shape[2]) - padding=(int(i.attribute[1].ints[0]),int(i.attribute[1].ints[0])) - stride=(int(i.attribute[2].ints[0]),int(i.attribute[2].ints[0])) - - handledic[name] = singa.CudnnConvHandle(x.data, k, stride,padding, cin, cout, True) - handledic[name].device_id = x.device.id() - return handledic - + def conv(inputs, attrs): + # inputs: a list of the input tensors + kernel = tuple(attrs['kernel_shape']) + padding = tuple(attrs['pads']) + stride = tuple(attrs['strides']) + group = 1 + bias = len(inputs) == 3 + x = inputs[0] + x_shape = inputs[0].shape + in_channels = x_shape[1] + w_shape = inputs[1].shape + out_channels = w_shape[0] + assert w_shape[1] == in_channels // group + + if inputs[0].device.id() == -1: + if group != 1: + raise NotImplementedError + else: + handle = singa.ConvHandle(x.data, kernel, stride, padding, + in_channels, out_channels, bias) + handle.device_id = inputs[0].device.id() + else: + handle = singa.CudnnConvHandle(x.data, kernel, stride, padding, + in_channels, out_channels, bias, + group) + handle.device_id = inputs[0].device.id() + return handle @staticmethod - def MaxPool2dhandle(name,handledic,x,model): - if(name in handledic):return handledic - i = Backend.find_name(model,name) - k = (int(i.attribute[0].ints[0]),int(i.attribute[0].ints[0])) - padding=(int(i.attribute[1].ints[0]),int(i.attribute[1].ints[0])) - stride=(int(i.attribute[2].ints[0]),int(i.attribute[2].ints[0])) - - handledic[name] = singa.CudnnPoolingHandle(x.data, k, stride, padding, True) - handledic[name].device_id = x.device.id() - return handledic + def max_pool(inputs, attrs): + x = inputs[0] + kernel = tuple(attrs['kernel_shape']) + padding = tuple(attrs['pads']) + stride = tuple(attrs['strides']) + if x.device.id() == -1: + handle = singa.PoolingHandle(x.data, kernel, stride, padding, True) + handle.device_id = inputs[0].device.id() + else: + handle = singa.CudnnPoolingHandle(x.data, kernel, stride, padding, + True) + handle.device_id = inputs[0].device.id() + return handle @staticmethod - def AveragePoolhandle(name,handledic,x,model): - if(name in handledic):return handledic - i = Backend.find_name(model,name) - k = (int(i.attribute[0].ints[0]),int(i.attribute[0].ints[0])) - padding=(int(i.attribute[1].ints[0]),int(i.attribute[1].ints[0])) - stride=(int(i.attribute[2].ints[0]),int(i.attribute[2].ints[0])) - - handledic[name] = singa.CudnnPoolingHandle(x.data, k, stride, padding, False) - handledic[name].device_id = x.device.id() - return handledic + def avg_pool(inputs, attrs): + x = inputs[0] + kernel = tuple(attrs['kernel_shape']) + padding = tuple(attrs['pads']) + stride = tuple(attrs['strides']) + if x.device.id() == -1: + handle = singa.PoolingHandle(x.data, kernel, stride, padding, False) + handle.device_id = inputs[0].device.id() + else: + handle = singa.CudnnPoolingHandle(x.data, kernel, stride, padding, + False) + handle.device_id = inputs[0].device.id() + return handle @staticmethod - def BatchNormalizationhandle(name,handledic,x,model): - if(name in handledic):return handledic - handledic[name] = singa.CudnnBatchNormHandle(0.9, x.data) - handledic[name].device_id = x.device.id() - return handledic - - - - - - - @staticmethod - def onnx_model_init(model,device): + def batchnorm(inputs, attrs): + x = inputs[0] + factor = attrs['momentum'] + if x.device.id() == -1: + raise NotImplementedError + else: + handle = singa.CudnnBatchNormHandle(factor, x.data) + handle.device_id = inputs[0].device.id() + return handle + +UnaryOp = {'Relu': autograd.relu, + 'Softmax': autograd.softmax, + 'Flatten': autograd.flatten, + 'Tanh': autograd.tanh, + 'Sigmoid': autograd.sigmoid} +BinaryOp = {'Add': autograd.add, + 'Mul': autograd.mul, + 'MatMul': autograd.matmul} + +OtherOp = {'Conv': (Handle.conv, autograd.conv2d), + 'MaxPool': (Handle.max_pool, autograd.pooling_2d), + 'AveragePool': (Handle.avg_pool, autograd.pooling_2d), + 'BatchNormalization': (Handle.batchnorm, autograd.batchnorm_2d) + } + + +class SingaBackendRep(backendRep): + + def __init__(self, model, device, tensor_dict): ''' - input model - - return: model and model dictionary + Args: + model: onnx model proto + device: singa device + tensor_dict: dict for weight tensors ''' - - modeldic = {} - for i in model.graph.node: - if (i.op_type == 'Constant'): - modeldic[str(i.output[0])] = tensor.Tensor(device=device,data=numpy_helper.to_array(i.attribute[0].t),requires_grad=True, stores_grad=True) - - return model,modeldic - + self.model = model + self.device = device + self.tensor_dict = tensor_dict + self.handle_dict = {} @staticmethod - def find_name(model,name): - for i in model.graph.node: - if (i.name == name): - return i - - - @staticmethod - def find_shape(model,input): + def run_node(node, tensors, handles): ''' - # find weight shape for layers + Args: + node: onnx node proto + tensors: dict from tensor name to tensor + handles: dict from node name to handle + ''' + inputs = [tensors[x] for x in node.input] + outputs = node.output + attrs = attribute2dict(node) + op = node.op_type + if op in UnaryOp: + tensors[outputs[0]] = UnaryOp[op](inputs[0]) + elif op in BinaryOp: + tensors[outputs[0]] = BinaryOp[op](inputs[0], inputs[1]) + elif op in OtherOp: + handle, forward = OtherOp[op] + if node.name not in handles: + handles[node.name] = handle(inputs, attrs) + tensors[outputs[0]] = forward(handles[node.name], *inputs) + elif op == 'Concat': + tensors[outputs[0]] = autograd.cat(tuple(inputs), attrs['axis']) + else: + raise NotImplementedError('Not supported op: {}'.format(op)) + + def run(self, input): + # input_dict: dict from input name to numpy array + tensors = self.tensor_dict.copy() + key=self.model.graph.input[0].name + oname = self.model.graph.output[0].name + tensors[key] = input[0] + + for node in self.model.graph.node: + if(node.op_type!="Constant"): + SingaBackendRep.run_node(node, tensors, self.handle_dict) + return tensors[oname] + + +def attribute2dict(node): + # create a dictionary from the node attribute name to value + attr = {} + for a in node.attribute: + attr[a.name] = helper.get_attribute_value(a) + return attr + + +class SingaBackend(backend): + + @classmethod + def prepare(cls, + model, # type: ModelProto + device, # type: Text + **kwargs # type: Any + ): # type: (...) -> Optional[BackendRep] + ''' + Args: + model: onnx model proto + device: 'CPU' or 'GPU' + Return: + SingaBackendRep instance ''' - for i in model.graph.node: - if (i.op_type == 'Constant' and i.output[0] == input): - return numpy_helper.to_array(i.attribute[0].t).shape + super(SingaBackend, cls).prepare(model, device, **kwargs) + name2tensor = {} + for node in model.graph.node: + if (node.op_type == 'Constant'): + data = helper.get_attribute_value(node.attribute[0]) + requires_grad, stores_grad = True, True + if len(node.attribute) == 3: + requires_grad = helper.get_attribute_value( + node.attribute[1]) + stores_grad = helper.get_attribute_value(node.attribute[2]) + t = tensor.Tensor(device=device, + data=numpy_helper.to_array(data), + requires_grad=requires_grad, + stores_grad=stores_grad) + + name2tensor[node.output[0]] = t + + return SingaBackendRep(model, device, name2tensor) + + @classmethod + def run_node(cls, node, inputs, device, outputs_info=None, **kwargs): + ''' + Args: + node: onnx node proto + inputs: dictionary of name to numpy array; the names should match Review comment: numpy array or singa tensor? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org With regards, Apache Git Services