tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: Tf2 test fixups (#5391)
Date Tue, 21 Apr 2020 14:49:48 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 72f2aea  Tf2 test fixups (#5391)
72f2aea is described below

commit 72f2aea2dd219bf55c15b3cf4cfc21491f1f60dd
Author: Ramana Radhakrishnan <ramana.radhakrishnan@arm.com>
AuthorDate: Tue Apr 21 15:49:41 2020 +0100

    Tf2 test fixups (#5391)
    
    * Fix oversight in importing tf.compat.v1 as tf.
    
    * Actually disable test for lstm in TF2.1
    
    Since the testing framework actually uses pytest, the version
    check needs to be moved.
---
 tests/python/frontend/tensorflow/test_bn_dynamic.py | 5 ++++-
 tests/python/frontend/tensorflow/test_forward.py    | 8 ++++----
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py
index 4be838e..a2d6903 100644
--- a/tests/python/frontend/tensorflow/test_bn_dynamic.py
+++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py
@@ -22,7 +22,10 @@ in TensorFlow frontend when mean and variance are not given.
 """
 import tvm
 import numpy as np
-import tensorflow as tf
+try:
+    import tensorflow.compat.v1 as tf
+except ImportError:
+    import tensorflow as tf
 from tvm import relay
 from tensorflow.python.framework import graph_util
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index bc884bb..93501f1 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1901,7 +1901,9 @@ def _test_lstm_cell(batch_size, num_hidden, num_layers, forget_bias,
dtype):
 
 def test_forward_lstm():
     '''test LSTM block cell'''
-    _test_lstm_cell(1, 2, 1, 0.5, 'float32')
+    if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
+        #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
+        _test_lstm_cell(1, 2, 1, 0.5, 'float32')
 
 
 #######################################################################
@@ -3308,9 +3310,7 @@ if __name__ == '__main__':
     test_forward_ptb()
 
     # RNN
-    if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
-        #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
-        test_forward_lstm()
+    test_forward_lstm()
 
     # Elementwise
     test_forward_ceil()


Mime
View raw message