tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #5699: [Frontend][TensorFlow] Improve Control Flow and TensorArray
Date Thu, 11 Jun 2020 18:26:35 GMT

kevinthesun commented on a change in pull request #5699:
URL: https://github.com/apache/incubator-tvm/pull/5699#discussion_r438983100



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2395,29 +2410,40 @@ def _get_abs_layer_name(node):
 # 1.x.
 _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
 
-# A map to record tensor array with fixed rank shape
-_static_tensor_array_map = {}
-
-class RewriteSubgraph(ExprMutator):
-    """
-    A helper class to rewrite expr in while loop function to variable
-
-    Parameters
-    ----------
-    rewrite_map : Dict[expr, expr]
-        A dictionay contains a set of expr to var mapping.
-    """
-    def __init__(self, rewrite_map):
-        ExprMutator.__init__(self)
-        self.rewrite_map = rewrite_map
-
-    def visit(self, expr):
-        if expr in self.rewrite_map:
-            return self.rewrite_map[expr]
-        return super().visit(expr)
+# A map to record tensor array write ops and input ta/tensor indices
+# Value is (index of tensor array, index of written node)
+_tensor_array_write_ops = {
+    "TensorArrayWrite"   : (3, 2),
+    "TensorArrayScatter" : (0, 2),
+    "TensorArraySplit"   : (0, 1),
+}
 
-def rewrite_subgraph(expr, rewrites):
-    return RewriteSubgraph(rewrites).visit(expr)
+def is_tensor_array_constuctor(tf_node):
+    """Check whether is tensor array constructor node."""
+    is_ta = False
+    ta_start = "TensorArrayV"
+    if tf_node.op.startswith(ta_start):
+        try:
+            int(tf_node.op[len(ta_start)])

Review comment:
       Will use ```isnumeric``` instead

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -2395,29 +2410,40 @@ def _get_abs_layer_name(node):
 # 1.x.
 _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
 
-# A map to record tensor array with fixed rank shape
-_static_tensor_array_map = {}
-
-class RewriteSubgraph(ExprMutator):
-    """
-    A helper class to rewrite expr in while loop function to variable
-
-    Parameters
-    ----------
-    rewrite_map : Dict[expr, expr]
-        A dictionay contains a set of expr to var mapping.
-    """
-    def __init__(self, rewrite_map):
-        ExprMutator.__init__(self)
-        self.rewrite_map = rewrite_map
-
-    def visit(self, expr):
-        if expr in self.rewrite_map:
-            return self.rewrite_map[expr]
-        return super().visit(expr)
+# A map to record tensor array write ops and input ta/tensor indices
+# Value is (index of tensor array, index of written node)
+_tensor_array_write_ops = {
+    "TensorArrayWrite"   : (3, 2),
+    "TensorArrayScatter" : (0, 2),
+    "TensorArraySplit"   : (0, 1),
+}
 
-def rewrite_subgraph(expr, rewrites):
-    return RewriteSubgraph(rewrites).visit(expr)
+def is_tensor_array_constuctor(tf_node):
+    """Check whether is tensor array constructor node."""
+    is_ta = False
+    ta_start = "TensorArrayV"
+    if tf_node.op.startswith(ta_start):
+        try:
+            int(tf_node.op[len(ta_start)])

Review comment:
       Will use ```isnumeric``` instead




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message