singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [41/51] [abbrv] incubator-singa git commit: SINGA-227 Add Split and Merge Layer and add ResNet Implementation
Date Wed, 17 Aug 2016 18:03:02 GMT
SINGA-227 Add Split and Merge Layer and add ResNet Implementation

Add python resnet implementation and add Split and Merge Layer.

Discard split and merge layer in resnet implementation.

Add add_split and add_merge function where multiple inputs or outputs involved when creating
net.

change comments in resnet.py


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/7ebea537
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/7ebea537
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/7ebea537

Branch: refs/heads/master
Commit: 7ebea537edfdd2b82e9aa2c8596033e0b2cab337
Parents: cdc5ffd
Author: jixin <jixin@comp.nus.edu.sg>
Authored: Wed Aug 10 23:29:40 2016 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Wed Aug 17 11:37:59 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/resnet.py | 350 ++++++++++++++++++++++++++++++++++++++++
 examples/cifar10/train.py  |  20 ++-
 src/model/layer/merge.cc   |  62 +++++++
 src/model/layer/merge.h    |  52 ++++++
 src/model/layer/split.cc   |  54 +++++++
 src/model/layer/split.h    |  52 ++++++
 src/proto/model.proto      |  12 ++
 7 files changed, 597 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7ebea537/examples/cifar10/resnet.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/resnet.py b/examples/cifar10/resnet.py
new file mode 100644
index 0000000..c9b3e2b
--- /dev/null
+++ b/examples/cifar10/resnet.py
@@ -0,0 +1,350 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+""" The resnet model is adapted from http://torch.ch/blog/2016/02/04/resnets.html
+The best validation accuracy we achieved is about 83% without data augmentation.
+The performance could be improved by tuning some hyper-parameters, including
+learning rate, weight decay, max_epoch, parameter initialization, etc.
+"""
+
+import sys
+import os
+import math
+import cPickle as pickle
+
+#sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python'))
+# use the python modules by installing py singa in build/python
+# pip install -e .
+
+from singa import tensor
+from singa import layer
+from singa import initializer
+from singa import metric
+from singa import loss
+from singa import net as ffnet
+from singa.proto.model_pb2 import kTrain, kEval
+
+class ResNet(object):
+
+    def __init__(self, loss=None, metric=None):
+        self.loss = loss
+        self.metric = metric
+        self.layers = []
+        self.src_layers = {}
+        self.dst_layers = {}
+        self.layer_shapes = {}
+        self.layer_names = []
+
+    def to_device(self, dev):
+        for lyr in self.layers:
+            lyr.to_device(dev)
+
+    def find(self, name):
+        for i in xrange(len(self.layers)):
+            if self.layers[i].name == name:
+                return self.layers[i]
+        assert False, "Undefined layer %s." % name
+        return None
+
+    def add(self, lyr, src_lyr_name=''):
+        """Append a layer into the layer list.
+        This function will get the sample shape from the last layer to setup
+        the newly added layer. For the first layer, it is setup outside.
+        The calling function should ensure the correctness of the layer order.
+        Args:
+            lyr (Layer): the layer to be added
+            src_lyr_name: list type, name of the src layer to the current layer
+        """
+        if len(self.layers) > 0 and lyr.has_setup is False:
+            #assert src_lyr_name in dst_layers, "Undefined src layer %s" % src_lyr_name
+            shape = self.layer_shapes[src_lyr_name]
+            lyr.setup(shape)
+        print lyr.name, ': ', lyr.get_output_sample_shape()
+        if src_lyr_name != '':
+            self.src_layers[lyr.name] = [src_lyr_name]
+        self.layers.append(lyr)
+        self.layer_shapes[lyr.name] = lyr.get_output_sample_shape()            
+        self.layer_names.append(lyr.name)
+
+        if src_lyr_name != '':
+            if src_lyr_name in self.dst_layers:
+                self.dst_layers[src_lyr_name].append(lyr.name)
+            else:
+                self.dst_layers[src_lyr_name] = [lyr.name]
+        if lyr.name in self.src_layers:
+            print 'src: ', self.src_layers[lyr.name]
+        else:
+            print 'src: null'
+        #print self.layer_names
+        print "----------------------------------------"
+
+    def add_split(self, lyr_name, src_lyr_name):
+        assert src_lyr_name in self.layer_shapes, "Undefined src layer %s." % src_lyr_name
+        self.src_layers[lyr_name] = [src_lyr_name]
+        self.layer_shapes[lyr_name] = self.layer_shapes[src_lyr_name]
+        self.layer_names.append(lyr_name)
+        if src_lyr_name in self.dst_layers:
+            self.dst_layers[src_lyr_name].append(lyr_name)
+        else:
+            self.dst_layers[src_lyr_name] = [lyr_name]
+        print lyr_name, ': ', self.layer_shapes[lyr_name]
+        if lyr_name in self.src_layers:
+            print 'src: ', self.src_layers[lyr_name]
+        else:
+            print 'src: null'
+        print "----------------------------------------"
+   
+    def add_merge(self, lyr_name, src_lyr_names):
+        self.src_layers[lyr_name] = src_lyr_names
+        self.layer_shapes[lyr_name] = self.layer_shapes[src_lyr_names[0]]
+        self.layer_names.append(lyr_name)
+        for i in xrange(len(src_lyr_names)):
+            if src_lyr_names[i] in self.dst_layers:
+                self.dst_layers[src_lyr_names[i]].append(lyr_name)
+            else:
+                self.dst_layers[src_lyr_names[i]] = [lyr_name]
+        print lyr_name, ': ', self.layer_shapes[lyr_name]
+        if lyr_name in self.src_layers:
+            print 'src: ', self.src_layers[lyr_name]
+        else:
+            print 'src: null'
+        print "----------------------------------------"
+
+    def param_values(self):
+        values = []
+        for lyr in self.layers:
+            values.extend(lyr.param_values())
+        return values
+
+    def param_specs(self):
+        specs = []
+        for lyr in self.layers:
+            specs.extend(lyr.param_specs)
+        return specs
+
+    def param_names(self):
+        return [spec.name for spec in self.param_specs()]
+
+    def train(self, x, y):
+        out = self.forward(kTrain, x)
+        l = self.loss.forward(kTrain, out, y)
+        if self.metric is not None:
+            m = self.metric.evaluate(out, y)
+        return self.backward(), (l.l1(), m)
+
+    def evaluate(self, x, y):
+        """Evaluate the loss and metric of the given data"""
+        out = self.forward(kEval, x)
+        l = None
+        m = None
+        assert self.loss is not None or self.metric is not None,\
+            'Cannot do evaluation, as neither loss nor metic is set'
+        if self.loss is not None:
+            l = self.loss.evaluate(kEval, out, y)
+        if self.metric is not None:
+            m = self.metric.evaluate(out, y)
+        return l, m
+
+    def predict(self, x):
+        xx = self.forward(kEval, x)
+        return tensor.softmax(xx)
+
+    def forward(self, flag, x):
+        #print x.l1()
+        outputs = {'': x}
+        for idx, name in enumerate(self.layer_names):
+            #print 'forward layer', name
+            if idx == 0:
+                outputs[name] = self.find(name).forward(flag, outputs[''])
+                del outputs['']
+                continue
+
+            if 'split' in name:
+                src = self.src_layers[name][0]
+                #print 'src: ', src
+                outputs[name] = []
+                for i in xrange(len(self.dst_layers[name])):
+                    outputs[name].append(outputs[src])
+                del outputs[src]
+            elif 'merge' in name:
+                srcs = self.src_layers[name]
+                #print 'src: ', srcs
+                for i in xrange(len(srcs)):
+                    if 'split' in srcs[i]:
+                       if i > 0:
+                            data += outputs[srcs[i]][0]
+                       else:
+                            data = outputs[srcs[i]][0]
+                       del outputs[srcs[i]][0]
+                       if len(outputs[srcs[i]]) == 0:
+                           del outputs[srcs[i]]
+                    else:
+                        if i > 0:
+                            data += outputs[srcs[i]]
+                        else:
+                            data = outputs[srcs[i]]
+                        del outputs[srcs[i]]
+                outputs[name] = data
+            else:
+                src = self.src_layers[name][0]
+                #print 'src: ', src
+                if 'split' in src:
+                    outputs[name] = self.find(name).forward(flag, outputs[src][0])
+                    del outputs[src][0]
+                    if len(outputs[src]) == 0:
+                        del outputs[src]
+                else:
+                    outputs[name] = self.find(name).forward(flag, outputs[src])
+                    del outputs[src]
+                
+        #    print lyr.name, x.l1()
+        return outputs[name]
+
+    def backward(self, flag=kTrain):
+        grad = self.loss.backward()
+        pgrads = []
+        in_grads = {'': grad}
+        for idx, name in enumerate(reversed(self.layer_names)):
+            #print 'backward layer', name
+            if idx == 0:
+                lyr = self.find(name)
+                grad, _pgrads = lyr.backward(flag, in_grads[''])
+                for g in reversed(_pgrads):
+                    pgrads.append(g)
+                in_grads[name] = grad
+                del in_grads['']
+                continue
+
+            if 'merge' in name:
+                src = self.dst_layers[name][0]
+                #print 'src: ', src
+                in_grads[name] = []
+                for i in xrange(len(self.src_layers[name])):
+                    in_grads[name].append(in_grads[src])
+                del in_grads[src]
+            elif 'split' in name:
+                srcs = self.dst_layers[name]
+                #print 'src: ', srcs
+                for i in xrange(len(srcs)):
+                    if 'merge' in srcs[i]:
+                       if i > 0:
+                            data += in_grads[srcs[i]][0]
+                       else:
+                            data = in_grads[srcs[i]][0]
+                       del in_grads[srcs[i]][0]
+                       if len(in_grads[srcs[i]]) == 0:
+                           del in_grads[srcs[i]]
+                    else:
+                        if i > 0:
+                            data += in_grads[srcs[i]]
+                        else:
+                            data = in_grads[srcs[i]]
+                        del in_grads[srcs[i]]
+                in_grads[name] = data
+            else:
+                src = self.dst_layers[name][0]
+                #print 'src: ', src
+                if 'merge' in src:
+                    grad, _pgrads = self.find(name).backward(flag, in_grads[src][0])
+                    del in_grads[src][0]
+                    if len(in_grads[src]) == 0:
+                        del in_grads[src]
+                else:
+                    grad, _pgrads = self.find(name).backward(flag, in_grads[src])
+                    del in_grads[src]
+                for g in reversed(_pgrads):
+                    pgrads.append(g)
+                in_grads[name] = grad
+
+
+        return reversed(pgrads)
+
+    def save(self, f):
+        """Save model parameters using cpickle"""
+        params = {}
+        for (specs, val) in zip(self.param_specs(), self.param_values()):
+            val.to_host()
+            params[specs.name] = tensor.to_numpy(val)
+        with open(f, 'wb') as fd:
+            pickle.dump(params, fd)
+
+    def load(self, f):
+        """Load model parameters using cpickle"""
+        with open(f, 'rb') as fd:
+            params = pickle.load(fd)
+        for (specs, val) in zip(self.param_specs(), self.param_values()):
+            val.copy_from_numpy(params[specs.name])
+
+def Block(net, name, nb_filters, stride, std, src):
+    #net.add(layer.Split("split" + name, 2), srcs)
+    net.add_split("split" + name, src)
+    if stride > 1:
+        net.add(layer.Conv2D("conv" + name + "_br1", nb_filters, 1, stride, pad=0), "split"
+ name)
+        net.add(layer.BatchNormalization("bn" + name + "_br1"), "conv" + name + "_br1")
+        net.add(layer.Conv2D("conv" + name + "_br2a", nb_filters, 3, stride, pad=1), "split"
+ name)
+    else:
+        net.add(layer.Conv2D("conv" + name + "_br2a", nb_filters, 3, stride, pad=1), "split"
+ name)
+    net.add(layer.BatchNormalization("bn" + name + "_br2a"), "conv" + name + "_br2a")
+    net.add(layer.Activation("relu" + name + "_br2a"), "bn" + name + "_br2a")
+    net.add(layer.Conv2D("conv" + name + "_br2b", nb_filters, 3, 1, pad=1), "relu" + name
+ "_br2a")
+    net.add(layer.BatchNormalization("bn" + name + "_br2b"), "conv" + name + "_br2b")
+    if stride > 1:
+        net.add_merge("merge" + name, ["bn" + name + "_br1", "bn" + name + "_br2b"])
+    else:
+        net.add_merge("merge" + name, ["split" + name, "bn" + name + "_br2b"])
+
+def create_net():
+    net = ResNet(loss.SoftmaxCrossEntropy(), metric.Accuracy())
+    net.add(layer.Conv2D("conv1", 16, 3, 1, pad=1, input_sample_shape=(3, 32, 32)))
+    net.add(layer.BatchNormalization("bn1"), "conv1")
+    net.add(layer.Activation("relu1"), "bn1")
+   
+    Block(net, "2a", 16, 1, 0.01, "relu1")
+    Block(net, "2b", 16, 1, 0.01, "merge2a")
+    Block(net, "2c", 16, 1, 0.01, "merge2b")
+
+    Block(net, "3a", 32, 2, 0.01, "merge2c")
+    Block(net, "3b", 32, 1, 0.01, "merge3a")
+    Block(net, "3c", 32, 1, 0.01, "merge3b")
+
+    Block(net, "4a", 64, 2, 0.01, "merge3c")
+    Block(net, "4b", 64, 1, 0.01, "merge4a")
+    Block(net, "4c", 64, 1, 0.01, "merge4b")
+
+    net.add(layer.AvgPooling2D("pool4", 8, 8, border_mode='valid'), "merge4c")
+    net.add(layer.Flatten('flat'), "pool4")
+    net.add(layer.Dense('ip5', 10), "flat")
+    net.add(layer.Softmax('softmax'), "ip5")
+    print 'Start intialization............'
+    for (p, name) in zip(net.param_values(), net.param_names()):
+        print name, p.shape
+        if 'mean' in name or 'beta' in name:
+            p.set_value(0.0)
+        elif 'var' in name:
+            p.set_value(1.0)
+        elif 'gamma' in name:
+            initializer.uniform(p, 0, 1)
+        elif len(p.shape) > 1:
+            if 'conv' in name:
+                #initializer.gaussian(p, 0, math.sqrt(2.0/p.shape[1]))
+                initializer.gaussian(p, 0, math.sqrt(2.0/(9.0*p.shape[0])))
+            else:
+                initializer.gaussian(p, 0, 0.02)
+        else:
+            p.set_value(0)
+        print name, p.l1()
+
+    return net

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7ebea537/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index 8f596e5..6b7631e 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -33,7 +33,7 @@ from singa.proto import core_pb2
 
 import alexnet
 import vgg
-
+import resnet
 
 def load_dataset(filepath):
     print 'Loading data file %s' % filepath
@@ -94,6 +94,13 @@ def alexnet_lr(epoch):
     else:
         return 0.00001
 
+def resnet_lr(epoch):
+    if epoch < 80:
+        return 0.02
+    elif epoch < 120:
+        return 0.005
+    else:
+        return 0.001
 
 def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
           use_cpu=False):
@@ -152,9 +159,8 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100,
     net.save('model.bin')  # save model params into checkpoint file
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='Train vgg/alexnet for '
-                                     'cifar10 dataset')
-    parser.add_argument('model', choices=['vgg', 'alexnet'], default='alexnet')
+    parser = argparse.ArgumentParser(description='Train vgg/alexnet for cifar10')
+    parser.add_argument('model', choices=['vgg', 'alexnet', 'resnet'], default='alexnet')
     parser.add_argument('data', default='cifar-10-batches-py')
     parser.add_argument('--use_cpu', action='store_true')
     args = parser.parse_args()
@@ -168,8 +174,12 @@ if __name__ == '__main__':
         net = alexnet.create_net(args.use_cpu)
         train((train_x, train_y, test_x, test_y), net, 160, alexnet_lr, 0.004,
               use_cpu=args.use_cpu)
-    else:
+    elif args.model == 'vgg':
         train_x, test_x = normalize_for_vgg(train_x, test_x)
         net = vgg.create_net(args.use_cpu)
         train((train_x, train_y, test_x, test_y), net, 250, vgg_lr, 0.0005,
               use_cpu=args.use_cpu)
+    else:
+        train_x, test_x = normalize_for_vgg(train_x, test_x)
+        net = resnet.create_net()
+        train((train_x, train_y, test_x, test_y), net, 200, resnet_lr, 1e-4)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7ebea537/src/model/layer/merge.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/merge.cc b/src/model/layer/merge.cc
new file mode 100644
index 0000000..a30c3b3
--- /dev/null
+++ b/src/model/layer/merge.cc
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/layer.h"
+#include "./merge.h"
+namespace singa {
+
+RegisterLayerClass(singa_merge, Merge);
+
+void Merge::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
+  MergeConf merge_conf = conf.merge_conf();
+  input_size_ = merge_conf.input_size();
+  out_sample_shape_ = in_sample;
+}
+
+const vector<Tensor> Merge::Forward(int flag, const vector<Tensor>& inputs)
{
+  vector<Tensor> outputs;
+  //input_size_ = inputs.size();
+  if (input_size_ == 1u) {
+    outputs = inputs;
+  } else {
+    Tensor sum = inputs.at(0);
+    for (size_t i = 1; i < inputs.size(); i++) {
+      Tensor temp = inputs.at(i);
+      CHECK_EQ(sum.nDim(), temp.nDim());
+      for (size_t j = 0; j < temp.nDim(); j++)
+        CHECK_EQ(sum.shape(j), temp.shape(j));
+      sum += temp;
+    }
+    outputs.push_back(sum);
+  }
+  return outputs;
+}
+
+const std::pair<vector<Tensor>, vector<Tensor>> Merge::Backward(
+    int flag, const vector<Tensor>& grads) {
+  vector<Tensor> input_grad, param_grad;
+  if (grads.size() != 1u) {
+    LOG(INFO) << "Merge layer only have one output tensor.";
+  }
+  for (size_t i = 0; i < input_size_; i++)
+    input_grad.push_back(grads.at(0));
+  return std::make_pair(input_grad, param_grad);
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7ebea537/src/model/layer/merge.h
----------------------------------------------------------------------
diff --git a/src/model/layer/merge.h b/src/model/layer/merge.h
new file mode 100644
index 0000000..9c34192
--- /dev/null
+++ b/src/model/layer/merge.h
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SRC_MODEL_LAYER_MERGE_H_
+#define SRC_MODEL_LAYER_MERGE_H_
+#include <string>
+#include <utility>
+#include <vector>
+#include "singa/model/layer.h"
+
+namespace singa {
+class Merge : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "Merge"; }
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() const override {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
+  /// \copydoc Layer::Forward(int flag, const vector<Tensor>&)
+  const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override;
+
+  /// \copydoc Layer::Backward(int, const vector<Tensor>&);
+  const std::pair<vector<Tensor>, vector<Tensor>> Backward(int flag,
+                                                   const vector<Tensor>& grads)
override;
+
+  const size_t input_size() const { return input_size_; }
+
+ protected:
+  // To store the input and output(of forward) tensors
+  Shape out_sample_shape_;
+  size_t input_size_;
+};
+}  // namespace singa
+#endif  // SRC_MODEL_LAYER_MERGE_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7ebea537/src/model/layer/split.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/split.cc b/src/model/layer/split.cc
new file mode 100644
index 0000000..fd1ab7d
--- /dev/null
+++ b/src/model/layer/split.cc
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "singa/model/layer.h"
+#include "./split.h"
+namespace singa {
+
+RegisterLayerClass(singa_split, Split);
+
+void Split::Setup(const Shape& in_sample, const LayerConf& conf) {
+  Layer::Setup(in_sample, conf);
+  SplitConf split_conf = conf.split_conf();
+  output_size_ = split_conf.output_size();
+  out_sample_shape_ = in_sample;
+}
+
+const vector<Tensor> Split::Forward(int flag, const vector<Tensor>& inputs)
{
+  vector<Tensor> outputs;
+  if (inputs.size() != 1)
+    LOG(FATAL) << "Split layer only have one input tensor.";
+  for (size_t i = 0; i < output_size_; i++)
+    outputs.push_back(inputs.at(0));
+  return outputs;
+}
+
+const std::pair<vector<Tensor>, vector<Tensor>> Split::Backward(
+    int flag, const vector<Tensor>& grads) {
+  vector<Tensor> input_grad, param_grad;
+  CHECK_EQ(grads.size(), output_size_);
+  
+  /// Input_grad is the sum of all the output gradients.
+  Tensor temp = grads.at(0);
+  for (size_t i = 1; i < output_size_; i++)
+    temp += grads.at(i);
+  input_grad.push_back(temp);
+  return std::make_pair(input_grad, param_grad);
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7ebea537/src/model/layer/split.h
----------------------------------------------------------------------
diff --git a/src/model/layer/split.h b/src/model/layer/split.h
new file mode 100644
index 0000000..79e70f6
--- /dev/null
+++ b/src/model/layer/split.h
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SRC_MODEL_LAYER_SPLIT_H_
+#define SRC_MODEL_LAYER_SPLIT_H_
+#include <string>
+#include <utility>
+#include <vector>
+#include "singa/model/layer.h"
+
+namespace singa {
+class Split : public Layer {
+ public:
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "Split"; }
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const Shape& in_sample, const LayerConf& conf) override;
+  const Shape GetOutputSampleShape() const override {
+    CHECK(out_sample_shape_.size()) << "You may haven't call Setup()";
+    return out_sample_shape_;
+  }
+  /// \copydoc Layer::Forward(int flag, const vector<Tensor>&)
+  const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override;
+
+  /// \copydoc Layer::Backward(int, const vector<Tensor>&);
+  const std::pair<vector<Tensor>, vector<Tensor>> Backward(int flag,
+                                                   const vector<Tensor>& grads)
override;
+
+  const size_t output_size() const { return output_size_; }
+
+ protected:
+  // To store the input and output(of forward) tensors
+  Shape out_sample_shape_;
+  size_t output_size_;
+};
+}  // namespace singa
+#endif  // SRC_MODEL_LAYER_SPLIT_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7ebea537/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 6923820..1796e9c 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -241,6 +241,8 @@ message LayerConf {
   optional DenseConf dense_conf = 201;
   optional MetricConf metric_conf = 200;
   optional BatchNormConf batchnorm_conf = 202;
+  optional SplitConf split_conf = 203;
+  optional MergeConf merge_conf = 204;
 }
 
 // Message that stores hyper-parameters used to apply transformation
@@ -948,3 +950,13 @@ message BatchNormConf {
   // newMean*factor + runningMean*(1-factor).
   optional double factor = 1 [default = 0.9];
 }
+
+message SplitConf {
+  // Indicate the number of outputs
+  optional int32 output_size = 1 [default = 2];
+}
+
+message MergeConf {
+  // Indicate the number of outputs
+  optional int32 input_size = 1 [default = 2];
+}


Mime
View raw message