tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mas...@apache.org
Subject [incubator-tvm] branch master updated: [PYTORCH]Unary Ops (#5378)
Date Mon, 20 Apr 2020 01:49:01 GMT
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 22db299  [PYTORCH]Unary Ops (#5378)
22db299 is described below

commit 22db299b33f05570db2a5a406bdb37b57198a822
Author: Samuel <siju.samuel@huawei.com>
AuthorDate: Mon Apr 20 07:18:51 2020 +0530

    [PYTORCH]Unary Ops (#5378)
---
 python/tvm/relay/frontend/pytorch.py          |  96 +++++-------------
 tests/python/frontend/pytorch/test_forward.py | 141 ++++++++++++++++----------
 2 files changed, 114 insertions(+), 123 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9da3ecf..0ade8af 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -132,12 +132,16 @@ def _elemwise(name):
         return get_relay_op(name)(data0, data1)
     return _impl
 
-def _abs():
+
+def _unary(name):
     def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.abs(data)
+        input_type = input_types[0]
+        data = _convert_elemwise_input(inputs[0], input_type)
+
+        return get_relay_op(name)(data)
     return _impl
 
+
 def _arange():
     def _impl(inputs, input_types):
         if len(inputs) == 5:
@@ -1254,26 +1258,6 @@ def _pad():
         return _op.nn.pad(data, pad_width, pad_value)
     return _impl
 
-def _sqrt():
-    def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.tensor.sqrt(data)
-    return _impl
-
-
-def _rsqrt():
-    def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.tensor.rsqrt(data)
-    return _impl
-
-
-def _ceil():
-    def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.ceil(data)
-    return _impl
-
 
 def _clamp():
     def _impl(inputs, input_types):
@@ -1284,20 +1268,6 @@ def _clamp():
     return _impl
 
 
-def _floor():
-    def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.floor(data)
-    return _impl
-
-
-def _round():
-    def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.round(data)
-    return _impl
-
-
 def _to():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1375,17 +1345,6 @@ def _expand_as():
         return inputs[0]
     return _impl
 
-def _neg():
-    def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.tensor.negative(data)
-    return _impl
-
-def _tanh():
-    def _impl(inputs, input_types):
-        data = inputs[0]
-        return _op.tensor.tanh(data)
-    return _impl
 
 def _Bool():
     def _impl(inputs, input_types):
@@ -1467,18 +1426,6 @@ def _logical_xor():
     return _impl
 
 
-def _isfinite():
-    def _impl(inputs, input_types):
-        return _op.isfinite(inputs[0])
-    return _impl
-
-
-def _isnan():
-    def _impl(inputs, input_types):
-        return _op.isnan(inputs[0])
-    return _impl
-
-
 def _list_getitem(prelude):
     def _impl(inputs, input_types):
         return prelude.nth(inputs[0], _wrap_const(inputs[1]))
@@ -1601,7 +1548,6 @@ def _get_convert_map(prelude):
         "aten::mul"                             : _elemwise("multiply"),
         "aten::mul_"                            : _elemwise("multiply"),
         "aten::pow"                             : _elemwise("power"),
-        "aten::abs"                             : _abs(),
         "aten::arange"                          : _arange(),
         "aten::div"                             : _elemwise("divide"),
         "aten::div_"                            : _elemwise("divide"),
@@ -1683,12 +1629,26 @@ def _get_convert_map(prelude):
         "aten::argmax"                          : _reduce("argmax"),
         "aten::std"                             : _std(),
         "aten::var"                             : _variance(),
-        "aten::sqrt"                            : _sqrt(),
-        "aten::rsqrt"                           : _rsqrt(),
-        "aten::ceil"                            : _ceil(),
+        "aten::abs"                             : _unary("abs"),
+        "aten::neg"                             : _unary("negative"),
+        "aten::cos"                             : _unary("cos"),
+        "aten::sin"                             : _unary("sin"),
+        "aten::tan"                             : _unary("tan"),
+        "aten::tanh"                            : _unary("tanh"),
+        "aten::atan"                            : _unary("atan"),
+        "aten::log"                             : _unary("log"),
+        "aten::exp"                             : _unary("exp"),
+        "aten::erf"                             : _unary("erf"),
+        "aten::trunc"                           : _unary("trunc"),
+        "aten::sign"                            : _unary("sign"),
+        "aten::sqrt"                            : _unary("sqrt"),
+        "aten::rsqrt"                           : _unary("rsqrt"),
+        "aten::ceil"                            : _unary("ceil"),
+        "aten::floor"                           : _unary("floor"),
+        "aten::round"                           : _unary("round"),
+        "aten::isfinite"                        : _unary("isfinite"),
+        "aten::isnan"                           : _unary("isnan"),
         "aten::clamp"                           : _clamp(),
-        "aten::floor"                           : _floor(),
-        "aten::round"                           : _round(),
         "aten::detach"                          : _identity(),
         "aten::upsample_bilinear2d"             : _upsample("bilinear"),
         "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
@@ -1703,12 +1663,8 @@ def _get_convert_map(prelude):
         "aten::logical_xor"                     : _logical_xor(),
         "aten::bitwise_not"                     : _bitwise_not(),
         "aten::bitwise_xor"                     : _bitwise_xor(),
-        "aten::isfinite"                        : _isfinite(),
-        "aten::isnan"                           : _isnan(),
         "aten::Bool"                            : _Bool(),
         "aten::Float"                           : _Float(),
-        "aten::neg"                             : _neg(),
-        "aten::tanh"                            : _tanh(),
         "aten::adaptive_avg_pool3d"             : _adaptive_avg_pool_3d(),
         "aten::adaptive_max_pool3d"             : _adaptive_max_pool_3d(),
         "aten::mm"                              : _matmul(),
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index c692c5e..0a0e6bb 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1497,30 +1497,6 @@ def test_forward_isinf():
     verify_model(IsInf1().float().eval(), input_data=input_data)
 
 
-def test_forward_rsqrt():
-    torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
-
-    class Rsqrt1(Module):
-        def forward(self, *args):
-            return torch.rsqrt(args[0])
-
-    input_data = torch.rand(input_shape).float()
-    verify_model(Rsqrt1().float().eval(), input_data=input_data)
-
-
-def test_forward_ceil():
-    torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
-
-    class Ceil1(Module):
-        def forward(self, *args):
-            return torch.ceil(args[0])
-
-    input_data = torch.rand(input_shape).float()
-    verify_model(Ceil1().float().eval(), input_data=input_data)
-
-
 def test_forward_clamp():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -1543,30 +1519,6 @@ def test_forward_clamp():
     verify_model(Clamp3().float().eval(), input_data=input_data)
 
 
-def test_forward_floor():
-    torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
-
-    class Floor1(Module):
-        def forward(self, *args):
-            return torch.floor(args[0])
-
-    input_data = torch.rand(input_shape).float()
-    verify_model(Floor1().float().eval(), input_data=input_data)
-
-
-def test_forward_round():
-    torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
-
-    class Round1(Module):
-        def forward(self, *args):
-            return torch.round(args[0])
-
-    input_data = torch.rand(input_shape).float()
-    verify_model(Round1().float().eval(), input_data=input_data)
-
-
 def test_forward_ones():
     torch.set_grad_enabled(False)
 
@@ -1849,6 +1801,93 @@ def test_forward_logical_xor():
     verify_model(LogicalXor2().float().eval(), input_data=[lhs])
 
 
+def test_forward_unary():
+    torch.set_grad_enabled(False)
+
+    class Sqrt1(Module):
+        def forward(self, *args):
+            return torch.sqrt(args[0])
+
+    class RSqrt1(Module):
+        def forward(self, *args):
+            return torch.rsqrt(args[0])
+
+    class Ceil1(Module):
+        def forward(self, *args):
+            return torch.ceil(args[0])
+
+    class Floor1(Module):
+        def forward(self, *args):
+            return torch.floor(args[0])
+
+    class Round1(Module):
+        def forward(self, *args):
+            return torch.round(args[0])
+
+    class Cos1(Module):
+        def forward(self, *args):
+            return torch.cos(args[0])
+
+    class Sin1(Module):
+        def forward(self, *args):
+            return torch.sin(args[0])
+
+    class Tan1(Module):
+        def forward(self, *args):
+            return torch.tan(args[0])
+
+    class Tanh1(Module):
+        def forward(self, *args):
+            return torch.tanh(args[0])
+
+    class ATanh1(Module):
+        def forward(self, *args):
+            return torch.atan(args[0])
+
+    class Log1(Module):
+        def forward(self, *args):
+            return torch.log(args[0])
+
+    class Exp1(Module):
+        def forward(self, *args):
+            return torch.exp(args[0])
+
+    class Erf1(Module):
+        def forward(self, *args):
+            return torch.erf(args[0])
+
+    class Trunc1(Module):
+        def forward(self, *args):
+            return torch.trunc(args[0])
+
+    class Sign1(Module):
+        def forward(self, *args):
+            return torch.sign(args[0])
+
+    class Neg1(Module):
+        def forward(self, *args):
+            return torch.neg(args[0])
+
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(Sqrt1().float().eval(), input_data=input_data)
+    verify_model(RSqrt1().float().eval(), input_data=input_data)
+    verify_model(Ceil1().float().eval(), input_data=input_data)
+    verify_model(Floor1().float().eval(), input_data=input_data)
+    verify_model(Round1().float().eval(), input_data=input_data)
+    verify_model(Cos1().float().eval(), input_data=input_data)
+    verify_model(Sin1().float().eval(), input_data=input_data)
+    verify_model(Tan1().float().eval(), input_data=input_data)
+    verify_model(Tanh1().float().eval(), input_data=input_data)
+    verify_model(ATanh1().float().eval(), input_data=input_data)
+    verify_model(Log1().float().eval(), input_data=input_data)
+    verify_model(Exp1().float().eval(), input_data=input_data)
+    verify_model(Erf1().float().eval(), input_data=input_data)
+    verify_model(Trunc1().float().eval(), input_data=input_data)
+    verify_model(Sign1().float().eval(), input_data=input_data)
+    verify_model(Neg1().float().eval(), input_data=input_data)
+
+
 if __name__ == "__main__":
     # Single operator tests
     test_forward_add()
@@ -1907,12 +1946,8 @@ if __name__ == "__main__":
     test_forward_mean()
     test_forward_expand()
     test_forward_pow()
-    test_forward_abs()
-    test_forward_rsqrt()
-    test_forward_ceil()
+    test_forward_unary()
     test_forward_clamp()
-    test_forward_floor()
-    test_forward_round()
     test_forward_logical_not()
     test_forward_bitwise_not()
     test_forward_bitwise_xor()


Mime
View raw message