singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [02/12] incubator-singa git commit: SINGA-347 Create a function that supports einsum 1.provide the diag() function which could make the matrix diagonalized along the given axis 2.provide some comments for some existing function 3. to do: provide complete
Date Sun, 29 Apr 2018 15:31:11 GMT
SINGA-347 Create a function that supports einsum
1.provide the diag() function which could make the matrix diagonalized along the given axis
2.provide some comments for some existing function
3. to do: provide complete comments and fix the codes. It still needs test under Singa environment

SINGA-347 Create a function that supports einsum
fix one bug


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b09aed3e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b09aed3e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b09aed3e

Branch: refs/heads/master
Commit: b09aed3e395b3bbf1eb32badaf1f8af19e0799e8
Parents: f595f10
Author: sheyujian <sheyujian@me.com>
Authored: Wed Apr 18 13:41:54 2018 +0800
Committer: sheyujian <sheyujian@me.com>
Committed: Wed Apr 18 13:56:27 2018 +0800

----------------------------------------------------------------------
 python/singa/tensor.py | 128 +++++++++++++++++++++++++++++++++-----------
 1 file changed, 97 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b09aed3e/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index fde1667..e81f3ab 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -947,8 +947,8 @@ def tensordot(A, B, axes=2):
     are treated as one axis by reshaping.
 
     Args:
-        a (Singa.Tensor): The first argument.
-        b (Singa.Tensor): The second argument.
+        A (Singa.Tensor): The first argument.
+        B (Singa.Tensor): The second argument.
         axes:
             - If it is an integer, then ''axes'' represent axes at the last of ''a`'' and
               the first of ''b'' are used.
@@ -961,6 +961,7 @@ def tensordot(A, B, axes=2):
         axes specified by ''axes''.
 
     Thanks to numpy.tensordot.
+    the link is https://github.com/numpy/numpy/blob/v1.14.0/numpy/core/numeric.py#L1123-L1306
     """
     # when axes is an integer, axes_A and axes_B represent axes at the last of ''a`'' and
     # the first of ''b''. For example, when axes is 1, we do the normal multiplication :
@@ -971,16 +972,16 @@ def tensordot(A, B, axes=2):
         axes_B = list(range(0, axes))
         axes_B = axes_B
     else:
-        axes_A, axes_B = axes
+        axes_A,axes_B =axes
     # when axes is a pair of sequences of integers.For example, A is in shape(3,2,4),
-    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5)
-    if isinstance(axes_A, list):
+    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5)
+    if isinstance(axes_A,list):
         na = len(axes_A)
         axes_A = list(axes_A)
     else:
         axes_A = [axes_A]
         na = 1
-    if isinstance(axes_B, list):
+    if isinstance(axes_B,list):
         nb = len(axes_B)
         axes_B = list(axes_B)
     else:
@@ -997,7 +998,7 @@ def tensordot(A, B, axes=2):
     if na != nb:
         equal = False
     else:
-        # to make the shape match
+    # to make the shape match
         for k in range(na):
             if a_shape[axes_A[k]] != b_shape[axes_B[k]]:
                 equal = False
@@ -1018,12 +1019,13 @@ def tensordot(A, B, axes=2):
         N2 *= a_shape[axis]
     N1 = 1
     for ax in notin:
-        N1 *= a_shape[ax]
-    # newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4),
-    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
-    # olda is the shape that will be shown in the result.
-    newshape_a = (N1, N2)
+        N1 *=a_shape[ax]
+    #newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4),
+    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
+    #olda is the shape that will be shown in the result.
+    newshape_a = (N1,N2)
     olda = [a_shape[axis] for axis in notin]
+    notin = [k for k in range(ndb) if k not in axes_B]
     newaxes_b = axes_B + notin
     N2 = 1
     for axis in axes_B:
@@ -1034,42 +1036,69 @@ def tensordot(A, B, axes=2):
     newshape_b = (N2, N1)
     oldb = [b_shape[axis] for axis in notin]
     # do transpose and reshape to get the 2D matrix to do multiplication
-    at = a.transpose(newaxes_a).reshape(newshape_a)
-    bt = b.transpose(newaxes_b).reshape(newshape_b)
+    at = A.transpose(newaxes_a).reshape(newshape_a)
+    bt = B.transpose(newaxes_b).reshape(newshape_b)
     res = mult(at, bt)
-    # reshape the result
+    #reshape the result
     return res.reshape(olda + oldb)
 
 
-def einsum(A,B,ops):
-    '''
-    Thanks to nils-werner/sparse.einsum()
+def einsum_(A,B,ops):
+    '''Do the matrix to matrix einsum calculation according to the operands
+
+    Args:
+        A (Singa.Tensor): The first argument.
+        B (Singa.Tensor): The second argument.
+        ops(string):
+            the string specifies the subscripts for summation such as 'ki,kj->kij'
+
+    Returns: Singa.Tensor
+        the output matirx of the einsum calculation
+    Thanks to nils-werner/sparse.einsum(),
+    the link is https://github.com/nils-werner/sparse/commit/449c75d21d3158630bc5be79c690a60cfc002578
     '''
 
     if len(ops) == 0:
         raise ValueError("No input operands")
-
-    nputops, outputops = ops.split('->')
+    # to get the input and output ops
+    inputops, outputops = ops.split('->')
     inputops = inputops.split(',')
 
+    if A.ndim != len(inputops[0]) or B.ndim != len(inputops[1]):
+        raise ValueError("input dim doesn't match operands")
+
     # All indices that are in input AND in output are multiplies
-    multiplies = sorted(list(set(inputops[0]) & set(inputops[1]) & set(outputops)))
     # All indices that are in input BUT NOT in output are sum contractions
+    multiplies = sorted(list(set(inputops[0]) & set(inputops[1]) & set(outputops)))
     sums = sorted(list((set(inputops[0]) & set(inputops[1])) - set(outputops)))
 
     # Map sums and indices to axis integers
-    multiplies = [[inop.find(x) for x in multiplies] for inop in inputops]
+    multiplies = [[inops.find(x) for x in multiplies] for inops in inputops]
+    sums = [[inops.find(x) for x in sums] for inops in inputops]
+
 
-    sums = [[inop.find(x) for x in sums] for inop in inputops]
     # Find output axes in input axes for final transpose
-    # Values very likely lie outside of output tensor shape, so
-    # just map them values to their rank (index in ordered list)
     transpose = [''.join(inputops).find(x) for x in outputops]
+    #to make the transpose match to its rank
     transpose = sorted(range(len(transpose)), key = transpose.__getitem__)
 
-    return tensormult(A,B, sum=sums, multiply=multiplies).transpose(transpose)
+    return tensormult(A,B,sums,multiplies).transpose(transpose)
+
 
-def tensordotmult(A, B, sum=None, multiply=None):
+def tensormult(A, B, sum=None, multiply=None):
+    '''
+    Args:
+        A (Singa.Tensor): The first input tensor from einsum.
+        B (Singa.Tensor): The second input tensor from einsum.
+        sum (list[list[int]]):
+            The axis to do the normal tensordot calculation
+
+        multiply (list[list[int]]):
+            The axis to multiply
+
+    Return: (Singa.Tensor)
+        The output of the tensormult calculation
+    '''
     if sum is None:
         sum = [[], []]
     else:
@@ -1079,14 +1108,15 @@ def tensordotmult(A, B, sum=None, multiply=None):
         multiply = [[], []]
     else:
         multiply = list(multiply)
-
     # For each multiply[0] we are adding one axis, thus we need to increment
     # all following items by one: (0, 1, 2) -> (0, 2, 4)
     idx = multipliessort(multiply[0])
+
     post_multiply = multiply[0]
     for i, v in enumerate(idx):
         post_multiply[v] += i
 
+
     for i in post_multiply:
         A = diag(A,i)
 
@@ -1096,6 +1126,15 @@ def tensordotmult(A, B, sum=None, multiply=None):
     return tensordot(A, B, axes=sum)
 
 def multipliessort(multiplies):
+    '''
+    Returns the indices that would sort an array
+    Args:
+        multiplies(list[int]):
+            the list of the input multiplies
+
+    Return(list[int]):
+        the indices list
+    '''
 
     if multiplies is None:
         multiplies = []
@@ -1106,9 +1145,36 @@ def multipliessort(multiplies):
     idx = [x[0] for x in sort_multiplies]
     return idx
 
-def diag(A,axis = -1):
-    A_diag = clone(A)
-    '''sheyujian todo: to make a tensor's axis to be diagonalization'''
+def diag(A,axis=-1):
+    '''
+    Return the matrix that is diagonalize along the axis
+    Args:
+        A (Singa.Tensor):the input tensor
+        axis (int):
+            the axis which the matrix use to diagonalization
+    Return(Singa.Tensor):
+        the tensor which has been diagonalize along the given axis
+
+
+    It is like using numpy.diag() in the 1D condition, but this function can be used to do
high dimension
+    matrix diagonalization
+
+    '''
+    # to get the shape of the diagonalize matirx
+    axis_diag = list(range(out.ndim))[axis]
+    shape = list(A.shape)
+    shape.insert(axis_diag,shape[axis_diag])
+
+    npA = to_numpy(A)
+    npA_diag = np.zeros(product(shape))
+    index = np.argwhere(npA != 0)
+    for i in index:
+        i_diag = list(i)
+        i_diag.insert(axis, i[axis])
+        index_npA = [[i[x]] for x in range(len(i))]
+        index_npA_diag = [[i_diag[x]] for x in range(len(i_diag))]
+        npA_diag[index_npA_diag] = npA[index_npA]
+    A_diag = from_numpy(npA_diag)
     return A_diag
 
 


Mime
View raw message