tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [FRONTEND][MXNET] Use leaky by default for LeakyReLU (#5192)
Date Wed, 01 Apr 2020 22:49:46 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f41a39  [FRONTEND][MXNET] Use leaky by default for LeakyReLU (#5192)
2f41a39 is described below

commit 2f41a39688bf5fe2f18d8481f9ae012fb6a05614
Author: MORITA Kazutaka <morita.kazutaka@gmail.com>
AuthorDate: Thu Apr 2 07:49:37 2020 +0900

    [FRONTEND][MXNET] Use leaky by default for LeakyReLU (#5192)
---
 python/tvm/relay/frontend/mxnet.py          |  2 +-
 tests/python/frontend/mxnet/test_forward.py | 11 ++++++++++-
 2 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index b918f9b..5c8e726 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -510,7 +510,7 @@ def _mx_pad(inputs, attrs):
                       pad_mode=pad_mode)
 
 def _mx_leaky_relu(inputs, attrs):
-    act_type = attrs.get_str("act_type")
+    act_type = attrs.get_str("act_type", "leaky")
     if act_type == "leaky":
         return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope", 0.25))
     if act_type == "prelu":
diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py
index 102905a..f015447 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -107,6 +107,14 @@ def test_forward_resnet():
         mx_sym = model_zoo.mx_resnet(18)
         verify_mxnet_frontend_impl(mx_sym)
 
+def test_forward_leaky_relu():
+    data = mx.sym.var('data')
+    data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
+    mx_sym = mx.sym.LeakyReLU(data)
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+    mx_sym = mx.sym.LeakyReLU(data, act_type='leaky')
+    verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
 def test_forward_elu():
     data = mx.sym.var('data')
     data = mx.sym.concat(data, -data, dim=1)  # negative part explicitly
@@ -979,6 +987,7 @@ if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
     test_forward_resnet()
+    test_forward_leaky_relu()
     test_forward_elu()
     test_forward_rrelu()
     test_forward_prelu()
@@ -1030,4 +1039,4 @@ if __name__ == '__main__':
     test_forward_deconvolution()
     test_forward_cond()
     test_forward_make_loss()
-    test_forward_unravel_index()
\ No newline at end of file
+    test_forward_unravel_index()


Mime
View raw message