tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mas...@apache.org
Subject [incubator-tvm] branch master updated: support aten::type_as in the pytorch frontend (#5787)
Date Sat, 13 Jun 2020 04:52:57 GMT
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 59f5cbe  support aten::type_as in the pytorch frontend (#5787)
59f5cbe is described below

commit 59f5cbe921cf329febcd9d6eff2df94d80f1c523
Author: Rand Xie <randxiexyy29@gmail.com>
AuthorDate: Fri Jun 12 21:52:45 2020 -0700

    support aten::type_as in the pytorch frontend (#5787)
    
    * support aten::type_as in the pytorch frontend
    
    * use _convert_data_type to convert torch type to tvm type and add more types in the type_as
test
---
 python/tvm/relay/frontend/pytorch.py          |  9 +++++++
 tests/python/frontend/pytorch/test_forward.py | 37 +++++++++++++++++++++++++++
 2 files changed, 46 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index a9f4a7b..d2451cd 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1645,6 +1645,14 @@ def _list_len(prelude):
     return _impl
 
 
+def _type_as():
+    def _impl(inputs, input_types):
+        assert len(inputs) == 2
+        assert len(input_types) == 2
+        return _op.cast(inputs[0], _convert_data_type(input_types[1]))
+    return _impl
+
+
 def _add(prelude):
     # add_ is overloaded for tensor add and list concat
     def _impl(inputs, input_types):
@@ -1953,6 +1961,7 @@ def _get_convert_map(prelude):
         "aten::stack"                           : _tensor_array_stack(prelude),
         "aten::__getitem__"                     : _list_getitem(prelude),
         "aten::len"                             : _list_len(prelude),
+        "aten::type_as"                         : _type_as(),
     }
     return convert_map
 
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 86fb409..f8fb57f 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -27,6 +27,7 @@ import torchvision
 
 from tvm import relay
 from tvm.contrib import graph_runtime
+from tvm.contrib.nvcc import have_fp16
 from tvm.relay.testing.config import ctx_list
 
 
@@ -837,6 +838,41 @@ def test_forward_size():
     input_data = torch.rand(input_shape).float()
     verify_model(Size1().float().eval(), input_data=input_data)
 
+
+def test_type_as():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3]
+    
+    def _create_module(dtype):
+        class TypeAs(Module):
+            def forward(self, *args):
+                expected_type_tensor = torch.zeros(1, 3, dtype=dtype)
+                return args[0].type_as(expected_type_tensor)
+        
+        return TypeAs()
+
+    input_data = torch.randn(input_shape).float()
+    verify_model(_create_module(torch.float64), input_data=input_data)
+    verify_model(_create_module(torch.float32), input_data=input_data)
+    verify_model(_create_module(torch.int64), input_data=input_data)
+    verify_model(_create_module(torch.int32), input_data=input_data)
+    verify_model(_create_module(torch.int16), input_data=input_data)
+    verify_model(_create_module(torch.int8), input_data=input_data)
+
+    if torch.cuda.is_available():
+        check_fp16 = False
+        try:
+            # Only check half precision on supported hardwares.
+            if have_fp16(tvm.gpu(0).compute_version):
+                check_fp16 = True
+        except Exception as e:
+            # If GPU is not enabled in TVM, skip the fp16 test.
+            pass
+        
+        if check_fp16:
+            verify_model(_create_module(torch.float16), input_data=input_data)
+
+
 def test_forward_view():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -2575,6 +2611,7 @@ if __name__ == "__main__":
     test_upsample()
     test_forward_upsample3d()
     test_to()
+    test_type_as()
     test_forward_functional_pad()
     test_forward_zero_pad2d()
     test_forward_constant_pad1d()


Mime
View raw message