singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [2/2] incubator-singa git commit: SINGA-245 Float as the first operand can not multiply with a tensor object
Date Thu, 08 Sep 2016 14:50:18 GMT
SINGA-245 Float as the first operand can not multiply with a tensor object

Add reverse add/sub/mult/div for float-tensor operations.
add unit tests in test_tensor.py


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/0ebce1a4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/0ebce1a4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/0ebce1a4

Branch: refs/heads/master
Commit: 0ebce1a44913dc760f3f6398b34fa45b3dcca5e8
Parents: 76cd806
Author: Wei Wang <wangwei.cs@gmail.com>
Authored: Thu Sep 8 22:48:46 2016 +0800
Committer: Wei Wang <wangwei.cs@gmail.com>
Committed: Thu Sep 8 22:48:46 2016 +0800

----------------------------------------------------------------------
 src/python/singa/loss.py   |  3 ++-
 src/python/singa/tensor.py | 24 ++++++++++++++++++++++++
 test/python/test_tensor.py | 29 +++++++++++++++++++++++++++++
 3 files changed, 55 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0ebce1a4/src/python/singa/loss.py
----------------------------------------------------------------------
diff --git a/src/python/singa/loss.py b/src/python/singa/loss.py
index 8b99ad3..526e4d0 100644
--- a/src/python/singa/loss.py
+++ b/src/python/singa/loss.py
@@ -95,6 +95,7 @@ class SoftmaxCrossEntropy(Loss):
     '''
 
     def __init__(self):
+        super(SoftmaxCrossEntropy, self).__init__()
         self.swig_loss = singa.SoftmaxCrossEntropy()
 
 
@@ -105,7 +106,7 @@ class SquaredError(Loss):
     It is implemented using Python Tensor operations.
     '''
     def __init__(self):
-        super(Loss, self).__init__()
+        super(SquareLoss, self).__init__()
         self.err = None
 
     def forward(self, flag, x, y):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0ebce1a4/src/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/src/python/singa/tensor.py b/src/python/singa/tensor.py
index f6bca43..1024483 100644
--- a/src/python/singa/tensor.py
+++ b/src/python/singa/tensor.py
@@ -372,6 +372,7 @@ class Tensor(object):
 
     '''
     python operators (+, -, *, /, <, <=, >, >=) for singa binary operators
+    https://docs.python.org/2/library/operator.html#mapping-operators-to-functions
     '''
 
     def __add__(self, rhs):
@@ -441,6 +442,29 @@ class Tensor(object):
             return _call_singa_func(singa.GE_Tf, self.singa_tensor, rhs)
 
 
+    def __radd__(self, lhs):
+        lhs = float(lhs)
+        return _call_singa_func(singa.Add_Tf, self.singa_tensor, lhs)
+
+    def __rsub__(self, lhs):
+        lhs = float(lhs)
+        ret = _call_singa_func(singa.Sub_Tf, self.singa_tensor, lhs)
+        ret *= -1
+        return ret
+
+    def __rmul__(self, lhs):
+        lhs = float(lhs)
+        return _call_singa_func(singa.EltwiseMul_Tf, self.singa_tensor, lhs)
+
+    def __rdiv__(self, lhs):
+        lhs = float(lhs)
+        one = Tensor(self.shape, self.device, self.dtype)
+        one.set_value(1)
+        one *= lhs
+        return _call_singa_func(singa.Div_TT, one.singa_tensor,\
+                self.singa_tensor)
+
+
 ''' python functions for global functions in Tensor.h
 '''
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0ebce1a4/test/python/test_tensor.py
----------------------------------------------------------------------
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index 2374adc..a1f220b 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -133,5 +133,34 @@ class TestTensorMethods(unittest.TestCase):
         self.assertAlmostEqual(tensor.average(x), 1, 3)
 
 
+    def test_radd(self):
+        x = tensor.Tensor((3,))
+        x.set_value(1)
+        y = 1 + x
+        self.assertEqual(tensor.average(y), 2.)
+
+
+    def test_rsub(self):
+        x = tensor.Tensor((3,))
+        x.set_value(1)
+        y = 1 - x
+        self.assertEqual(tensor.average(y), 0.)
+
+
+    def test_rmul(self):
+        x = tensor.Tensor((3,))
+        x.set_value(1)
+        y = 2 * x
+        self.assertEqual(tensor.average(y), 2.)
+
+
+    def test_rdiv(self):
+        x = tensor.Tensor((3,))
+        x.set_value(1)
+        y = 2 / x
+        self.assertEqual(tensor.average(y), 2.)
+
+
+
 if __name__ == '__main__':
     unittest.main()


Mime
View raw message