tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From z...@apache.org
Subject [incubator-tvm] branch master updated: Fix tf parser (#5794)
Date Sat, 13 Jun 2020 03:32:56 GMT
This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a41971  Fix tf parser (#5794)
7a41971 is described below

commit 7a419718c121164fc260864014e1d0d81f556949
Author: Yao Wang <kevinthesunwy@gmail.com>
AuthorDate: Fri Jun 12 20:32:46 2020 -0700

    Fix tf parser (#5794)
---
 python/tvm/relay/frontend/tensorflow.py        | 12 ++++--------
 python/tvm/relay/frontend/tensorflow_parser.py | 10 ++++++++--
 2 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 5778b25..af09877 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1322,14 +1322,10 @@ def _shape():
 
 def _fill():
     def _impl(inputs, attr, params, mod):
-        output_shape = attr['_output_shapes'][0]
-        # Output shape must be defined to avoid errors. If any axis is not, we must
-        # try to compute its shape.
-        if output_shape is None or -1 in output_shape:
-            try:
-                output_shape = _expr.Constant(_infer_value(inputs[0], params, mod))
-            except Exception:
-                output_shape = inputs[0]
+        try:
+            output_shape = _infer_value(inputs[0], params, mod).asnumpy().tolist()
+        except Exception:
+            output_shape = inputs[0]
 
         return _op.full(inputs[1], output_shape, attr['T'].name)
     return _impl
diff --git a/python/tvm/relay/frontend/tensorflow_parser.py b/python/tvm/relay/frontend/tensorflow_parser.py
index fdbb876..771aed0 100644
--- a/python/tvm/relay/frontend/tensorflow_parser.py
+++ b/python/tvm/relay/frontend/tensorflow_parser.py
@@ -30,6 +30,10 @@ class TFParser(object):
     model_dir : tensorflow frozen pb file or a directory that contains saved
     model or checkpoints.
 
+    outputs : List of output tensor names (Optional)
+        Optional output node names. This will be protected for saved model
+        when we do remove training nodes.
+
     Examples
     --------
     .. code-block:: python
@@ -38,11 +42,12 @@ class TFParser(object):
         graphdef = parser.parse()
     """
 
-    def __init__(self, model_dir):
+    def __init__(self, model_dir, outputs=None):
         from tensorflow.core.framework import graph_pb2
         self._tmp_dir = util.tempdir()
         self._model_dir = model_dir
         self._graph = graph_pb2.GraphDef()
+        self._outputs = outputs or []
 
     def _set_graph(self, graph):
         """Set Graph"""
@@ -128,7 +133,8 @@ class TFParser(object):
             output_graph_def = graph_pb2.GraphDef()
             with open(output_graph_filename, "rb") as f:
                 output_graph_def.ParseFromString(f.read())
-            output_graph_def = graph_util.remove_training_nodes(output_graph_def)
+            output_graph_def = graph_util.remove_training_nodes(output_graph_def,
+                                                                protected_nodes=self._outputs)
             return output_graph_def
 
     def _load_ckpt(self):


Mime
View raw message