SINGA347 Create a function that supports einsum
1.test the tensordot function and fix some error in the function
2.tweak the code to be more readable and fix the some error in the comment
Project: http://gitwipus.apache.org/repos/asf/incubatorsinga/repo
Commit: http://gitwipus.apache.org/repos/asf/incubatorsinga/commit/9a3ce585
Tree: http://gitwipus.apache.org/repos/asf/incubatorsinga/tree/9a3ce585
Diff: http://gitwipus.apache.org/repos/asf/incubatorsinga/diff/9a3ce585
Branch: refs/heads/master
Commit: 9a3ce585e468f4f5f17b8ea82658baf9d6ccd2aa
Parents: e27498d
Author: sheyujian <sheyujian@me.com>
Authored: Fri Apr 13 14:18:08 2018 +0800
Committer: sheyujian <sheyujian@me.com>
Committed: Wed Apr 18 13:56:27 2018 +0800

python/singa/tensor.py  85 +++++++++++++++++++++++
1 file changed, 43 insertions(+), 42 deletions()

http://gitwipus.apache.org/repos/asf/incubatorsinga/blob/9a3ce585/python/singa/tensor.py

diff git a/python/singa/tensor.py b/python/singa/tensor.py
index b5f5d99..49fa052 100644
 a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ 938,52 +938,53 @@ def mult(A, B, C=None, alpha=1.0, beta=0.0):
singa.MultWithScale(alpha, A.singa_tensor, B.singa_tensor,
beta, C.singa_tensor)
return C
def tensordot (A,B,axes=2):

+
+
+def tensordot(A, B, axes=2):
"""Returns the tensor multiplication of two tensors along specified axes.

+
This is equivalent to compute dot product along the specified axes which
are treated as one axis by reshaping.

+
Args:
a (Singa.Tensor): The first argument.
b (Singa.Tensor): The second argument.
axes:
  If it is an integer, then ''axes'' represent axes at the last of ''a`'' and
 the first of ''b'' are used.
  If it is a pair of sequences of integers, then these two
 sequences specify the list of axes for ''a'' and ''b''. The
 corresponding axes are paired for sumproduct.

+  If it is an integer, then ''axes'' represent axes at the last of ''a`'' and
+ the first of ''b'' are used.
+  If it is a pair of sequences of integers, then these two
+ sequences specify the list of axes for ''a'' and ''b''. The
+ corresponding axes are paired for sumproduct.
+
Return:
singa.tensor: The tensor product of ''a'' and ''b'' along the
axes specified by ''axes''.
 """
+ """
# when axes is an integer, axes_A and axes_B represent axes at the last of ''a`'' and
# the first of ''b''. For example, when axes is 1, we do the normal multiplication :
# if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5)
 #when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5)
 try:
 iter(axes)
 except Exception:
+ # when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5)
+ if type(axes) == int:
axes_A = list(range(axes, 0))
axes_B = list(range(0, axes))
+ axes_B = axes_B.reverse()
else:
axes_A, axes_B = axes
# when axes is a pair of sequences of integers.For example, A is in shape(3,2,4),
 #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5)
 try:
+ # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5)
+ if isinstance(axes_A, list):
na = len(axes_A)
 axes_A = list(axes_B)
 except TypeError:
+ axes_A = list(axes_A)
+ else:
axes_A = [axes_A]
na = 1
 try:
 nb = len(axes_A)
+ if isinstance(axes_B, list):
+ nb = len(axes_B)
axes_B = list(axes_B)
 except TypeError:
+ else:
axes_B = [axes_B]
nb = 1
+
# a_shape and b_shape are the shape of tensor A and B, while nda and ndb are the dim
of A and B
a_shape = A.shape
nda = A.ndim
@@ 994,47 +995,47 @@ def tensordot (A,B,axes=2):
if na != nb:
equal = False
else:
 # to make the shape match
+ # to make the shape match
for k in range(na):
 if a_shape[axes_a[k]] != b_shape[axes_b[k]]:
+ if a_shape[axes_A[k]] != b_shape[axes_B[k]]:
equal = False
break
 if axes_a[k] < 0:
 axes_a[k] += nda
 if axes_b[k] < 0:
 axes_b[k] += ndb
+ if axes_A[k] < 0:
+ axes_A[k] += nda
+ if axes_B[k] < 0:
+ axes_B[k] += ndb
if not equal:
raise ValueError("shapemismatch for sum")
'''start to do the calculation according to the axes'''
 notin = [k for k in range(nda) if k not in axes_a]
+
+ notin = [k for k in range(nda) if k not in axes_A]
# nda is the dim of A, and axes_a is the axis for A, notin is the axis which is not in
axes_A
 newaxes_a = notin + axes_a
+ newaxes_a = notin + axes_A
N2 = 1
 for axis in axes_a:
+ for axis in axes_A:
N2 *= a_shape[axis]
N1 = 1
for ax in notin:
 N1 *=a_shape[ax]
+ N1 *= a_shape[ax]
# newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4),
 #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,8)
 #olda is the shape that will be shown in the result.
 newshape_a = (N1,N2)
+ # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
+ # olda is the shape that will be shown in the result.
+ newshape_a = (N1, N2)
olda = [a_shape[axis] for axis in notin]
 notin = [k for k in range(ndb) if k not in axes_b]
 newaxes_b = axes_b + notin
+ newaxes_b = axes_B + notin
N2 = 1
 for axis in axes_b:
+ for axis in axes_B:
N2 *= b_shape[axis]
N1 = 1
 for ax in notin:
 N1 *=a_shape[ax]
+ for bx in notin:
+ N1 *= b_shape[bx]
newshape_b = (N2, N1)
oldb = [b_shape[axis] for axis in notin]
# do transpose and reshape to get the 2D matrix to do multiplication
at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
 res = mult(a, b)
 #reshape the result
+ res = mult(at, bt)
+ # reshape the result
return res.reshape(olda + oldb)
def div(lhs, rhs, ret=None):
