From commits-return-2381-archive-asf-public=cust-asf.ponee.io@singa.incubator.apache.org Mon Jul 16 05:13:34 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 27A20180621 for ; Mon, 16 Jul 2018 05:13:33 +0200 (CEST) Received: (qmail 82807 invoked by uid 500); 16 Jul 2018 03:13:33 -0000 Mailing-List: contact commits-help@singa.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@singa.incubator.apache.org Delivered-To: mailing list commits@singa.incubator.apache.org Received: (qmail 82798 invoked by uid 99); 16 Jul 2018 03:13:33 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd4-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 16 Jul 2018 03:13:33 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd4-us-west.apache.org (ASF Mail Server at spamd4-us-west.apache.org) with ESMTP id AD887C02CD for ; Mon, 16 Jul 2018 03:13:32 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd4-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -11.701 X-Spam-Level: X-Spam-Status: No, score=-11.701 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, RCVD_IN_DNSWL_HI=-5, SPF_PASS=-0.001, USER_IN_DEF_SPF_WL=-7.5] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd4-us-west.apache.org [10.40.0.11]) (amavisd-new, port 10024) with ESMTP id h5BFG1Z3xqcw for ; Mon, 16 Jul 2018 03:13:29 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with SMTP id A7E875F175 for ; Mon, 16 Jul 2018 03:13:28 +0000 (UTC) Received: (qmail 82645 invoked by uid 99); 16 Jul 2018 03:13:27 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 16 Jul 2018 03:13:27 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id BD2F3DFE60; Mon, 16 Jul 2018 03:13:27 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: zhaojing@apache.org To: commits@singa.incubator.apache.org Date: Mon, 16 Jul 2018 03:13:27 -0000 Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: [1/4] incubator-singa git commit: SINGA-384 Implement ResNet using autograd API Repository: incubator-singa Updated Branches: refs/heads/master 76779be72 -> 870c5df0b SINGA-384 Implement ResNet using autograd API Add ResNet as an example of autograd. Rename autograd operations to be consistent with torch Pass the inference of resnet Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2b5c3f70 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2b5c3f70 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2b5c3f70 Branch: refs/heads/master Commit: 2b5c3f709ee2c0530f4a97ea26a34f55bff36c6e Parents: 76779be Author: Wang Wei Authored: Fri Jul 13 16:06:32 2018 +0800 Committer: Wang Wei Committed: Mon Jul 16 10:04:13 2018 +0800 ---------------------------------------------------------------------- examples/autograd/resnet.py | 226 ++++++++++++++++++++++++++++++++++ python/singa/autograd.py | 38 +++--- src/model/operation/batchnorm.cc | 1 - 3 files changed, 243 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2b5c3f70/examples/autograd/resnet.py ---------------------------------------------------------------------- diff --git a/examples/autograd/resnet.py b/examples/autograd/resnet.py new file mode 100644 index 0000000..930d9e0 --- /dev/null +++ b/examples/autograd/resnet.py @@ -0,0 +1,226 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# the code is modified from +# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +from singa import autograd +from singa import tensor +from singa import device + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return autograd.Conv2D(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(autograd.Layer): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = autograd.BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + self.bn2 = autograd.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def __call__(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = autograd.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = autograd.relu(out) + + return out + + +class Bottleneck(autograd.Layer): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = autograd.Conv2D( + inplanes, planes, kernel_size=1, bias=False) + self.bn1 = autograd.BatchNorm2d(planes) + self.conv2 = autograd.Conv2D(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = autograd.BatchNorm2d(planes) + self.conv3 = autograd.Conv2D( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = autograd.BatchNorm2d(planes * self.expansion) + + self.downsample = downsample + self.stride = stride + + def __call__(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = autograd.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = autograd.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = autograd.relu(out) + + return out + + +class ResNet(autograd.Layer): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = autograd.Conv2D(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = autograd.BatchNorm2d(64) + self.maxpool = autograd.MaxPool2d( + kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = autograd.AvgPool2d(7, stride=1) + self.fc = autograd.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + conv = autograd.Conv2D(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False) + bn = autograd.BatchNorm2d(planes * block.expansion), + downsample = lambda x: bn(conv(x)) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + def forward(x): + for layer in layers: + x = layer(x) + return x + return forward + + def __call__(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = autograd.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = autograd.flatten(x) + x = self.fc(x) + + return x + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + + return model + + +if __name__ == '__main__': + + model = resnet18() + x = tensor.Tensor((16, 3, 224, 224), device.create_cuda_gpu()) + x.set_value(float(0.1)) + autograd.training = True + y = model(x) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2b5c3f70/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py index faa9685..c77c174 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -572,12 +572,12 @@ class Concat(Operation): return tuple(dxs) -def concat(xs, axis=0): +def cat(xs, axis=0): # xs is a tuple of multiple Tensors return Concat(axis)(*xs)[0] -class _Conv2D(Operation): +class _Conv2d(Operation): def __init__(self, handle): self.handle = handle @@ -627,10 +627,10 @@ class _Conv2D(Operation): def conv2d(handle, x, W, b): - return _Conv2D(handle)(x, W, b)[0] + return _Conv2d(handle)(x, W, b)[0] -class Conv2D(Layer): +class Conv2d(Layer): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs): @@ -693,10 +693,6 @@ class Conv2D(Layer): def __call__(self, x): assert x.shape[1] == self.in_channels, 'in_channels dismatched' - assert (x.shape[2] + 2 * self.padding[0] - self.kernel_size[0] - ) % self.stride[0] == 0, 'invalid padding or strides.' - assert (x.shape[3] + 2 * self.padding[1] - self.kernel_size[1] - ) % self.stride[1] == 0, 'invalid padding or stride.' self.device_check(x, self.W, self.b) @@ -720,7 +716,7 @@ class Conv2D(Layer): return y -class BatchNorm(Layer): +class BatchNorm2d(Layer): def __init__(self, num_features, momentum=0.9): self.channels = num_features @@ -765,7 +761,7 @@ class BatchNorm(Layer): return y -class _BatchNorm(Operation): +class _BatchNorm2d(Operation): def __init__(self, handle, running_mean, running_var): self.running_mean = running_mean.data @@ -805,11 +801,11 @@ class _BatchNorm(Operation): return dx, ds, db -def batchnorm(handle, x, scale, bias, running_mean, running_var): +def batchnorm_2d(handle, x, scale, bias, running_mean, running_var): return _BatchNorm(handle, running_mean, running_var)(x, scale, bias)[0] -class _Pooling2D(Operation): +class _Pooling2d(Operation): def __init__(self, handle): self.handle = handle @@ -838,7 +834,7 @@ def pooling_2d(handle, x): return _Pooling2D(handle)(x)[0] -class Pooling2D(Layer): +class Pooling2d(Layer): def __init__(self, kernel_size, stride=None, padding=0, is_max=True): if isinstance(kernel_size, int): @@ -897,31 +893,31 @@ class Pooling2D(Layer): return y -class MaxPooling2D(Pooling2D): +class MaxPool2d(Pooling2D): def __init__(self, kernel_size, stride=None, padding=0): - super(MaxPooling2D, self).__init__(kernel_size, stride, padding, True) + super(MaxPool2d, self).__init__(kernel_size, stride, padding, True) -class AvgPooling2D(Pooling2D): +class AvgPool2d(Pooling2D): def __init__(self, kernel_size, stride=None, padding=0): - super(AvgPooling2D, self).__init__(kernel_size, stride, padding, False) + super(AvgPool2d, self).__init__(kernel_size, stride, padding, False) -class MaxPooling1D(Pooling2D): +class MaxPool1d(Pooling2D): def __init__(self, kernel_size, stride=None, padding=0): if stride is None: stride = kernel_size - super(MaxPooling2D, self).__init__( + super(MaxPool2d, self).__init__( (1, kernel_size), (0, stride), (0, padding), True) -class AvgPooling1D(Pooling2D): +class AvgPool1d(Pooling2D): def __init__(self, kernel_size, stride=None, padding=0): if stride is None: stride = kernel_size - super(MaxPooling2D, self).__init__( + super(MaxPool2d, self).__init__( (1, kernel_size), (0, stride), (0, padding), False) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2b5c3f70/src/model/operation/batchnorm.cc ---------------------------------------------------------------------- diff --git a/src/model/operation/batchnorm.cc b/src/model/operation/batchnorm.cc index 29eaba9..4673919 100755 --- a/src/model/operation/batchnorm.cc +++ b/src/model/operation/batchnorm.cc @@ -121,7 +121,6 @@ const std::vector GpuBatchNormBackward(const CudnnBatchNormHandle &cbnh, CHECK_EQ(mean.device()->lang(), kCuda); CHECK_EQ(var.device()->lang(), kCuda); - vector out_grads; Tensor dx; dx.ResetLike(dy);