singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zhaoj...@apache.org
Subject [2/4] incubator-singa git commit: SINGA-126 Python Binding for Interactive Training
Date Wed, 06 Apr 2016 08:38:12 GMT
SINGA-126 Python Binding for Interactive Training

- add 2 example python codes for interactive training
  . train_mnist.py
  . train_cifar10.py

- add methods/class in singa/layer.py
  . ComputeFeature, ComputeGradient, Feed, Setup
  . GetParams, SetParams, GetData
  . Dummy()

- add methods in singa/model.py
  . save_model_parameter
  . load_model_parameter

- add Feed fucntion in src/neuralnet/neuron_layer/dummy.cc
  . correspond to class Dummy() in layer.py
    note: DummyInputLayer and kDummyInput are removed

- add functions in src/worker.cc
  . Checkpoint
  . InitNetParams

- add CreateXXX functions to set up singa::XXX from string proto
  . XXX are Layer, Updater, Worker

- update tool/python/singa/driver.i for wrapper

- include cifar10_mean_image in examples/datasets/


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8130b7ed
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8130b7ed
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8130b7ed

Branch: refs/heads/master
Commit: 8130b7ed14e6b556749bf49f0e3ca2a2f6f00e2d
Parents: 8b7d1e0
Author: chonho <leech@comp.nus.edu.sg>
Authored: Thu Mar 17 13:04:40 2016 +0800
Committer: chonho <leech@comp.nus.edu.sg>
Committed: Tue Apr 5 23:22:51 2016 +0800

----------------------------------------------------------------------
 include/singa/neuralnet/input_layer.h           |   2 +
 include/singa/neuralnet/layer.h                 |  12 +
 include/singa/neuralnet/neuralnet.h             |   7 +
 include/singa/neuralnet/neuron_layer.h          |   9 +
 include/singa/utils/param.h                     |   2 +
 include/singa/utils/updater.h                   |   8 +
 include/singa/worker.h                          |   3 +
 src/driver.cc                                   |   1 +
 src/neuralnet/layer.cc                          |  18 +
 src/neuralnet/loss_layer/softmax.cc             |   1 +
 src/neuralnet/neuralnet.cc                      |   6 +
 src/neuralnet/neuron_layer/dummy.cc             |  28 ++
 src/neuralnet/neuron_layer/inner_product.cc     |   2 +
 src/proto/job.proto                             |   1 +
 src/utils/param.cc                              |   6 +
 src/utils/updater.cc                            |  13 +-
 src/worker.cc                                   |  44 ++
 .../python/examples/datasets/cifar10_mean_image | Bin 0 -> 24576 bytes
 tool/python/examples/train_cifar10.py           | 117 ++++++
 tool/python/examples/train_mnist.py             | 107 +++++
 tool/python/singa/driver.i                      | 101 ++++-
 tool/python/singa/layer.py                      | 400 +++++++++++++++++--
 tool/python/singa/model.py                      |  81 +++-
 tool/python/singa/utils/utility.py              |   7 +
 24 files changed, 926 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/include/singa/neuralnet/input_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/input_layer.h b/include/singa/neuralnet/input_layer.h
index e701eec..73d509b 100644
--- a/include/singa/neuralnet/input_layer.h
+++ b/include/singa/neuralnet/input_layer.h
@@ -200,6 +200,8 @@ class RNNLabelLayer : public InputLayer {
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers);
   void ComputeFeature(int flag, const vector<Layer*>& srclayers);
 };
+
+
 /****************Deprecated layers******************/
 /**
  * @deprecated please use the StoreInputLayer.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/include/singa/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h
index c1612a2..ce47b47 100644
--- a/include/singa/neuralnet/layer.h
+++ b/include/singa/neuralnet/layer.h
@@ -57,6 +57,7 @@ inline const string AddPrefixSuffix(int unroll_idx, int partition_idx,
  * Layer::ComputeFeature() and Layer::ComputGradient()
  * functions in accordance with the NeuralNet::TrainOneBatch function.
  */
+
 class Layer {
  public:
   /**
@@ -69,6 +70,14 @@ class Layer {
 
   Layer() {}
   virtual ~Layer() {}
+
+  /**
+   * Create for python binding, production test mode
+   *
+   */
+  static Layer* CreateLayer(const string str);
+  static void SetupLayer(Layer* layer, const string str, const vector<Layer*>& srclayers);
+
   /**
    * Setup layer properties.
    *
@@ -84,6 +93,8 @@ class Layer {
     datavec_.push_back(&data_);
     gradvec_.push_back(&grad_);
   }
+
+
   /**
    * Compute features of this layer based on connected layers.
    *
@@ -108,6 +119,7 @@ class Layer {
   virtual const std::vector<Param*> GetParams() const {
     return std::vector<Param*> {};
   }
+  virtual void SetParams(std::vector<Param*>) {}
   /**
    * Return the connection type between one neuron of this layer and its source
    * layer.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/include/singa/neuralnet/neuralnet.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuralnet.h b/include/singa/neuralnet/neuralnet.h
index 33ad38c..60887e0 100644
--- a/include/singa/neuralnet/neuralnet.h
+++ b/include/singa/neuralnet/neuralnet.h
@@ -58,6 +58,13 @@ class NeuralNet {
   static NeuralNet* Create(const NetProto& net_conf, Phase phase,
                            int npartitions);
 
+  /**
+   * Create for python binding, production test mode
+   *
+   */
+  static NeuralNet* CreateNeuralNet(const string str);
+  NeuralNet() {};
+  
   static const NetProto Unrolling(const NetProto& net_conf);
   /**
    * construct the net structure from protocol buffer.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h
index f03e91b..2d73854 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -122,13 +122,17 @@ class DropoutLayer : public NeuronLayer {
  */
 class DummyLayer: public NeuronLayer {
  public:
+  void Setup(const std::string str, const vector<Layer*>& srclayers);
   void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
   void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
   void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
+  void Feed(int batchsize, vector<float>& data, int is_aux);
+  Layer* ToLayer() { return this;}
 
  private:
   bool input_ = false;  // use as input layer
   bool output_ = false;  // use as output layer
+  int batchsize_ = 1;  // use for input layer
 };
 
 /**
@@ -224,6 +228,11 @@ class InnerProductLayer : public NeuronLayer {
     return params;
   }
 
+  void SetParams(std::vector<Param*> params) {
+    weight_ = params.at(0);
+    bias_ = params.at(1);
+  }
+
  private:
   int batchsize_;
   int vdim_, hdim_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/include/singa/utils/param.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/param.h b/include/singa/utils/param.h
index fcaaeb7..319f2b4 100644
--- a/include/singa/utils/param.h
+++ b/include/singa/utils/param.h
@@ -155,6 +155,7 @@ class Param {
    * Init param values from checkpoint blob.
    */
   void FromProto(const BlobProto& blob);
+  void FromProto(const std::string str);
   /**
    * Dump param values to blob.
    */
@@ -211,6 +212,7 @@ class Param {
    /**
     * @return num of parameters in this Param obj.
     */
+  inline const std::vector<int>& shape() const { return data_.shape(); }
   inline int size() const { return data_.count(); }
   inline const Blob<float>& data() const { return data_; }
   inline Blob<float>* mutable_data() { return &data_; }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/include/singa/utils/updater.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/updater.h b/include/singa/utils/updater.h
index b14f72b..33ad8a7 100644
--- a/include/singa/utils/updater.h
+++ b/include/singa/utils/updater.h
@@ -22,10 +22,13 @@
 #ifndef SINGA_UTILS_UPDATER_H_
 #define SINGA_UTILS_UPDATER_H_
 
+#include <string>
 #include "singa/proto/job.pb.h"
 #include "singa/utils/param.h"
+#include "singa/neuralnet/layer.h"
 
 namespace singa {
+using std::string;
 /**
  * Base learning rate generator.
  *
@@ -87,6 +90,11 @@ class InvTLRGen : public LRGenerator {
  */
 class Updater {
  public:
+
+  /* added for python binding */
+  static Updater* CreateUpdater(const string str);
+  /* ------------------------ */
+
   static Updater* Create(const UpdaterProto& proto);
 
   virtual ~Updater() {}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/include/singa/worker.h
----------------------------------------------------------------------
diff --git a/include/singa/worker.h b/include/singa/worker.h
index 34c8000..d53e54b 100644
--- a/include/singa/worker.h
+++ b/include/singa/worker.h
@@ -61,6 +61,7 @@ class Worker {
    *
    * @return a pointer to the instance of the Worker subclass.
    */
+  static Worker* CreateWorker(const std::string str);
   static Worker* Create(const AlgProto& conf);
   virtual ~Worker();
   /**
@@ -129,6 +130,7 @@ class Worker {
    * initialized.
    */
   void InitNetParams(const JobProto& job_conf, NeuralNet* net);
+  void InitNetParams(const std::string& folder, vector<Layer*> net);
   /**
    * Checkpoint all Param objects owned by the worker onto disk.
    * The serialization is done using BlobProtos which includes the name, version
@@ -140,6 +142,7 @@ class Worker {
    * @param net the training net whose Param objects will be dumped.
    */
   void Checkpoint(int step, const std::string& folder, NeuralNet* net);
+  void Checkpoint(int step, const std::string& folder, vector<Layer*> net);
   /**
     * Train one mini-batch.
     * Test/Validation is done before training.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 83f3953..702df5e 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -96,6 +96,7 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<CConvolutionLayer, int>(kCConvolution);
   RegisterLayer<CPoolingLayer, int>(kCPooling);
   RegisterLayer<EmbeddingLayer, int>(kEmbedding);
+  RegisterLayer<ActivationLayer, int>(kActivation);
 
 #ifdef USE_CUDNN
   RegisterLayer<CudnnActivationLayer, int>(kCudnnActivation);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index cb1f3b8..ef1629f 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -19,7 +19,11 @@
 *
 *************************************************************/
 
+#include "singa/worker.h"
 #include "singa/neuralnet/layer.h"
+#include "singa/neuralnet/input_layer.h"
+#include "singa/neuralnet/neuron_layer.h"
+#include "singa/neuralnet/loss_layer.h"
 
 #include <cblas.h>
 #include <glog/logging.h>
@@ -33,6 +37,20 @@ namespace singa {
 
 using std::string;
 
+void Layer::SetupLayer(Layer* layer, const string str, const vector<Layer*>& srclayers) {
+  LayerProto layer_conf;
+  layer_conf.ParseFromString(str);
+  layer->Setup(layer_conf, srclayers);
+  for (auto param : layer->GetParams())
+      param->InitValues();
+}
+
+Layer* Layer::CreateLayer(const string str) {
+  LayerProto layer_conf;
+  layer_conf.ParseFromString(str);
+  return Layer::Create(layer_conf);
+}
+
 Layer* Layer::Create(const LayerProto& proto) {
   auto* factory = Singleton<Factory<Layer>>::Instance();
   Layer* layer = nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/neuralnet/loss_layer/softmax.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/loss_layer/softmax.cc b/src/neuralnet/loss_layer/softmax.cc
index 7956470..9d0cb1d 100644
--- a/src/neuralnet/loss_layer/softmax.cc
+++ b/src/neuralnet/loss_layer/softmax.cc
@@ -98,6 +98,7 @@ void SoftmaxLossLayer::ComputeGradient(int flag,
   Tensor<cpu, 1> gsrc(gsrcptr, Shape1(gsrcblob->count()));
   gsrc *= scale_ / (1.0f * batchsize_);
 }
+
 const std::string SoftmaxLossLayer::ToString(bool debug, int flag) {
   if (debug)
     return Layer::ToString(debug, flag);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index b045e06..226d8d9 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -58,6 +58,12 @@ const NetProto NetConfPreprocess(const NetProto& conf) {
   return proto;
 }
 
+NeuralNet* NeuralNet::CreateNeuralNet(const string str) {
+  NetProto net_conf;
+  net_conf.ParseFromString(str);
+  return NeuralNet::Create(net_conf,singa::kTest,1);
+}
+
 NeuralNet* NeuralNet::Create(const NetProto& net_conf, Phase phase,
     int npartitions) {
   const NetProto& full_net_conf = NetConfPreprocess(net_conf);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/neuralnet/neuron_layer/dummy.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/dummy.cc b/src/neuralnet/neuron_layer/dummy.cc
index 936bb5e..001e053 100644
--- a/src/neuralnet/neuron_layer/dummy.cc
+++ b/src/neuralnet/neuron_layer/dummy.cc
@@ -27,6 +27,13 @@
 
 namespace singa {
 
+void DummyLayer::Setup(const std::string str,
+                       const vector<Layer*>& srclayers) {
+  LayerProto conf;
+  conf.ParseFromString(str);
+  DummyLayer::Setup(conf, srclayers);
+}
+
 void DummyLayer::Setup(const LayerProto& proto,
                        const vector<Layer*>& srclayers) {
   NeuronLayer::Setup(proto, srclayers);
@@ -71,4 +78,25 @@ void DummyLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
     Copy(grad_, srclayers[0]->mutable_grad(this));
 }
 
+void DummyLayer::Feed(int batchsize, vector<float>& data, int is_aux){
+
+    batchsize_ = batchsize;
+    // input data
+    if (is_aux == 0) {
+      int size = data.size();
+      float* ptr = data_.mutable_cpu_data();
+      for (int i = 0; i< size; i++) { 
+          ptr[i] = data.at(i);
+      }
+    }
+    // label
+    else {
+      aux_data_.resize(batchsize_);
+      for (int i = 0; i< batchsize_; i++) {
+          aux_data_[i] = static_cast<int>(data.at(i));
+      }
+    }
+    return;
+}
+
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/neuralnet/neuron_layer/inner_product.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/inner_product.cc b/src/neuralnet/neuron_layer/inner_product.cc
index 1e5e93e..a7378a2 100644
--- a/src/neuralnet/neuron_layer/inner_product.cc
+++ b/src/neuralnet/neuron_layer/inner_product.cc
@@ -83,5 +83,7 @@ void InnerProductLayer::ComputeGradient(int flag,
     else
       MMDot(grad_, weight_->data(), srclayers[0]->mutable_grad(this));
   }
+  //clee auto w = weight_->mutable_cpu_data();
+  //LOG(ERROR) << srclayers[0]->name() << " " << w[0];
 }
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 7bc0ea3..6afc599 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -676,6 +676,7 @@ enum LayerType {
   kSoftmax = 214;
   kGRU = 215;
   kEmbedding = 216;
+  kActivation = 217;
 
   // cudnn v3
   kCudnnConv = 250;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index e1c04c7..73d8314 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -199,6 +199,12 @@ void Param::ShareFrom(Param* other) {
   grad_.ShareData(&(other->grad_), false);
 }
 
+void Param::FromProto(const string str) {
+  BlobProto blob;
+  blob.ParseFromString(str);
+  data_.FromProto(blob);
+}
+
 void Param::FromProto(const BlobProto& blob) {
   data_.FromProto(blob);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/utils/updater.cc
----------------------------------------------------------------------
diff --git a/src/utils/updater.cc b/src/utils/updater.cc
index fa051b1..a2180d3 100644
--- a/src/utils/updater.cc
+++ b/src/utils/updater.cc
@@ -60,8 +60,8 @@ float StepLRGen::Get(int step) {
   // do not cast int to float
   int freq = proto_.step_conf().change_freq();
   float lr = proto_.base_lr() * pow(proto_.step_conf().gamma(), step / freq);
-  LOG_IF(INFO, step % freq == 0) << "Update learning rate to " << lr
-    << " @ step " << step;
+  // LOG_IF(INFO, step % freq == 0) << "Update learning rate to " << lr
+  //   << " @ step " << step;
   return lr;
 }
 
@@ -96,6 +96,15 @@ Updater* Updater::Create(const UpdaterProto& proto) {
   return updater;
 }
 
+/**************** added for Python Binding ***************************/
+Updater* Updater::CreateUpdater(const string str) {
+  UpdaterProto conf;
+  conf.ParseFromString(str);
+  return Updater::Create(conf);
+}
+/***********************Python Binding end**************************/
+
+
 /***********************SGD with momentum******************************/
 void Updater::Init(const UpdaterProto& proto) {
   momentum_ = proto.momentum();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/src/worker.cc
----------------------------------------------------------------------
diff --git a/src/worker.cc b/src/worker.cc
index 5206513..e92d780 100644
--- a/src/worker.cc
+++ b/src/worker.cc
@@ -35,6 +35,12 @@ namespace singa {
 
 using std::string;
 
+Worker* Worker::CreateWorker(const string str) {
+  AlgProto alg_proto;
+  alg_proto.ParseFromString(str);
+  return Worker::Create(alg_proto);
+}
+
 Worker* Worker::Create(const AlgProto& conf) {
   auto factory = Singleton<Factory<singa::Worker>>::Instance();
   Worker* worker = nullptr;
@@ -154,6 +160,23 @@ void Worker::InitSockets(const NeuralNet* net) {
   }
 }
 
+void Worker::InitNetParams(const std::string& folder, vector<Layer*> net) {
+
+    std::unordered_map<string, Param*> name2param;
+    for (auto layer : net) {
+        for (auto param : layer->GetParams()) {
+          // only owners fill the memory of parameter values.
+          //if (param->owner() == param->id()) {
+            CHECK(name2param.find(param->name()) == name2param.end());
+            name2param[param->name()] = param;
+          //}
+        }
+    }
+    vector<string> paths;
+    paths.push_back(folder);
+    NeuralNet::Load(paths, name2param);
+}
+
 void Worker::InitNetParams(const JobProto& job_conf, NeuralNet* net) {
   // for each server grp, its first subscriber worker grp does the param init
   if (grp_id_ % Cluster::Get()->nworker_groups_per_server_group() == 0) {
@@ -209,6 +232,27 @@ void Worker::InitNetParams(const JobProto& job_conf, NeuralNet* net) {
   }
 }
 
+void Worker::Checkpoint(int step, const std::string& folder, vector<Layer*> net) {
+  BlobProtos bps;
+  for (auto layer : net) {
+    //if (layer->partition_id() == id_) {
+      for (auto param : layer->GetParams()) {
+        // only owners fill the memory of parameter values.
+        //if (param->owner() == param->id()) {
+          auto *blob = bps.add_blob();
+          param->ToProto(blob);
+          bps.add_version(param->version());
+          bps.add_name(param->name());
+        //}
+      }
+    //}
+  }
+  char buf[256];
+  snprintf(buf, sizeof(buf), "%s/step%d-worker0", folder.c_str(), step);
+  LOG(INFO) << "checkpoint to " << buf;
+  WriteProtoToBinaryFile(bps, buf);
+}
+
 void Worker::Checkpoint(int step, const std::string& folder, NeuralNet* net) {
   BlobProtos bps;
   for (auto layer : net->layers()) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/tool/python/examples/datasets/cifar10_mean_image
----------------------------------------------------------------------
diff --git a/tool/python/examples/datasets/cifar10_mean_image b/tool/python/examples/datasets/cifar10_mean_image
new file mode 100644
index 0000000..a4ea8a5
Binary files /dev/null and b/tool/python/examples/datasets/cifar10_mean_image differ

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/tool/python/examples/train_cifar10.py
----------------------------------------------------------------------
diff --git a/tool/python/examples/train_cifar10.py b/tool/python/examples/train_cifar10.py
new file mode 100755
index 0000000..a757595
--- /dev/null
+++ b/tool/python/examples/train_cifar10.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+
+#/************************************************************
+#*
+#* Licensed to the Apache Software Foundation (ASF) under one
+#* or more contributor license agreements.  See the NOTICE file
+#* distributed with this work for additional information
+#* regarding copyright ownership.  The ASF licenses this file
+#* to you under the Apache License, Version 2.0 (the
+#* "License"); you may not use this file except in compliance
+#* with the License.  You may obtain a copy of the License at
+#*
+#*   http://www.apache.org/licenses/LICENSE-2.0
+#*
+#* Unless required by applicable law or agreed to in writing,
+#* software distributed under the License is distributed on an
+#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#* KIND, either express or implied.  See the License for the
+#* specific language governing permissions and limitations
+#* under the License.
+#*
+#*************************************************************/
+
+'''
+Example script of CNN model for CIFAR10 dataset
+'''
+import os, sys
+import numpy as np
+
+current_path_ = os.path.dirname(__file__)
+singa_root_=os.path.abspath(os.path.join(current_path_,'../../..'))
+sys.path.append(os.path.join(singa_root_,'tool','python'))
+
+from singa.driver import Driver
+from singa.layer import *
+from singa.model import *
+from singa.utils.utility import swap32
+
+fname_mean_image = 'tool/python/examples/datasets/cifar10_mean_image'
+mean_image = np.fromfile(fname_mean_image)
+
+def load_dataset(did=1):
+    ''' CIFAR10 dataset
+        5 binary dataset, each contains 10000 images
+        1 row (1 image) includes 1 label & 3072 pixels
+        3072 pixels are  3 channels of a 32x32 image
+    '''
+    print '[Load CIFAR10 dataset]', did
+    dataset_dir_ = singa_root_ + "/examples/cifar10/cifar-10-batches-bin"
+    fname_train_data = dataset_dir_ + "/data_batch_{}.bin".format(did)
+    
+    nb_samples = 10000
+    nb_pixels = 3 * 32 * 32  
+    d = np.fromfile(fname_train_data, dtype=np.uint8)
+    d = d.reshape(nb_samples, nb_pixels + 1) # +1 for label
+    x = d[:, 1:] 
+    x = x - mean_image
+    print '   data x:', x.shape
+    y = d[:, 0]
+    y = y.reshape(nb_samples, 1) 
+    print '  label y:', y.shape
+    return x, y
+
+def get_labellist():
+    dataset_dir_ = singa_root_ + "/examples/cifar10/cifar-10-batches-bin"
+    fname_label_list = dataset_dir_ + "/batches.meta.txt"
+    label_list_ = np.genfromtxt(fname_label_list, dtype=str)
+    return label_list_
+
+#-------------------------------------------------------------------
+print '[Layer registration/declaration]'
+d = Driver()
+d.Init(sys.argv)
+
+input = Dummy()
+label = Dummy()
+
+nn = []
+nn.append(input)
+nn.append(Convolution2D(32, 5, 1, 2, w_std=0.0001, b_lr=2))
+nn.append(MaxPooling2D(pool_size=(3,3), stride=2))
+nn.append(Activation('relu'))
+nn.append(LRN2D(3, alpha=0.00005, beta=0.75))
+nn.append(Convolution2D(32, 5, 1, 2, b_lr=2))
+nn.append(Activation('relu'))
+nn.append(AvgPooling2D(pool_size=(3,3), stride=2))
+nn.append(LRN2D(3, alpha=0.00005, beta=0.75))
+nn.append(Convolution2D(64, 5, 1, 2))
+nn.append(Activation('relu'))
+nn.append(AvgPooling2D(pool_size=(3,3), stride=2))
+nn.append(Dense(10, w_wd=250, b_lr=2, b_wd=0))
+loss = Loss('softmaxloss')
+
+# updater
+sgd = SGD(decay=0.004, momentum=0.9, lr_type='manual', step=(0,60000,65000), step_lr=(0.001,0.0001,0.00001))
+
+#-------------------------------------------------------------------
+batchsize = 100 
+disp_freq = 50
+train_step = 1000
+
+for dataset_id in range(train_step / batchsize):
+
+    x, y = load_dataset(dataset_id%5+1)
+
+    print '[Start training]'
+    for i in range(x.shape[0] / batchsize):
+        xb, yb = x[i*batchsize:(i+1)*batchsize,:], y[i*batchsize:(i+1)*batchsize,:]
+        nn[0].Feed(xb, 3, 0)
+        label.Feed(yb, 1, 1)
+        for h in range(1, len(nn)):
+            nn[h].ComputeFeature(nn[h-1])
+        loss.ComputeFeature(nn[-1], label)
+        if (i+1)%disp_freq == 0:
+            print '  Step {:>3}: '.format(i+1 + dataset_id*(x.shape[0]/batchsize)),
+            loss.display()
+        loss.ComputeGradient(i+1, sgd)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/tool/python/examples/train_mnist.py
----------------------------------------------------------------------
diff --git a/tool/python/examples/train_mnist.py b/tool/python/examples/train_mnist.py
new file mode 100755
index 0000000..466bc58
--- /dev/null
+++ b/tool/python/examples/train_mnist.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+
+#/************************************************************
+#*
+#* Licensed to the Apache Software Foundation (ASF) under one
+#* or more contributor license agreements.  See the NOTICE file
+#* distributed with this work for additional information
+#* regarding copyright ownership.  The ASF licenses this file
+#* to you under the Apache License, Version 2.0 (the
+#* "License"); you may not use this file except in compliance
+#* with the License.  You may obtain a copy of the License at
+#*
+#*   http://www.apache.org/licenses/LICENSE-2.0
+#*
+#* Unless required by applicable law or agreed to in writing,
+#* software distributed under the License is distributed on an
+#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#* KIND, either express or implied.  See the License for the
+#* specific language governing permissions and limitations
+#* under the License.
+#*
+#*************************************************************/
+
+'''
+Example script of MLP model for MNIST dataset
+'''
+import os, sys
+import numpy as np
+
+current_path_ = os.path.dirname(__file__)
+singa_root_=os.path.abspath(os.path.join(current_path_,'../../..'))
+sys.path.append(os.path.join(singa_root_,'tool','python'))
+
+from singa.driver import Driver
+from singa.layer import *
+from singa.model import *
+from singa.utils.utility import swap32
+
+def load_dataset():
+    ''' MNIST dataset
+        train-images: 4 int32 headers & int8 pixels
+        train-labels: 2 int32 headers & int8 labels
+    '''
+    print '[Load MNIST dataset]'
+    fname_train_image = "examples/mnist/train-images-idx3-ubyte"
+    fname_train_label = "examples/mnist/train-labels-idx1-ubyte"
+    nb_header = [4, 2]
+
+    info = swap32(np.fromfile(fname_train_image, dtype=np.uint32, count=nb_header[0]))
+    nb_samples = info[1] 
+    shape = (info[2],info[3])
+    
+    x = np.fromfile(fname_train_image, dtype=np.uint8)
+    x = x[np.dtype(np.int32).itemsize*nb_header[0]:] # skip header
+    x = x.reshape(nb_samples, shape[0]*shape[1]) 
+    print '   data x:', x.shape
+    y = np.fromfile(fname_train_label, dtype=np.uint8)
+    y = y[np.dtype(np.int32).itemsize*nb_header[1]:] # skip header
+    y = y.reshape(nb_samples, 1) 
+    print '  label y:', y.shape
+
+    return x, y
+
+#-------------------------------------------------------------------
+print '[Layer registration/declaration]'
+d = Driver()
+d.Init(sys.argv)
+
+input = Dummy()
+label = Dummy()
+
+nn = []
+nn.append(input)
+nn.append(Dense(2500, init='uniform'))
+nn.append(Activation('stanh'))
+nn.append(Dense(2000, init='uniform'))
+nn.append(Activation('stanh'))
+nn.append(Dense(1500, init='uniform'))
+nn.append(Activation('stanh'))
+nn.append(Dense(1000, init='uniform'))
+nn.append(Activation('stanh'))
+nn.append(Dense(500, init='uniform'))
+nn.append(Activation('stanh'))
+nn.append(Dense(10, init='uniform'))
+loss = Loss('softmaxloss')
+
+# updater
+sgd = SGD(lr=0.001, lr_type='step')
+
+#-------------------------------------------------------------------
+batchsize = 64 
+disp_freq = 10
+
+x, y = load_dataset()
+
+print '[Start training]'
+for i in range(x.shape[0] / batchsize):
+    xb, yb = x[i*batchsize:(i+1)*batchsize,:], y[i*batchsize:(i+1)*batchsize,:]
+    nn[0].Feed(xb)
+    label.Feed(yb, is_label=1)
+    for h in range(1, len(nn)):
+        nn[h].ComputeFeature(nn[h-1])
+    loss.ComputeFeature(nn[-1], label)
+    if (i+1)%disp_freq == 0:
+        print '  Step {:>3}: '.format(i+1),
+        loss.display()
+    loss.ComputeGradient(i+1, sgd)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/tool/python/singa/driver.i
----------------------------------------------------------------------
diff --git a/tool/python/singa/driver.i b/tool/python/singa/driver.i
index f756d57..1650145 100644
--- a/tool/python/singa/driver.i
+++ b/tool/python/singa/driver.i
@@ -25,19 +25,102 @@
 %include "std_vector.i"
 %include "std_string.i"
 %include "argcargv.i"
+%include "carrays.i"
+%array_class(float, floatArray);
+
 %apply (int ARGC, char **ARGV) { (int argc, char **argv)  }
 %{
 #include "singa/driver.h"
+#include "singa/worker.h"
+#include "singa/neuralnet/neuralnet.h"
+#include "singa/neuralnet/layer.h"
+#include "singa/neuralnet/neuron_layer.h"
+#include "singa/neuralnet/loss_layer.h"
+#include "singa/utils/blob.h"
+#include "singa/utils/param.h"
+#include "singa/utils/updater.h"
+#include "singa/proto/job.pb.h"
+#include "singa/proto/common.pb.h"
 %}
 
-namespace singa{
-using std::vector;
-class Driver{
-public:
-void Train(bool resume, const std::string job_conf);
-void Init(int argc, char **argv);
-void InitLog(char* arg);
-void Test(const std::string job_conf);
-};
+namespace std {
+  %template(strVector) vector<string>;
+  %template(intVector) vector<int>;
+  %template(floatVector) vector<float>;
+  %template(layerVector) vector<singa::Layer*>;
+  %template(paramVector) vector<singa::Param*>;
 }
 
+namespace singa{
+  class Driver{
+    public:
+    void Train(bool resume, const std::string job_conf);
+    void Init(int argc, char **argv);
+    void InitLog(char* arg);
+    void Test(const std::string job_conf);
+  };
+
+  class NeuralNet{
+    public:
+     static NeuralNet* CreateNeuralNet(const std::string str);
+     void Load(const std::vector<std::string>& paths);
+     inline const std::vector<singa::Layer*>& layers();
+     inline const std::vector<singa::Layer*>& srclayers(const singa::Layer* layer);
+  };
+
+  %nodefault Worker;
+  class Worker{
+    public:
+      static singa::Worker* CreateWorker(const std::string str);
+      void InitNetParams(const std::string& folder, std::vector<singa::Layer*> net);
+      void Checkpoint(int step, const std::string& folder, std::vector<singa::Layer*> net);
+  };
+    
+  class DummyLayer{
+    public:
+      void Setup(const std::string str, const std::vector<singa::Layer*>& srclayers);
+      void Feed(int batchsize, std::vector<float>& data, int is_aux);
+      singa::Layer* ToLayer();
+  };
+
+  %nodefault Layer;
+  class Layer{
+    public:
+      static singa::Layer* CreateLayer(const std::string str);
+      static void SetupLayer(singa::Layer* layer, const std::string str, const std::vector<singa::Layer*>& srclayers);
+      virtual void ComputeFeature(int flag, const std::vector<singa::Layer*>& srclayers); 
+      virtual void ComputeGradient(int flag, const std::vector<singa::Layer*>& srclayers);
+      virtual const singa::Blob<float>& data(const singa::Layer* from); 
+      virtual const std::vector<singa::Param*> GetParams();
+      virtual const std::string ToString(bool debug, int flag);
+      void SetParams(std::vector<singa::Param*> params);
+  };
+
+  %nodefault Updater;
+  class Updater{
+    public:
+      static singa::Updater* CreateUpdater(const std::string str);
+      virtual void Update(int step, singa::Param* param, float grad_scale);
+  };
+
+  template <typename Dtype>
+  class Blob{
+    public:
+      inline int count();
+      inline const std::vector<int>& shape();
+      inline Dtype* mutable_cpu_data(); 
+      inline const Dtype* cpu_data();
+  };
+
+  class Param{
+    public:
+      inline int size();
+      inline const std::vector<int>& shape();
+      inline float* mutable_cpu_data();
+      void FromProto(const std::string str);
+      /*void ToProto(singa::BlobProto* blob); 
+      */
+  };
+
+  %template(floatBlob) Blob<float>;
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/tool/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/tool/python/singa/layer.py b/tool/python/singa/layer.py
index f838e45..35da9aa 100644
--- a/tool/python/singa/layer.py
+++ b/tool/python/singa/layer.py
@@ -25,15 +25,21 @@
 This script includes Layer class and its subclasses that
 users can configure different types of layers for their model.
 '''
-
+import numpy as np
 from singa.parameter import Parameter, set_param_field
 from singa.initializations import get_init_values
 from singa.utils.utility import setval, generate_name
 from singa.utils.message import *
 from google.protobuf import text_format
 
+from singa.driver import Layer as SingaLayer, Updater as SingaUpdater,\
+                         intVector, floatVector, layerVector,\
+                         paramVector, floatArray_frompointer, DummyLayer
+
 class Layer(object):
 
+    singaupdater = None
+
     def __init__(self, **kwargs):
         '''
         **kwargs (KEY=VALUE)
@@ -41,12 +47,314 @@ class Layer(object):
         '''
 
         self.layer = Message('Layer', **kwargs).proto
-        # required
+        # required field
         if not 'name' in kwargs:
             setval(self.layer, name=generate_name('layer', 1))
 
-        # srclayers are set in Model.build()
+        # layer connectivity is set in Model.build()
         self.is_datalayer = False
+        self.singalayer = None
+        self.srclayers = []
+
+        # set src for Rafiki
+        if 'src' in kwargs:
+            self.src = kwargs['src']
+        else:
+            self.src = None
+
+    def setup(self, srclys):
+        ''' Create singa::Layer and store srclayers
+        '''
+        if self.singalayer == None:
+            self.singalayer = SingaLayer.CreateLayer(
+                                    self.layer.SerializeToString())
+            self.singaSrclayerVector = layerVector(len(srclys))
+            for i in range(len(srclys)):
+                self.srclayers.append(srclys[i])
+                self.singaSrclayerVector[i] = srclys[i].get_singalayer()
+            # set up the layer
+            SingaLayer.SetupLayer(self.singalayer,
+                                  self.layer.SerializeToString(),
+                                  self.singaSrclayerVector)
+
+    def ComputeFeature(self, *srclys):
+        ''' The method creates and sets up singa::Layer
+            and maintains its source layers
+            then call ComputeFeature for data transformation.
+
+            *srclys = (list)  // a list of source layers
+        '''
+        # create singa::Layer and store srclayers
+        if self.singalayer == None:
+            if self.src != None:
+                srclys = self.src
+            self.singalayer = SingaLayer.CreateLayer(
+                                    self.layer.SerializeToString())
+            self.singaSrclayerVector = layerVector(len(srclys))
+            for i in range(len(srclys)):
+                self.srclayers.append(srclys[i])
+                self.singaSrclayerVector[i] = srclys[i].get_singalayer()
+            # set up the layer
+            SingaLayer.SetupLayer(self.singalayer,
+                                  self.layer.SerializeToString(),
+                                  self.singaSrclayerVector)
+
+        self.singalayer.ComputeFeature(1, self.singaSrclayerVector)
+
+    def ComputeGradient(self, step, upd=None):
+        ''' The method creates singa::Updater
+            and calls ComputeGradient for gradient computation
+            then updates the parameters.
+
+            step = (int)    // a training step
+            upd = (object)  // Updater object
+        '''
+        # create singa::Updater
+        assert upd != None, 'required Updater (see model.py)'
+        if Layer.singaupdater == None:
+            Layer.singaupdater = SingaUpdater.CreateUpdater(
+                                          upd.proto.SerializeToString())
+
+        # call ComputeGradient of Singa
+        self.singalayer.ComputeGradient(1, self.singaSrclayerVector)
+
+        # update parameters
+        singaParams = self.singalayer.GetParams()
+        for par in singaParams:
+            Layer.singaupdater.Update(step, par, 1.0)
+
+        # recursively call ComputeGradient of srclayers
+        #(TODO) what if there are multiple source layers?
+        for sly in self.srclayers:
+            if sly.srclayers != None:
+                sly.ComputeGradient(step, upd)
+
+    def GetParams(self):
+        ''' The method gets parameter values
+            singaParams[0] for weight
+            singaParams[1] for bias
+        '''
+        singaParams = self.singalayer.GetParams()
+        assert len(singaParams) == 2, 'weight and bias'
+        # for weight
+        weight_array = floatArray_frompointer(singaParams[0].mutable_cpu_data())
+        weight = [weight_array[i] for i in range(singaParams[0].size())]
+        weight = np.array(weight).reshape(singaParams[0].shape())
+        # for bias
+        bias_array = floatArray_frompointer(singaParams[1].mutable_cpu_data())
+        bias = [bias_array[i] for i in range(singaParams[1].size())]
+        bias = np.array(bias).reshape(singaParams[1].shape()[0], 1)
+
+        return weight, bias
+
+    def SetParams(self, *params):
+        ''' The method sets parameter values
+            params[0] for weight
+            params[1] for bias
+        '''
+        singaParams = self.singalayer.GetParams()
+        import pb2.common_pb2 as cm
+        for k in range(len(params)):
+            bp = cm.BlobProto()
+            bp.shape.append(int(params[k].shape[0]))
+            bp.shape.append(int(params[k].shape[1]))
+            for i in range(params[k].shape[0]):
+                for j in range(params[k].shape[1]):
+                    bp.data.append(params[k][i, j])
+            singaParams[k].FromProto(bp.SerializeToString())
+
+    def GetData(self):
+        ''' The method gets layer data values
+        '''
+        blobptr = self.singalayer.data(self.singalayer)
+        data_array = floatArray_frompointer(blobptr.mutable_cpu_data())
+        data = [data_array[i] for i in range(blobptr.count())]
+        return data
+
+    def display(self):
+        debug, flag = 0, 0
+        print self.singalayer.ToString(debug, flag)
+
+    def get_singalayer(self):
+        return self.singalayer
+
+
+class Dummy(object):
+
+    def __init__(self, shape=[], path='', dtype='', src=[], **kwargs):
+        ''' Dummy layer is used for data layer to feed/fetch input data
+            shape = (list)   // [# of samples, # of channels, img h, img w]
+            path  = (string) // path to dataset
+        '''
+        self.is_datalayer = True
+        self.srclayers = None
+        self.singalayer = None
+        if 'is_label' in kwargs: self.is_label = kwargs['is_label']
+
+        # create layer proto for Dummy layer
+        kwargs = {'name':'dummy', 'type':kDummy}
+        self.layer = Message('Layer', **kwargs).proto
+
+    def setup(self, data_shape):
+        ''' Create and Setup singa Dummy layer
+            called by load_model_parameter
+        '''
+        if self.singalayer == None:
+            setval(self.layer.dummy_conf, input=True)
+            setval(self.layer.dummy_conf, shape=data_shape)
+            self.singalayer = DummyLayer()
+            self.singalayer.Setup(self.layer.SerializeToString(),
+                                  layerVector(0))
+
+    def Feed(self, data, nb_channel=1, is_label=0):
+        ''' Create and Setup singa::DummyLayer for input data
+            Insert data using Feed()
+        '''
+        batchsize, hdim = data.shape
+        datasize = batchsize * hdim
+
+        # create and setup the dummy layer
+        if self.singalayer == None:
+            imgsize = int(np.sqrt(hdim/nb_channel))
+            shapeVector = [batchsize, nb_channel, imgsize, imgsize]
+            self.setup(shapeVector)
+
+        data = data.astype(np.float)
+        dataVector = floatVector(datasize)
+        k = 0
+        for i in range(batchsize):
+            for j in range(hdim):
+                dataVector[k] = data[i, j]
+                k += 1
+        self.singalayer.Feed(batchsize, dataVector, is_label)
+
+    def FetchData(self, batchsize):
+        sidx = self.batch_index * batchsize
+        eidx = sidx + batchsize
+        batch = self.data[sidx:eidx, :]
+
+        self.Feed(batch, self.shape[1], self.is_label)
+
+        self.batch_index += 1
+        if eidx > self.data.shape[0]:
+            self.batch_index = 0
+
+    def get_singalayer(self):
+        return self.singalayer.ToLayer()
+
+class ImageData(Dummy):
+    ''' This class will be used for Rafiki, dlaas
+    '''
+    def __init__(self, shape=[], data_path='', data_type='byte', src=[],
+                 mean_path='', mean_type='float'):
+        ''' Dummy layer is used for data layer
+            shape = (list)   // [# of samples, # of channels, img h, img w]
+            data_path  = (string) // path to dataset
+            mean_path
+        '''
+        is_label = False
+        super(ImageData, self).__init__(shape, data_path, data_type, src,
+                                    is_label=is_label,
+                                    mean_path=mean_path,
+                                    mean_type=mean_type)
+
+        # if dataset path is not specified, skip
+        # otherwise, load dataset
+        if data_path == '' or mean_path == '':
+            return
+
+        self.shape = shape
+        self.data_path = data_path
+        self.mean_path = mean_path
+        self.src = None
+        self.batch_index = 0
+
+        nb_samples = shape[0]
+        nb_pixels = shape[1]
+        for i in range(len(shape)-2):
+            nb_pixels *= shape[i+2]
+
+        if data_type == 'byte':
+            d = np.fromfile(data_path, dtype=np.uint8)
+        elif data_type == 'int':
+            d = np.fromfile(data_path, dtype=np.int)
+        self.data = d.reshape(nb_samples, nb_pixels)
+
+        if mean_type == 'float':
+            d = np.fromfile(mean_path, dtype=np.float32)
+        self.mean = d.reshape(1, nb_pixels)
+
+    def Feed(self, data, nb_channel=1, is_label=0):
+        ''' Create and Setup singa::DummyLayer for input data
+            Insert data using Feed()
+            Need to minus the mean file
+        '''
+        batchsize, hdim = data.shape
+        datasize = batchsize * hdim
+
+        # create and setup the dummy layer
+        if self.singalayer == None:
+            imgsize = int(np.sqrt(hdim/nb_channel))
+            shapeVector = [batchsize, nb_channel, imgsize, imgsize]
+            self.setup(shapeVector)
+
+        # feed input data and minus mean
+        data = data.astype(np.float)
+        dataVector = floatVector(datasize)
+        k = 0
+        for i in range(batchsize):
+            for j in range(hdim):
+                dataVector[k] = data[i, j]-self.mean[0, j]
+                k += 1
+        self.singalayer.Feed(batchsize, dataVector, is_label)
+
+class LabelData(Dummy):
+    ''' This class will be used for Rafiki, dlaas
+    '''
+    def __init__(self, shape=[], label_path='', label_type='int', src=[]):
+        ''' Dummy layer is used for label data layer
+            shape = (list)   // [# of samples, # of channels, img h, img w]
+            data_path  = (string) // path to dataset
+            mean_path
+        '''
+        is_label = True
+        super(LabelData, self).__init__(shape, label_path, label_type, src,
+                                    is_label=is_label)
+
+        # if dataset path is not specified, skip
+        # otherwise, load dataset
+        if label_path == '':
+            return
+
+        self.shape = shape
+        self.label_path = label_path
+        self.src = None
+        self.batch_index = 0
+
+        nb_samples = shape[0]
+
+        if label_type == 'int':
+            d = np.fromfile(label_path, dtype=np.int)
+        self.data = d.reshape(nb_samples, 1)
+
+    def Feed(self, data, nb_chanel=1, is_label=1):
+        ''' Create and Setup singa::DummyLayer for input data
+            Insert data using Feed()
+            Need to minus the mean file
+        '''
+        batchsize = data.shape[0]
+
+        # create and setup the dummy layer
+        if self.singalayer == None:
+            shapeVector = [batchsize, 1]
+            self.setup(shapeVector)
+
+        data = data.astype(np.float)
+        dataVector = floatVector(batchsize)
+        for i in range(batchsize):
+            dataVector[i] = data[i, 0]
+        self.singalayer.Feed(batchsize, dataVector, 1)
+
 
 class Data(Layer):
 
@@ -66,11 +374,11 @@ class Data(Layer):
         assert load != None, 'data type should be specified'
         if load == 'kData':
             super(Data, self).__init__(name=generate_name('data'),
-                                       user_type=load)
+                                       user_type=load, **kwargs)
         else:
             self.layer_type = enumLayerType(load)
             super(Data, self).__init__(name=generate_name('data'),
-                                       type=self.layer_type)
+                                       type=self.layer_type, **kwargs)
         self.is_datalayer = True
 
         # include/exclude
@@ -108,32 +416,32 @@ class Convolution2D(Layer):
                            // scale the learning rate when updating parameters.
             w_wd = (float) // weight decay multiplier for weight, used to
                            // scale the weight decay when updating parameters.
-            b_lr = (float) // learning rate multiplier for bias 
+            b_lr = (float) // learning rate multiplier for bias
             b_wd = (float) // weight decay multiplier for bias
         '''
 
         assert nb_filter > 0, 'nb_filter should be set as positive int'
         super(Convolution2D, self).__init__(name=generate_name('conv', 1),
-                                            type=kCConvolution)
+                                            type=kCConvolution, **kwargs)
         fields = {"num_filters":nb_filter}
         # for kernel
         if type(kernel) == int:
-          fields['kernel'] = kernel
+            fields['kernel'] = kernel
         else:
-          fields['kernel_x'] = kernel[0]
-          fields['kernel_y'] = kernel[1]
-        # for stride 
+            fields['kernel_x'] = kernel[0]
+            fields['kernel_y'] = kernel[1]
+        # for stride
         if type(stride) == int:
-          fields['stride'] = stride
+            fields['stride'] = stride
         else:
-          fields['stride_x'] = stride[0]
-          fields['stride_y'] = stride[1]
-        # for pad 
+            fields['stride_x'] = stride[0]
+            fields['stride_y'] = stride[1]
+        # for pad
         if type(pad) == int:
-          fields['pad'] = pad 
+            fields['pad'] = pad
         else:
-          fields['pad_x'] = pad[0]
-          fields['pad_y'] = pad[1]
+            fields['pad_x'] = pad[0]
+            fields['pad_y'] = pad[1]
 
         setval(self.layer.convolution_conf, **fields)
 
@@ -155,6 +463,7 @@ class Convolution2D(Layer):
         if activation:
             self.mask = Activation(activation=activation).layer
 
+
 class MaxPooling2D(Layer):
 
     def __init__(self, pool_size=None,
@@ -218,31 +527,53 @@ class LRN2D(Layer):
           size = (int)  // local size
         '''
 
-        super(LRN2D, self).__init__(name=generate_name('norm'), type=kLRN)
+        super(LRN2D, self).__init__(name=generate_name('norm'), type=kLRN, **kwargs)
         # required
         assert size != 0, 'local size should be set'
         self.layer.lrn_conf.local_size = size
         init_values = get_init_values('lrn2d', **kwargs)
         setval(self.layer.lrn_conf, **init_values)
 
+class Loss(Layer):
+
+    def __init__(self, lossname, topk=1, **kwargs):
+        '''
+        required
+          lossname = (string) // softmaxloss, euclideanloss
+        '''
+        self.layer_type = enumLayerType(lossname)
+        super(Loss, self).__init__(name=generate_name(lossname),
+                                         type=self.layer_type, **kwargs)
+        if lossname == 'softmaxloss':
+            self.layer.softmaxloss_conf.topk = topk
 
 class Activation(Layer):
 
-    def __init__(self, activation='stanh', topk=1):
+    def __init__(self, activation='stanh', **kwargs):
         '''
         required
-          activation = (string)
-        optional
-          topk       = (int)  // the number of results
+          activation = (string) // relu, sigmoid, tanh, stanh, softmax.
         '''
+        if activation == 'tanh':
+            print 'Warning: Tanh layer is not supported for CPU'
 
         self.name = activation
-        if activation == 'tanh': activation = 'stanh' # <-- better way to set?
-        self.layer_type = enumLayerType(activation)
+        self.layer_type = kActivation
+        if activation == 'stanh':
+            self.layer_type = kSTanh
+        elif activation == 'softmax':
+            self.layer_type = kSoftmax
         super(Activation, self).__init__(name=generate_name(self.name),
-                                         type=self.layer_type)
-        if activation == 'softmaxloss':
-            self.layer.softmaxloss_conf.topk = topk
+                                         type=self.layer_type, **kwargs)
+        if activation == 'relu':
+            self.layer.activation_conf.type = RELU
+        elif activation == 'sigmoid':
+            self.layer.activation_conf.type = SIGMOID
+        elif activation == 'tanh':
+            self.layer.activation_conf.type = TANH # for GPU
+        #elif activation == 'stanh':
+        #    self.layer.activation_conf.type = STANH
+
 
 class Dropout(Layer):
 
@@ -255,19 +586,19 @@ class Dropout(Layer):
         self.name = 'dropout'
         self.layer_type = enumLayerType(self.name)
         super(Dropout, self).__init__(name=generate_name(self.name),
-                                      type=self.layer_type)
+                                      type=self.layer_type, **kwargs)
         self.layer.dropout_conf.dropout_ratio = ratio
 
 class Accuracy(Layer):
 
-    def __init__(self):
+    def __init__(self, **kwargs):
         '''
         '''
 
         self.name = 'accuracy'
         self.layer_type = enumLayerType(self.name)
         super(Accuracy, self).__init__(name=generate_name(self.name),
-                                       type=self.layer_type)
+                                       type=self.layer_type, **kwargs)
 
 class RGB(Layer):
 
@@ -302,7 +633,7 @@ class Dense(Layer):
                            // scale the learning rate when updating parameters.
             w_wd = (float) // weight decay multiplier for weight, used to
                            // scale the weight decay when updating parameters.
-            b_lr = (float) // learning rate multiplier for bias 
+            b_lr = (float) // learning rate multiplier for bias
             b_wd = (float) // weight decay multiplier for bias
         '''
         # required
@@ -344,7 +675,7 @@ class Autoencoder(object):
         required
           hid_dim     = (int/list) // the number of nodes in hidden layers
           out_dim     = (int)      // the number of nodes in the top layer
-        optional 
+        optional
           activation  = (string)
           param_share = (bool)     // to share params in encoder and decoder
         '''
@@ -383,7 +714,8 @@ class RBM(Layer):
         self.name = kwargs['name'] if 'name' in kwargs else 'RBMVis'
         self.layer_type = kwargs['type'] if 'type' in kwargs else kRBMVis
         super(RBM, self).__init__(name=generate_name(self.name,
-                                  withnumber=False), type=self.layer_type)
+                                                     withnumber=False),
+                                  type=self.layer_type, **kwargs)
         setval(self.layer.rbm_conf, hdim=self.out_dim[-1])
         if self.layer_type == kRBMHid and sampling != None:
             if sampling == 'gaussian':

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/tool/python/singa/model.py
----------------------------------------------------------------------
diff --git a/tool/python/singa/model.py b/tool/python/singa/model.py
index f652f86..b6664bd 100644
--- a/tool/python/singa/model.py
+++ b/tool/python/singa/model.py
@@ -107,9 +107,9 @@ class Model(object):
             else:
                 # add new layer
                 if loss == 'categorical_crossentropy':
-                    self.add(Activation('softmaxloss', topk=topk))
+                    self.add(Loss('softmaxloss', topk=topk))
                 elif loss == 'mean_squared_error':
-                    self.add(Activation('euclideanloss'))
+                    self.add(Loss('euclideanloss'))
                 elif loss == 'user_loss_rnnlm': # user-defined loss layer
                     self.add(UserLossRNNLM(nclass=kwargs['nclass'],
                                            vocab_size=kwargs['in_dim']))
@@ -323,17 +323,40 @@ class Model(object):
             elif ly_type == kLRN: cudnn_ly_type = kCudnnLRN
             elif ly_type == kSoftmax: cudnn_ly_type = kCudnnSoftmax
             elif ly_type == kSoftmaxLoss: cudnn_ly_type = kCudnnSoftmaxLoss
+            elif ly_type == kActivation:
+                cudnn_ly_type = kCudnnActivation
             elif ly_type == kSTanh:
+                print 'Error report: STanh layer is not supported for GPU'
+            '''
+            elif ly_type == kReLU:
                 cudnn_ly_type = kCudnnActivation
-                net.layer[i].activation_conf.type = STANH
+                net.layer[i].activation_conf.type = RELU
             elif ly_type == kSigmoid:
                 cudnn_ly_type = kCudnnActivation
                 net.layer[i].activation_conf.type = SIGMOID
-            elif ly_type == kReLU:
+            elif ly_type == kTanh:
                 cudnn_ly_type = kCudnnActivation
-                net.layer[i].activation_conf.type = RELU
+                net.layer[i].activation_conf.type = TANH
+            '''
+            #elif ly_type == kSTanh:
+            #    print 'Error report: STanh layer is not supported for GPU'
+                #cudnn_ly_type = kCudnnActivation
+                #net.layer[i].activation_conf.type = STANH
             net.layer[i].type = cudnn_ly_type
 
+    def show(self):
+        for ly in self.jobconf.neuralnet.layer:
+            print layer(ly.name)
+
+    def layer_by_id(self, k):
+        return self.jobconf.neuralnet.layer[k]
+
+    def layer_by_name(self, name):
+        return self.layers[k]
+
+    def size(self):
+        return len(self.jobconf.neuralnet.layer)
+
 class Energy(Model):
     ''' energy model
     '''
@@ -627,3 +650,51 @@ def SingaRun_script(filename='', execpath=''):
     #TODO better format to store the result??
     return resultDic
 
+def load_model_parameter(fin, neuralnet, batchsize=1, data_shape=None):
+    """
+    this method loads model parameter
+    """
+    hly_idx = 0
+    for i in range(len(neuralnet)): 
+        if neuralnet[i].is_datalayer:
+            if data_shape == None:
+                shape = neuralnet[i].shape
+                shape[0] = batchsize
+                neuralnet[i].setup(shape)
+            else:
+                neuralnet[i].setup(data_shape)
+        else:
+            hly_idx = i
+            break
+
+    net = layerVector(len(neuralnet)-hly_idx)
+    for i in range(hly_idx, len(neuralnet)): 
+        if neuralnet[i].src==None:
+            neuralnet[i].setup(neuralnet[i-1])
+        else:
+            neuralnet[i].setup(neuralnet[i].src)
+        net[i-hly_idx] = neuralnet[i].singalayer
+
+    from singa.driver import Worker
+    alg = Algorithm(type=enumAlgType('bp')).proto
+    w = Worker.CreateWorker(alg.SerializeToString())
+    w.InitNetParams(fin, net)
+
+def save_model_parameter(step, fout, neuralnet):
+    """
+    this method saves model parameter
+    """
+    hly_idx = 0
+    for i in range(len(neuralnet)): 
+        if not neuralnet[i].is_datalayer:
+            hly_idx = i
+            break
+
+    from singa.driver import Worker
+    net = layerVector(len(neuralnet)-hly_idx)
+    for i in range(hly_idx, len(neuralnet)): 
+        net[i-hly_idx] = neuralnet[i].singalayer
+    alg = Algorithm(type=enumAlgType('bp')).proto
+    w = Worker.CreateWorker(alg.SerializeToString())
+    w.Checkpoint(step, fout, net)
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8130b7ed/tool/python/singa/utils/utility.py
----------------------------------------------------------------------
diff --git a/tool/python/singa/utils/utility.py b/tool/python/singa/utils/utility.py
index 8c6c455..bea784d 100644
--- a/tool/python/singa/utils/utility.py
+++ b/tool/python/singa/utils/utility.py
@@ -25,6 +25,7 @@
 This script includes methods to
 (1) generate name of layer, parameter, etc.
 (2) set field values for proto.
+(3) swap bits
 '''
 
 LAYERID = 0
@@ -82,3 +83,9 @@ def setval(proto, **kwargs):
                     fattr.MergeFrom(val)
                 else:
                     setattr(proto, key, val)
+
+def swap32(x):
+    return (((x << 24) & 0xFF000000) |
+            ((x <<  8) & 0x00FF0000) |
+            ((x >>  8) & 0x0000FF00) |
+            ((x >> 24) & 0x000000FF))


Mime
View raw message