tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] siju-samuel commented on a change in pull request #5329: [Frontend][TFLite] Add parser support for shape and range
Date Fri, 08 May 2020 05:13:31 GMT

siju-samuel commented on a change in pull request #5329:
URL: https://github.com/apache/incubator-tvm/pull/5329#discussion_r421938105



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -579,6 +582,63 @@ def convert_tanh(self, op):
 
         return out
 
+    def convert_range(self, op):
+        """Convert TFLite Range"""
+        try:
+            from tflite.Operator import Operator

Review comment:
       remove this import op, its already handled

##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -650,6 +693,82 @@ def test_all_resize():
         _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)
 
 
+#######################################################################
+# Range
+# -----
+def _test_range(start, limit, delta):
+    # tflite 1.13 convert method does not accept empty shapes
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        tf.reset_default_graph()
+        with tf.Graph().as_default():
+            start_scalar, limit_scalar, delta_scalar = \
+                tf.placeholder(dtype=start.dtype, shape=(), name="start"), \
+                tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \
+                tf.placeholder(dtype=delta.dtype, shape=(), name="delta")
+
+            out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range")
+
+            compare_tflite_with_tvm(
+                [start, limit, delta],
+                ["start", "limit", "delta"],
+                [start_scalar, limit_scalar, delta_scalar],
+                [out],
+                mode="vm",
+                quantized=False
+        )
+
+def _test_range_default():
+    # tflite 1.13 convert method does not accept empty shapes
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        tf.reset_default_graph()
+        with tf.Graph().as_default():
+
+            inputs = [
+                tf.placeholder(dtype=tf.int32, shape=(), name="p1"),
+                tf.placeholder(dtype=tf.int32, shape=(), name="p2")
+            ]
+            leaves = [
+                tf.range(start = inputs[0], limit = inputs[1]), #use default delta
+                tf.range(start = inputs[1]) #use start as limit with 0 as the first item
in the range
+            ]
+
+            compare_tflite_with_tvm(
+                [np.int32(1), np.int32(18)],
+                ["p1", "p2"],
+                inputs,
+                leaves,
+                mode="vm",
+                quantized=False
+        )
+
+def test_forward_range():
+   _test_range(np.int32(1), np.int32(18), np.int32(3))
+   _test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float
+   _test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float
+   _test_range_default()
+
+#######################################################################
+# Shape
+# -----
+def test_forward_shape():
+    # tflite 1.13 convert method does not accept empty shapes
+    if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
+        tf.reset_default_graph()
+        with tf.Graph().as_default():
+            data = np.array([1, 18, 3], dtype=np.int32)
+            start = tf.placeholder(dtype=tf.int32, shape=[], name="start")
+            limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit")
+            delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta")
+            r = tf.range(start, limit, delta, tf.int32, name="range")
+            out = tf.shape(r, out_type=tf.dtypes.int32)
+            compare_tflite_with_tvm(
+                [x for x in np.nditer(data)],
+                ["start", "limit", "delta"],
+                [start, limit, delta],
+                [out],
+                mode="vm",
+                quantized=False
+            )
 #######################################################################

Review comment:
       add new lines here

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -579,6 +582,63 @@ def convert_tanh(self, op):
 
         return out
 
+    def convert_range(self, op):
+        """Convert TFLite Range"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.TensorType import TensorType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                'TFlite quantized RANGE operator is not supported yet.')
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
+        expressions = []
+
+        for t in [start, limit, delta]:
+            if self.has_expr(t.tensor_idx):
+                expressions.append(self.get_expr(t.tensor_idx))
+            else:
+                tensor_type = self.get_tensor_type_str(t.tensor.Type())
+                tensor_value = self.get_tensor_value(t)
+                expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))
+
+        #out type inference

Review comment:
       #out  ->  # out .......
   Add space before starting comment. Change in all places

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -579,6 +582,63 @@ def convert_tanh(self, op):
 
         return out
 
+    def convert_range(self, op):
+        """Convert TFLite Range"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.TensorType import TensorType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                'TFlite quantized RANGE operator is not supported yet.')
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
+        expressions = []
+
+        for t in [start, limit, delta]:
+            if self.has_expr(t.tensor_idx):
+                expressions.append(self.get_expr(t.tensor_idx))
+            else:
+                tensor_type = self.get_tensor_type_str(t.tensor.Type())
+                tensor_value = self.get_tensor_value(t)
+                expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))
+
+        #out type inference
+        if delta.tensor.Type() == TensorType.FLOAT32:
+            out_type = self.get_tensor_type_str(delta.tensor.Type())
+        else:
+            out_type = self.get_tensor_type_str(start.tensor.Type())
+
+        #put type here form op
+        out = _op.arange(expressions[0], expressions[1], expressions[2], out_type)
+
+        return out
+
+    def convert_shape(self, op):
+        """Convert TFLite Shape"""
+        try:
+            from tflite.Operator import Operator
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                'TFlite quantized SHAPE operator is not supported yet.')
+

Review comment:
       Does the shape output have any impact with quantized inputs? do we need this check?

##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -188,7 +231,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
                 continue
 
             tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
-                                       num_output=len(out_names), out_names=out_names)
+                                       num_output=len(out_names), out_names=out_names,mode=mode)

Review comment:
       `out_names,mode` -> space after comma

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -579,6 +582,63 @@ def convert_tanh(self, op):
 
         return out
 
+    def convert_range(self, op):
+        """Convert TFLite Range"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.TensorType import TensorType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                'TFlite quantized RANGE operator is not supported yet.')
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
+        expressions = []
+
+        for t in [start, limit, delta]:
+            if self.has_expr(t.tensor_idx):
+                expressions.append(self.get_expr(t.tensor_idx))
+            else:
+                tensor_type = self.get_tensor_type_str(t.tensor.Type())
+                tensor_value = self.get_tensor_value(t)
+                expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))
+

Review comment:
       use `get_tensor_expr` function from recent prs.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message