tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] u99127 commented on a change in pull request #5510: [FRONTEND][TFLite] Fully connected op conversion made in sync with TFLite
Date Mon, 04 May 2020 20:50:42 GMT

u99127 commented on a change in pull request #5510:
URL: https://github.com/apache/incubator-tvm/pull/5510#discussion_r419715441



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -419,6 +419,31 @@ def test_forward_cast():
     _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
     _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)
 
+#######################################################################
+# Batch Mat Mul
+# ----
+def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
+    with tf.Graph().as_default():
+        A = array_ops.placeholder(shape=A_shape, dtype=dtype, name='A')
+        B = array_ops.placeholder(shape=B_shape, dtype=dtype, name='B')
+        result = math_ops.matmul(A, B, adjoint_a=adjoint_a,
+                           adjoint_b=adjoint_b, name='batchmatmul')
+
+        A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
+        B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+        compare_tflite_with_tvm([A_np, B_np], [A.name, B.name], [A, B], [result])
+
+
+def test_forward_batch_matmul():
+    """ BATCH_MAT_MUL """
+    print("Jai hanuman!!!")
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32')
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', True, False)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
+    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32')
+    print("Jai hanuman!!!")

Review comment:
       Unneeded print.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1329,16 +1329,20 @@ def convert_fully_connected(self, op):
         input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
         weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()
 
-        # reshape input tensor from N H W C to N H*W*C
-        input_size_per_batch = 1
-        for s in range(1, len(input_tensor_shape)):
-            input_size_per_batch *= input_tensor_shape[s]
-        assert input_size_per_batch == weight_tensor_shape[1], \
-            "input size and weight size are mismatched"
-        target_shape = tuple((input_tensor_shape[0], input_size_per_batch))
+        # Weight should have only 2 dimensions(TFLite convention)
+        assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim"
+
+        input_size = 1
+        for s in range(0, len(input_tensor_shape)):
+            input_size *= input_tensor_shape[s]
+

Review comment:
       A comment here would be appropriate or indeed a reference to TFlite documentation.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1329,16 +1329,20 @@ def convert_fully_connected(self, op):
         input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
         weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()
 
-        # reshape input tensor from N H W C to N H*W*C
-        input_size_per_batch = 1
-        for s in range(1, len(input_tensor_shape)):
-            input_size_per_batch *= input_tensor_shape[s]
-        assert input_size_per_batch == weight_tensor_shape[1], \
-            "input size and weight size are mismatched"
-        target_shape = tuple((input_tensor_shape[0], input_size_per_batch))

Review comment:
       Can you explain why this assert is wrong or needs to be removed in your covering note
?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message