singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [04/12] incubator-singa git commit: SINGA-347 Create a function that supports einsum 1.test the tensordot function and fix some error in the function 2.tweak the code to be more readable and fix the some error in the comment
Date Sun, 29 Apr 2018 15:31:13 GMT
SINGA-347 Create a function that supports einsum
1.test the tensordot function and fix some error in the function
2.tweak the code to be more readable and fix the some error in the comment


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/9a3ce585
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/9a3ce585
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/9a3ce585

Branch: refs/heads/master
Commit: 9a3ce585e468f4f5f17b8ea82658baf9d6ccd2aa
Parents: e27498d
Author: sheyujian <sheyujian@me.com>
Authored: Fri Apr 13 14:18:08 2018 +0800
Committer: sheyujian <sheyujian@me.com>
Committed: Wed Apr 18 13:56:27 2018 +0800

----------------------------------------------------------------------
 python/singa/tensor.py | 85 +++++++++++++++++++++++----------------------
 1 file changed, 43 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9a3ce585/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index b5f5d99..49fa052 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -938,52 +938,53 @@ def mult(A, B, C=None, alpha=1.0, beta=0.0):
         singa.MultWithScale(alpha, A.singa_tensor, B.singa_tensor,
                             beta, C.singa_tensor)
         return C
-def tensordot (A,B,axes=2):
-    
+
+
+def tensordot(A, B, axes=2):
     """Returns the tensor multiplication of two tensors along specified axes.
-        
+
     This is equivalent to compute dot product along the specified axes which
     are treated as one axis by reshaping.
-        
+
     Args:
         a (Singa.Tensor): The first argument.
         b (Singa.Tensor): The second argument.
         axes:
-        - If it is an integer, then ''axes'' represent axes at the last of ''a`'' and
-        the first of ''b'' are used.
-        - If it is a pair of sequences of integers, then these two
-        sequences specify the list of axes for ''a'' and ''b''. The
-        corresponding axes are paired for sum-product.
-        
+            - If it is an integer, then ''axes'' represent axes at the last of ''a`'' and
+              the first of ''b'' are used.
+            - If it is a pair of sequences of integers, then these two
+              sequences specify the list of axes for ''a'' and ''b''. The
+              corresponding axes are paired for sum-product.
+
     Return:
         singa.tensor: The tensor  product of ''a'' and ''b'' along the
         axes specified by ''axes''.
-        """
+    """
     # when axes is an integer, axes_A and axes_B represent axes at the last of ''a`'' and
     # the first of ''b''. For example, when axes is 1, we do the normal multiplication :
     # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5)
-    #when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5)
-    try:
-        iter(axes)
-    except Exception:
+    # when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5)
+    if type(axes) == int:
         axes_A = list(range(-axes, 0))
         axes_B = list(range(0, axes))
+        axes_B = axes_B.reverse()
     else:
         axes_A, axes_B = axes
     # when axes is a pair of sequences of integers.For example, A is in shape(3,2,4),
-    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5)
-    try:
+    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5)
+    if isinstance(axes_A, list):
         na = len(axes_A)
-        axes_A = list(axes_B)
-    except TypeError:
+        axes_A = list(axes_A)
+    else:
         axes_A = [axes_A]
         na = 1
-    try:
-        nb = len(axes_A)
+    if isinstance(axes_B, list):
+        nb = len(axes_B)
         axes_B = list(axes_B)
-    except TypeError:
+    else:
         axes_B = [axes_B]
         nb = 1
+
     # a_shape and b_shape are the shape of tensor A and B, while nda and ndb are the dim
of A and B
     a_shape = A.shape
     nda = A.ndim
@@ -994,47 +995,47 @@ def tensordot (A,B,axes=2):
     if na != nb:
         equal = False
     else:
-    # to make the shape match
+        # to make the shape match
         for k in range(na):
-            if a_shape[axes_a[k]] != b_shape[axes_b[k]]:
+            if a_shape[axes_A[k]] != b_shape[axes_B[k]]:
                 equal = False
                 break
-            if axes_a[k] < 0:
-                axes_a[k] += nda
-            if axes_b[k] < 0:
-                axes_b[k] += ndb
+            if axes_A[k] < 0:
+                axes_A[k] += nda
+            if axes_B[k] < 0:
+                axes_B[k] += ndb
     if not equal:
         raise ValueError("shape-mismatch for sum")
     '''start to do the calculation according to the axes'''
-    notin = [k for k in range(nda) if k not in axes_a]
+
+    notin = [k for k in range(nda) if k not in axes_A]
     # nda is the dim of A, and axes_a is the axis for A, notin is the axis which is not in
axes_A
-    newaxes_a = notin + axes_a
+    newaxes_a = notin + axes_A
     N2 = 1
-    for axis in axes_a:
+    for axis in axes_A:
         N2 *= a_shape[axis]
     N1 = 1
     for ax in notin:
-        N1 *=a_shape[ax]
+        N1 *= a_shape[ax]
     # newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4),
-    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,8)
-    #olda is the shape that will be shown in the result.
-    newshape_a = (N1,N2)
+    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
+    # olda is the shape that will be shown in the result.
+    newshape_a = (N1, N2)
     olda = [a_shape[axis] for axis in notin]
-    notin = [k for k in range(ndb) if k not in axes_b]
-    newaxes_b = axes_b + notin
+    newaxes_b = axes_B + notin
     N2 = 1
-    for axis in axes_b:
+    for axis in axes_B:
         N2 *= b_shape[axis]
     N1 = 1
-    for ax in notin:
-        N1 *=a_shape[ax]
+    for bx in notin:
+        N1 *= b_shape[bx]
     newshape_b = (N2, N1)
     oldb = [b_shape[axis] for axis in notin]
     # do transpose and reshape to get the 2D matrix to do multiplication
     at = a.transpose(newaxes_a).reshape(newshape_a)
     bt = b.transpose(newaxes_b).reshape(newshape_b)
-    res = mult(a, b)
-    #reshape the result
+    res = mult(at, bt)
+    # reshape the result
     return res.reshape(olda + oldb)
 
 def div(lhs, rhs, ret=None):


Mime
View raw message