singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [06/12] incubator-singa git commit: SINGA-347 Create a function that supports einsum assuming have the numpy.repeat, finish the einsum function (using elementwisemult, no need to use tensordot)
Date Sun, 29 Apr 2018 15:31:15 GMT
SINGA-347 Create a function that supports einsum
assuming have the numpy.repeat, finish the einsum function (using elementwisemult, no need
to use tensordot)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2ec13649
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2ec13649
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2ec13649

Branch: refs/heads/master
Commit: 2ec1364930a7a5b17640151b7cd16ae8fe75e131
Parents: 2ec06ed
Author: sheyujian <sheyujian@me.com>
Authored: Thu Apr 19 16:35:38 2018 +0800
Committer: sheyujian <sheyujian@me.com>
Committed: Thu Apr 19 16:35:38 2018 +0800

----------------------------------------------------------------------
 python/singa/tensor.py | 60 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 59 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2ec13649/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index e6502fc..38acf6d 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -1042,7 +1042,6 @@ def tensordot(A, B, axes=2):
     #reshape the result
     return res.reshape(olda + oldb)
 
-
 def einsum_(A,B,ops):
     '''Do the matrix to matrix einsum calculation according to the operands
 
@@ -1183,6 +1182,65 @@ def diag(A,axis=-1):
     A_diag = from_numpy(npA_diag)
     return A_diag
 
+def einsum2(A,B,ops):
+    '''
+        Do the matrix to matrix einsum calculation according to the operands
+
+    Args:
+        A (Singa.Tensor): The first argument.
+        B (Singa.Tensor): The second argument.
+        ops(string):
+            the string specifies the subscripts for summation such as 'ki,kj->kij'
+
+    Returns: Singa.Tensor
+        the output matirx of the einsum calculation
+    '''
+
+    if len(ops) == 0:
+        raise ValueError("No input operands")
+
+    inputops, outputops = ops.split('->')
+    inputops = inputops.split(',')
+
+    if A.ndim != len(inputops[0]) or B.ndim != len(inputops[1]):
+        raise ValueError("input dim doesn't match operands")
+
+    sums = sorted(list((set(inputops[0]) | set(inputops[1])) - set(outputops)))
+
+    broadcast_A = sorted(list(set(inputops[1]) - set(inputops[0])))
+    broadcast_B = sorted(list(set(inputops[0]) - set(inputops[1])))
+
+
+    outputall = sorted(list(set(inputops[0]) | set(inputops[1])))
+
+    sums = [outputall.index(x) for x in sums]
+    broadcast_idA = [inputops[1].find(x) for x in broadcast_A]
+    broadcast_idB = [inputops[0].find(x) for x in broadcast_B]
+
+    broadcast_a = [B.shape[x] for x in broadcast_idA]
+    broadcast_b = [A.shape[x] for x in broadcast_idB]
+
+    transpose_A = [(list(inputops[0])+broadcast_A) .index(x) for x in outputall]
+    transpose_B = [(list(inputops[1])+broadcast_B) .index(x) for x in outputall]
+
+
+    reshape_A = list(A.shape)+broadcast_a
+    reshape_B = list(B.shape)+broadcast_b
+
+    mult_A = np.repeat(A, np.product(broadcast_a)).reshape(reshape_A).transpose(transpose_A)
+    mult_B = np.repeat(B, np.product(broadcast_b)).reshape(reshape_B).transpose(transpose_B)
+
+    if mult_A.shape != mult_B.shape:
+        raise ValueError("error: matrix dimension mismatch")
+    res = eltwise_mult(mult_A, mult_B)
+
+    sum_R = sorted(sums, reverse=True)
+    for i in sum_R:
+        res = res.sum(axis=i)
+    transpose_res = [sorted(list(outputops)).index(x) for x in list(outputops)]
+
+    return res.transpose(transpose_res)
+
 
 
 def div(lhs, rhs, ret=None):


Mime
View raw message