singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [1/6] incubator-singa git commit: SINGA-186 Create Python Tensor class
Date Fri, 01 Jul 2016 08:24:44 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/dev f026b6050 -> cde7dcf6b


SINGA-186 Create Python Tensor class

- Update core_tensor.i, tensor.py to take care of the following methods
  . Random operations
  . Axpy
  . Mult
  . Pow
  . AddColumn, SubColumn, MultColumn, DivColumn, and xxxRow, etc.
  . SumColumns, SumRows

- Revised model_layer.i, layer.h to take care of Shape (std::vector)

TODO: from_array() in python


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/7cfdb995
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/7cfdb995
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/7cfdb995

Branch: refs/heads/dev
Commit: 7cfdb995cd32624779330fb336a2771da9a5ba00
Parents: 62c6603
Author: chonho <leech@comp.nus.edu.sg>
Authored: Mon Jun 27 18:01:52 2016 +0800
Committer: chonho <leech@comp.nus.edu.sg>
Committed: Thu Jun 30 15:00:45 2016 +0800

----------------------------------------------------------------------
 include/singa/model/layer.h   |   5 +-
 src/core/tensor/tensor.cc     |  17 +--
 src/python/swig/core_device.i |   2 +-
 src/python/swig/core_tensor.i | 137 ++++++++++++++++++------
 src/python/swig/model_layer.i |  13 ++-
 src/python/tensor.py          | 208 +++++++++++++++++++++++++++++++++----
 6 files changed, 318 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cfdb995/include/singa/model/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h
index ce8007c..938b161 100644
--- a/include/singa/model/layer.h
+++ b/include/singa/model/layer.h
@@ -28,6 +28,7 @@
 
 namespace singa {
 
+typedef vector<size_t> Shape;
 /// The base layer class.
 /// Generally, a layer conducts feature transformation against a set of Tensor
 /// to generate a set of Tensor. Each layer may have some parameters.
@@ -37,14 +38,14 @@ class Layer {
 
   /// Set meta data fields from a string representing a proto message.
   /// 'in_shape' is the shape of the input feature for one sample
-  void Setup(const vector<size_t>& in_shape, const string& proto_str) {
+  void Setup(const Shape& in_shape, const string& proto_str) {
     LayerConf conf;
     conf.ParseFromString(proto_str);
     this->Setup(in_shape, conf);
   }
 
   /// 'in_shapes' is the shape of the input feature for one sample
-  void Setup(const vector<vector<size_t>>& in_shapes, const string& proto_str)
{
+  void Setup(const vector<Shape>& in_shapes, const string& proto_str) {
     LayerConf conf;
     conf.ParseFromString(proto_str);
     this->Setup(in_shapes, conf);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cfdb995/src/core/tensor/tensor.cc
----------------------------------------------------------------------
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 3501ecd..4972a86 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -680,7 +680,7 @@ void AddColumn(const SType alpha, const SType beta, const Tensor &v,
     AddRow(v, &X);
   } else {
     CHECK_EQ(M->nDim(), 2u);
-    CHECK_EQ(v.nDim(), 1u);
+    // CHECK_EQ(v.nDim(), 1u); (chonho) shape of v is 2-element tuple
     size_t nb_row = M->shape(0), nb_col = M->shape(1);
     CHECK_EQ(nb_row, v.Size());
 
@@ -690,7 +690,7 @@ void AddColumn(const SType alpha, const SType beta, const Tensor &v,
     Mult(alpha, vmat, one, beta, M);
   }
 }
-template <>
+template
 void AddColumn(const float alpha, const float beta, const Tensor &v, Tensor *M);
 
 void AddRow(const Tensor &v, Tensor *M) { AddRow(1, 1, v, M); }
@@ -703,7 +703,7 @@ void AddRow(const SType alpha, const SType beta, const Tensor &v,
Tensor *M) {
     AddColumn(v, &X);
   } else {
     CHECK_EQ(M->nDim(), 2u);
-    CHECK_EQ(v.nDim(), 1u);
+    // CHECK_EQ(v.nDim(), 1u); (chonho) shape of v is 2-element tuple
     size_t nb_row = M->shape(0), nb_col = M->shape(1);
     CHECK_EQ(nb_col, v.Size());
 
@@ -804,7 +804,7 @@ void DivRow(const Tensor &v, Tensor *M) {
 void MultColumn(const Tensor &v, Tensor *M) {
   CHECK(!M->transpose()) << "Not supported yet";
   CHECK_EQ(M->nDim(), 2u);
-  CHECK_EQ(v.nDim(), 1u);
+  // CHECK_EQ(v.nDim(), 1u); (chonho) shape of v is 2-element tuple
   CHECK_EQ(v.Size(), M->shape(0));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
@@ -819,7 +819,7 @@ void MultColumn(const Tensor &v, Tensor *M) {
 void MultRow(const Tensor &v, Tensor *M) {
   CHECK(!M->transpose()) << "Not supported yet";
   CHECK_EQ(M->nDim(), 2u);
-  CHECK_EQ(v.nDim(), 1u);
+  // CHECK_EQ(v.nDim(), 1u); (chonho) shape of v is 2-element tuple
   CHECK_EQ(v.Size(), M->shape(1));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
@@ -857,7 +857,7 @@ void SumColumns(const Tensor &M, Tensor *v) {
     SumRows(X, v);
   } else {
     CHECK_EQ(M.nDim(), 2u);
-    CHECK_EQ(v->nDim(), 1u);
+    // CHECK_EQ(v->nDim(), 1u); (chonho) shape of v is 2-element tuple
     size_t nb_row = M.shape().at(0), nb_col = M.shape().at(1);
     CHECK_EQ(nb_row, v->Size());
 
@@ -872,7 +872,7 @@ void SumRows(const Tensor &M, Tensor *v) {
     SumColumns(X, v);
   } else {
     CHECK_EQ(M.nDim(), 2u);
-    CHECK_EQ(v->nDim(), 1u);
+    // CHECK_EQ(v->nDim(), 1u); (chonho) shape of v is 2-element tuple
     size_t nb_row = M.shape(0), nb_col = M.shape(1);
     CHECK_EQ(nb_col, v->Size());
 
@@ -929,7 +929,8 @@ void Axpy(const SType alpha, const Tensor &in, Tensor *out) {
     }, {in.block(), out->block()}, {out->block()});
   });
 }
-template void Axpy(const float alpha, const Tensor &in, Tensor *out);
+template
+void Axpy<float>(const float alpha, const Tensor &in, Tensor *out);
 
 Tensor Mult(const Tensor &A, const Tensor &B) {
   Shape s;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cfdb995/src/python/swig/core_device.i
----------------------------------------------------------------------
diff --git a/src/python/swig/core_device.i b/src/python/swig/core_device.i
index ab9abd8..50cee3e 100644
--- a/src/python/swig/core_device.i
+++ b/src/python/swig/core_device.i
@@ -35,7 +35,7 @@ namespace singa{
   class Device {
    public:
     virtual void SetRandSeed(unsigned seed) = 0;
-    Device* host();
+    std::shared_ptr<Device> host();
     int id() const;
   };
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cfdb995/src/python/swig/core_tensor.i
----------------------------------------------------------------------
diff --git a/src/python/swig/core_tensor.i b/src/python/swig/core_tensor.i
index 409ab0c..e30f2ef 100644
--- a/src/python/swig/core_tensor.i
+++ b/src/python/swig/core_tensor.i
@@ -32,6 +32,7 @@
 %array_class(double, doubleArray);
 
 %{
+#define SWIG_FILE_WITH_INIT
 #include "singa/core/tensor.h"
 #include "singa/core/device.h"
 #include "singa/proto/core.pb.h"
@@ -51,6 +52,7 @@ namespace singa{
                         int start = 0, size_t len = 0);
   inline size_t SizeOf(DataType t);
 
+
   class Tensor {
 
    public:
@@ -58,11 +60,11 @@ namespace singa{
     explicit Tensor(const std::vector<size_t> &shape,
                     DataType dtype = kFloat32);
     Tensor(const std::vector<size_t> &shape,
-           singa::Device *dev, DataType dtype = kFloat32);
+           std::shared_ptr<singa::Device> dev, DataType dtype = kFloat32);
     Tensor(const Tensor &from);
 
     //Blob *blob() const;
-    singa::Device *device() const;
+    std::shared_ptr<singa::Device> device() const;
 
     template <typename DType> DType data() const;
     %template(floatData) data<const float*>;
@@ -80,23 +82,24 @@ namespace singa{
     void Reshape(const std::vector<size_t> &shape);
     void ResetLike(const Tensor &t);
     void AsType(DataType type);
-    void ToDevice(singa::Device *dev);
+    void ToDevice(std::shared_ptr<singa::Device> dev);
     void ToHost();
+    float L2();
 
     template <typename SType> void SetValue(const SType x);
     %template(floatSetValue) SetValue<float>;
     // ...
 
-    /* no need to expose this function
     template <typename DType> void CopyDataFromHostPtr(const DType *src,
-                                                       size_t num);
-    */
+                                                       const size_t num,
+                                                       const size_t offset);
+    %template(floatCopyData) CopyDataFromHostPtr<float>;
 
     void CopyData(const Tensor &other);
     Tensor Clone() const;
     Tensor T() const;
 
-    /* python has no assignment operator as c++
+    /* python has no assignment operator
     Tensor &operator=(const Tensor &t); */
     Tensor &operator+=(const Tensor &t);
     Tensor &operator-=(const Tensor &t);
@@ -124,9 +127,17 @@ namespace singa{
     /* TODO(chonho-01) for other types */
     // ...
 
+
+    /*TODO(chonho-08-b)
+    amax
+    amin
+    asum
+    */
+
+
   };
 
-  /* TODO
+  /* TODO(chonho-02)
   inline void CheckDataTypeAndLang(const Tensor &in1, const Tensor &in2);
   */
   void CopyDataToFrom(Tensor *dst, const Tensor &src, size_t num,
@@ -147,22 +158,25 @@ namespace singa{
   Tensor Sum(const Tensor &t, int axis);
   template <typename SType> SType Sum(const Tensor &t);
   %template(floatSum) Sum<float>;
-  /* TODO(chonho-03) not implemented
-  %template(intSum) Sum<int>;
-  %template(charSum) Sum<char>;
-  %template(doubleSum) Sum<double>;
-  */
+  // --- other types
 
-  /* TODO(chonho-04) not implemented
-     need average of all elements ??? */
+  /* TODO(chonho-04)
+     need to implement the average of all elements ??? */
   Tensor Average(const Tensor &t, int axis);
-  Tensor SoftMax(const Tensor &t, int axis = 0);
 
-  /* TODO(chonho-05) not implemented ???
-  Tensor Pow(const Tensor &base, Tensor exp);
-  template <typename DType>
-  Tensor Pow(const Tensor &t, DType x);
-  */
+
+  Tensor Pow(const Tensor &base, const Tensor &exp);
+  void Pow(const Tensor &base, const Tensor &exp, Tensor *out);
+
+  %rename(Pow_f) Pow(const Tensor &in, const float x);
+  template <typename SType>
+  Tensor Pow(const Tensor &in, const SType x);
+  %template(pow_temp) Pow<float>;
+
+  %rename(Pow_f_out) Pow(const Tensor &in, const float x, Tensor *out);
+  template <typename SType>
+  void Pow(const Tensor &in, const SType x, Tensor *out);
+  %template(pow_temp) Pow<float>;
 
 
   /* rename comparison operators */
@@ -206,10 +220,10 @@ namespace singa{
   */
 
 
-  /* rename operators */
+  /* ========== Arithmetic operations ========== */
   %rename(Add_TT) operator+(const Tensor &lhs, const Tensor &rhs);
   %rename(Sub_TT) operator-(const Tensor &lhs, const Tensor &rhs);
-  %rename(Mul_TT) operator*(const Tensor &lhs, const Tensor &rhs);
+  %rename(EltwiseMul_TT) operator*(const Tensor &lhs, const Tensor &rhs);
   %rename(Div_TT) operator/(const Tensor &lhs, const Tensor &rhs);
   Tensor operator+(const Tensor &lhs, const Tensor &rhs);
   Tensor operator-(const Tensor &lhs, const Tensor &rhs);
@@ -228,7 +242,7 @@ namespace singa{
   %template(op) operator-<float>;
   // --- other types
 
-  %rename(Mul_Tf) operator*(const Tensor &t, float x);
+  %rename(EltwiseMul_Tf) operator*(const Tensor &t, float x);
   template <typename DType>
   Tensor operator*(const Tensor &t, DType x);
   %template(op) operator*<float>;
@@ -240,10 +254,6 @@ namespace singa{
   %template(op) operator/<float>;
   // --- other types
 
-  /* TODO(chonho-07)
-  no need to include theses
-  in python, these can be replaced with operators
-
   void Add(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
   void Sub(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
   void EltwiseMult(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
@@ -251,13 +261,82 @@ namespace singa{
 
   template <typename DType>
   void Add(const Tensor &t, DType x, Tensor *ret);
+  %template(Add_Tf_out) Add<float>;
+  // --- other types
+
   template <typename DType>
   void Sub(const Tensor &t, DType x, Tensor *ret);
+  %template(Sub_Tf_out) Sub<float>;
+  // --- other types
+
   template <typename DType>
   void EltwiseMult(const Tensor &t, DType x, Tensor *ret);
+  %template(EltwiseMult_Tf_out) EltwiseMult<float>;
+  // --- other types
+
   template <typename DType>
   void Div(const Tensor &t, DType x, Tensor *ret);
-  */
+  %template(Div_Tf_out) Div<float>;
+  // --- other types
+
+
+  /* ========== Random operations ========== */
+  template <typename SType>
+  void Bernoulli(const SType p, Tensor *out);
+  %template(floatBernoulli) Bernoulli<float>;
+  /* TODO for other types */
+  // ...
+
+  template <typename SType>
+  void Gaussian(const SType mean, const SType std, Tensor *out);
+  %template(floatGaussian) Gaussian<float>;
+  /* TODO for other types */
+  // ...
+
+  template <typename SType>
+  void Uniform(const SType low, const SType high, Tensor *out);
+  %template(floatUniform) Uniform<float>;
+  /* TODO for other types */
+  // ...
+
+  /* ========== Blas operations ========== */
+  template <typename SType>
+  void Axpy(SType alpha, const Tensor &in, Tensor *out);
+  %template(floatAxpy) Axpy<float>;
+  /* TODO for other types */
+  // ...
+
+  Tensor Mult(const Tensor &A, const Tensor &B);
+  void Mult(const Tensor &A, const Tensor &B, Tensor *C);
+  template <typename SType>
+  void Mult(const SType alpha, const Tensor &A, const Tensor &B,
+            const SType beta, Tensor *C);
+  %template(floatMult) Mult<float>;
+
+  void AddColumn(const Tensor &v, Tensor *M);
+  template <typename SType>
+  void AddColumn(const SType alpha, const SType beta, const Tensor &v,
+                 Tensor *M);
+  %template(floatAddColumn) AddColumn<float>;
+
+  void AddRow(const Tensor &v, Tensor *M);
+  template <typename SType>
+  void AddRow(const SType alpha, const SType beta, const Tensor &v,
+              Tensor *M);
+  %template(floatAddRow) AddRow<float>;
+
+  void DivColumn(const Tensor &v, Tensor *M);
+  void DivRow(const Tensor &v, Tensor *M);
+  void MultColumn(const Tensor &v, Tensor *M);
+  void MultRow(const Tensor &v, Tensor *M);
+  void SubColumn(const Tensor &v, Tensor *M);
+  void SubRow(const Tensor &v, Tensor *M);
+
+  void SumColumns(const Tensor &M, Tensor *v);
+  void SumRows(const Tensor &M, Tensor *v);
+
+  Tensor SoftMax(const Tensor &in);
+  void SoftMax(const Tensor &in, Tensor *out);
 
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cfdb995/src/python/swig/model_layer.i
----------------------------------------------------------------------
diff --git a/src/python/swig/model_layer.i b/src/python/swig/model_layer.i
index 3fb4917..15bd05f 100644
--- a/src/python/swig/model_layer.i
+++ b/src/python/swig/model_layer.i
@@ -21,7 +21,7 @@
 
 /*interface file for swig */
 
-%module singa_layer
+%module model_layer
 %include "std_vector.i"
 %include "std_string.i"
 %include "std_pair.i"
@@ -46,12 +46,14 @@ namespace std {
   %template(tvectvecPair) pair<vector<Tensor>, vector<Tensor>>;
 }
 
+
 namespace singa {
 
   class Layer {
     public:
       Layer();
-      void Setup(const std::string& proto_str);
+      void Setup(const std::vector<size_t>&, const string&);
+      void Setup(const std::vector<vector<size_t>>&, const string&);
 
       std::string ToProtoStr() const;
       const std::vector<ParamSpec> param_specs();
@@ -64,8 +66,11 @@ namespace singa {
 
       /* virtual functions */
       virtual const std::string layer_type() const;
-      virtual void Setup(const LayerConf& conf);
-      virtual void ToDevice(Device* device);
+      virtual void Setup(const std::vector<size_t>&,
+                         const LayerConf&);
+      virtual void Setup(const std::vector<std::vector<size_t>>&,
+                         const LayerConf&);
+      virtual void ToDevice(std::shared_ptr<Device> device);
       virtual void AsType(DataType dtype);
       virtual void ToProto(LayerConf* conf) const;
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/7cfdb995/src/python/tensor.py
----------------------------------------------------------------------
diff --git a/src/python/tensor.py b/src/python/tensor.py
index 04e070d..d382aae 100644
--- a/src/python/tensor.py
+++ b/src/python/tensor.py
@@ -28,6 +28,7 @@ to call singa::Tensor and its methods
 import sys
 import os
 import numpy as np
+import ctypes
 
 sys.path.append(os.path.join(os.path.dirname(__file__),
                              '../../build/lib'))
@@ -62,7 +63,7 @@ class Tensor(object):
             self.device = device
             self.dtype = dtype
 
-    def toarray(self):
+    def to_array(self):
         # TODO(chonho) - need to think more efficient way to convert???
         idx = self.singa_tensor.data_type()
         if idx == kFloat32:
@@ -92,6 +93,21 @@ class Tensor(object):
         data = np.array(data, dtype=dt).reshape(self.tuple_shape)
         return data
 
+    def from_array(self, np_array):
+        # TODO(chonho)
+        test = np.array([[1,2],[3,4]], dtype=np.float32)
+        test = test.flatten()
+        self.singa_tensor.floatCopyData(test, 4, 0)
+        '''
+        c_float_ptr = ctypes.POINTER(ctypes.c_float)
+        data_ptr = test.ctypes.data_as(c_float_ptr)
+        self.singa_tensor.floatCopyData(data_ptr, 4, 0)
+        '''
+        '''
+        d = [1.0, 2.0, 3.0, 4.0]
+        self.singa_tensor.floatCopyData(d, 4, 0)
+        '''
+
     def data_type(self):
         return self.singa_tensor.data_type()
 
@@ -131,8 +147,12 @@ class Tensor(object):
     def to_host(self):
         self.singa_tensor.ToHost()
 
+    def nrm2(self):
+        self.singa_tensor.L2()
+
     def set_value(self, x):
-        self.singa_tensor.SetValue(x)
+        if type(x) == float:
+            self.singa_tensor.floatSetValue(x)
 
     def copy_data(self, t):
         self.singa_tensor.CopyData(t.singa_tensor)
@@ -161,6 +181,36 @@ class Tensor(object):
         '''
         return self.clone()
 
+    def bernoulli(self, p):
+        if type(p) == float:
+            singa.floatBernoulli(p, self.singa_tensor)
+
+    def gaussian(self, mean, std):
+        if type(mean) == float:
+            singa.floatGaussian(mean, std, self.singa_tensor)
+
+    def uniform(self, low, high):
+        if type(low) == float:
+            singa.floatUniform(low, high, self.singa_tensor)
+
+    def add_column(self, v):
+        singa.AddColumn(v.singa_tensor, self.singa_tensor)
+
+    def add_row(self, v):
+        singa.AddRow(v.singa_tensor, self.singa_tensor)
+
+    def div_column(self, v):
+        singa.DivColumn(v.singa_tensor, self.singa_tensor)
+
+    def div_row(self, v):
+        singa.DivRow(v.singa_tensor, self.singa_tensor)
+
+    def mult_column(self, v):
+        singa.MultColumn(v.singa_tensor, self.singa_tensor)
+
+    def mult_row(self, v):
+        singa.MultRow(v.singa_tensor, self.singa_tensor)
+
     '''
     python operators (+=, -=, *=, /=) for singa::Tensor unary operators
     '''
@@ -213,10 +263,10 @@ class Tensor(object):
 
     def __mul__(self, rhs):
         if isinstance(rhs, Tensor):
-            return _call_singa_func(singa.Mul_TT,
+            return _call_singa_func(singa.EltwiseMul_TT,
                                     self.singa_tensor, rhs.singa_tensor)
         else:
-            return _call_singa_func(singa.Mul_Tf,
+            return _call_singa_func(singa.EltwiseMul_Tf,
                                     self.singa_tensor, rhs)
 
     def __div__(self, rhs):
@@ -296,16 +346,30 @@ def sum(t, axis=None):
         return _call_singa_func(singa.Sum, t.singa_tensor, axis)
 
 
-def pow(t, x):
-    print 'not implemented yet'
+def pow(t, x, out=None):
+    if out is None:
+        if isinstance(x, Tensor):
+            return _call_singa_func(singa.Pow, t.singa_tensor, x.singa_tensor)
+        else:
+            return _call_singa_func(singa.Pow_f, t.singa_tensor, x)
+    else:
+        if isinstance(x, Tensor):
+            singa.Pow(t.singa_tensor, x.singa_tensor, out.singa_tensor)
+        else:
+            singa.Pow_f_out(t.singa_tensor, x, out.singa_tensor)
+        return out
 
 
 def average(t, axis=0):
     return _call_singa_func(singa.Average, t.singa_tensor, axis)
 
 
-def softmax(t, axis=0):
-    return _call_singa_func(singa.SoftMax, t.singa_tensor, axis)
+def softmax(t, out=None):
+    if out is None:
+        return _call_singa_func(singa.SoftMax, t.singa_tensor)
+    else:
+        singa.SoftMax(t.singa_tensor, out.singa_tensor)
+        return out
 
 
 def lt(t, x):
@@ -324,24 +388,128 @@ def ge(t, x):
     return t >= x
 
 
-def add(lhs, rhs):
-    # call Tensor.__add__()
-    return lhs + rhs
+def add(lhs, rhs, ret=None):
+    if ret is None:
+        # call Tensor.__add__()
+        return lhs + rhs
+    else:
+        if isinstance(rhs, Tensor):
+            singa.Add(lhs.singa_tensor, rhs.singa_tensor, ret.singa_tensor)
+        else:
+            singa.Add_Tf_out(lhs.singa_tensor, rhs, ret.singa_tensor)
+        return ret
+
+def sub(lhs, rhs, ret=None):
+    if ret is None:
+        # call Tensor.__sub__()
+        return lhs - rhs
+    else:
+        if isinstance(rhs, Tensor):
+            singa.Sub(lhs.singa_tensor, rhs.singa_tensor, ret.singa_tensor)
+        else:
+            singa.Sub_Tf_out(lhs.singa_tensor, rhs, ret.singa_tensor)
+        return ret
+
+
+def eltwise_mult(lhs, rhs, ret=None):
+    if ret is None:
+        # call Tensor.__mul__()
+        return lhs * rhs
+    else:
+        if isinstance(rhs, Tensor):
+            singa.EltwiseMult(lhs.singa_tensor, rhs.singa_tensor,
+                              ret.singa_tensor)
+        else:
+            singa.EltwiseMult_Tf_out(lhs.singa_tensor, rhs,
+                                     ret.singa_tensor)
+        return ret
+
+
+def mult(A, B, C=None, alpha=1.0, beta=0.0):
+    if C is None:
+        return _call_singa_func(singa.Mult, A.singa_tensor, B.singa_tensor)
+    else:
+        singa.floatMult(alpha, A.singa_tensor, B.singa_tensor,
+                        beta, C.singa_tensor)
+        return C
+
+
+'''
+TODO(chonho) combined into the above
+                delete later
+def mult(A, B, C=None):
+    if C is None:
+        return _call_singa_func(singa.Mult, A.singa_tensor, B.singa_tensor)
+    else:
+        singa_Mult(A.singa_tensor, B.singa_tensor, C.singa_tensor)
+        return C
+
+def axypbz(alpha, A, B, b, C):
+    singa.floatMult(alpha, A.singa_tensor, B.singa_tensor, b, C.singa_tensor)
+    return C
+'''
+
+
+def div(lhs, rhs, ret=None):
+    if ret is None:
+        # call Tensor.__div__()
+        return lhs / rhs
+    else:
+        if isinstance(rhs, Tensor):
+            singa.Div(lhs.singa_tensor, rhs.singa_tensor, ret.singa_tensor)
+        else:
+            singa.Div_Tf_out(lhs.singa_tensor, rhs, ret.singa_tensor)
+        return ret
+
+
+def axpy(alpha, x, y):
+    if type(alpha) == float:
+        singa.floatAxpy(alpha, x.singa_tensor, y.singa_tensor)
+    return y
+
+
+def bernoulli(p, t):
+    if type(p) == float:
+        singa.floatBernoulli(p, t.singa_tensor)
+    return t
+
+
+def gaussian(mean, std, t):
+    if type(mean) == float:
+        singa.floatGaussian(mean, std, t.singa_tensor)
+    return t
+
+
+def uniform(low, high, t):
+    if type(low) == float:
+        singa.floatUniform(low, high, t.singa_tensor)
+    return t
+
+
+def add_column(alpha, v, beta, M):
+    singa.floatAddColumn(alpha, beta, v.singa_tensor, M.singa_tensor)
+    return M
 
 
-def sub(lhs, rhs):
-    # call Tensor.__sub__()
-    return lhs - rhs
+def add_row(alpha, v, beta, M):
+    singa.floatAddRow(alpha, beta, v.singa_tensor, M.singa_tensor)
+    return M
 
 
-def eltwise_mult(lhs, rhs):
-    # call Tensor.__mul__()
-    return lhs * rhs
+def sum_columns(M):
+    assert M.ndim() == 2, 'M.nDim() is supposed to be 2'
+    nb_col = M.shape(0)
+    ret = Tensor((nb_col, 1))
+    singa.SumColumns(M.singa_tensor, ret.singa_tensor)
+    return ret
 
 
-def div(lhs, rhs):
-    # call Tensor.__div__()
-    return lhs / rhs
+def sum_rows(M):
+    assert M.ndim() == 2, 'M.nDim() is supposed to be 2'
+    nb_row = M.shape(1)
+    ret = Tensor((1, nb_row))
+    singa.SumRows(M.singa_tensor, ret.singa_tensor)
+    return ret
 
 
 ''' private functions, internally used


Mime
View raw message