tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mas...@apache.org
Subject [incubator-tvm] branch master updated: [PYTORCH]Reduce_ops support added (#5308)
Date Mon, 13 Apr 2020 09:50:49 GMT
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 6805d54  [PYTORCH]Reduce_ops support added (#5308)
6805d54 is described below

commit 6805d54370ea657a304c58d610e5371c4add4bdf
Author: Samuel <siju.samuel@huawei.com>
AuthorDate: Mon Apr 13 15:20:10 2020 +0530

    [PYTORCH]Reduce_ops support added (#5308)
    
    * [PYTORCH]Reduce_ops support added
    
    * Review comments updated
    
    * typo bug in qnn test
---
 python/tvm/relay/frontend/pytorch.py          |  49 +++++++-
 tests/python/frontend/pytorch/qnn_test.py     |   2 +-
 tests/python/frontend/pytorch/test_forward.py | 168 ++++++++++++++++++++++++++
 3 files changed, 217 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 506f6ba..18868cf 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -934,7 +934,50 @@ def _dropout():
 def _reduce(name):
     def _impl(inputs, input_types):
         data = inputs[0]
-        return get_relay_op(name)(data)
+        axis = None
+        keepdims = False
+
+        if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False
+            if isinstance(inputs[1], int):
+                axis = int(inputs[1])
+            else:
+                axis = list(_infer_shape(inputs[1]))
+            keepdims = bool(inputs[2])
+
+        return get_relay_op(name)(data, axis=axis, keepdims=keepdims)
+
+    return _impl
+
+def _std():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        axis = list(_infer_shape(inputs[1]))
+        keepdims = bool(inputs[3])
+        unbiased = bool(inputs[2])
+
+        if unbiased:
+            msg = "Currently only supports standard-deviation calculated via the biased "\
+                  "estimator. Pytorch's Bessel's correction is not supported."
+            raise NotImplementedError(msg)
+
+        return _op.reduce.std(data, axis=axis, keepdims=keepdims)
+
+    return _impl
+
+def _variance():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        axis = list(_infer_shape(inputs[1]))
+        keepdims = bool(inputs[3])
+        unbiased = bool(inputs[2])
+
+        if unbiased:
+            msg = "Currently only supports standard-deviation calculated via the biased "\
+                  "estimator. Pytorch's Bessel's correction is not supported."
+            raise NotImplementedError(msg)
+
+        return _op.reduce.variance(data, axis=axis, keepdims=keepdims)
+
     return _impl
 
 def _mean():
@@ -1381,6 +1424,10 @@ def _get_convert_map(prelude):
         "aten::permute"                         : _transpose(prelude),
         "aten::sum"                             : _reduce("sum"),
         "aten::prod"                            : _reduce("prod"),
+        "aten::argmin"                          : _reduce("argmin"),
+        "aten::argmax"                          : _reduce("argmax"),
+        "aten::std"                             : _std(),
+        "aten::var"                             : _variance(),
         "aten::sqrt"                            : _sqrt(),
         'aten::floor'                           : _floor(),
         "aten::detach"                          : _identity(),
diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py
index 82e3393..bf5fa98 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -396,7 +396,7 @@ def test_quantized_imagenet():
         mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
         num_identical = np.sum(tvm_result == pt_result)
         pt_top3_labels = np.argsort(pt_result)[::-1][:3]
-        tvm_top3_labels = np.argsort(pt_result)[::-1][:3]
+        tvm_top3_labels = np.argsort(tvm_result)[::-1][:3]
 
         print("\nModel name: %s" % model_name)
         print("PyTorch top3 label:", pt_top3_labels)
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 8e99285..91e14c6 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1279,6 +1279,168 @@ def test_simple_rnn():
     verify_script_model(RNNLoop().eval(), [(10, 10, 4)])
 
 
+def test_forward_reduce_sum():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class ReduceSum1(Module):
+        def forward(self, *args):
+            return args[0].sum(1)
+
+    class ReduceSum2(Module):
+        def forward(self, *args):
+            return args[0].sum(dim=1, keepdim=False)
+
+    class ReduceSum3(Module):
+        def forward(self, *args):
+            return args[0].sum(dim=2, keepdim=True)
+
+    class ReduceSum4(Module):
+        def forward(self, *args):
+            return args[0].sum(dim=(2,3), keepdim=True)
+
+    class ReduceSum5(Module):
+        def forward(self, *args):
+            return args[0].sum(dim=(2,3), keepdim=False)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(ReduceSum1().float().eval(), input_data=input_data)
+    verify_model(ReduceSum2().float().eval(), input_data=input_data)
+    verify_model(ReduceSum3().float().eval(), input_data=input_data)
+    verify_model(ReduceSum4().float().eval(), input_data=input_data)
+    verify_model(ReduceSum5().float().eval(), input_data=input_data)
+
+
+def test_forward_reduce_prod():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class ReduceProd1(Module):
+        def forward(self, *args):
+            return args[0].prod(1)
+
+    class ReduceProd2(Module):
+        def forward(self, *args):
+            return args[0].prod(dim=1, keepdim=False)
+
+    class ReduceProd3(Module):
+        def forward(self, *args):
+            return args[0].prod(dim=2, keepdim=True)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(ReduceProd1().float().eval(), input_data=input_data)
+    verify_model(ReduceProd2().float().eval(), input_data=input_data)
+    verify_model(ReduceProd3().float().eval(), input_data=input_data)
+
+
+def test_forward_argmin():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class ArgMin1(Module):
+        def forward(self, *args):
+            return args[0].argmin(1)
+
+    class ArgMin2(Module):
+        def forward(self, *args):
+            return args[0].argmin(dim=1, keepdim=False)
+
+    class ArgMin3(Module):
+        def forward(self, *args):
+            return args[0].argmin(dim=2, keepdim=True)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(ArgMin1().float().eval(), input_data=input_data)
+    verify_model(ArgMin2().float().eval(), input_data=input_data)
+    verify_model(ArgMin3().float().eval(), input_data=input_data)
+
+
+def test_forward_argmax():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class ArgMax1(Module):
+        def forward(self, *args):
+            return args[0].argmax(1)
+
+    class ArgMax2(Module):
+        def forward(self, *args):
+            return args[0].argmax(dim=1, keepdim=False)
+
+    class ArgMax3(Module):
+        def forward(self, *args):
+            return args[0].argmax(dim=2, keepdim=True)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(ArgMax1().float().eval(), input_data=input_data)
+    verify_model(ArgMax2().float().eval(), input_data=input_data)
+    verify_model(ArgMax3().float().eval(), input_data=input_data)
+
+
+def test_forward_std():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Std1(Module):
+        def forward(self, *args):
+            return args[0].std(1, unbiased=False)
+
+    class Std2(Module):
+        def forward(self, *args):
+            return args[0].std(dim=1, keepdim=False, unbiased=False)
+
+    class Std3(Module):
+        def forward(self, *args):
+            return args[0].std(dim=2, keepdim=True, unbiased=False)
+
+    class Std4(Module):
+        def forward(self, *args):
+            return args[0].std(dim=(2,3), keepdim=True, unbiased=False)
+
+    class Std5(Module):
+        def forward(self, *args):
+            return args[0].std(dim=(2,3), keepdim=False, unbiased=False)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Std1().float().eval(), input_data=input_data)
+    verify_model(Std2().float().eval(), input_data=input_data)
+    verify_model(Std3().float().eval(), input_data=input_data)
+    verify_model(Std4().float().eval(), input_data=input_data)
+    verify_model(Std5().float().eval(), input_data=input_data)
+
+
+def test_forward_variance():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Variance1(Module):
+        def forward(self, *args):
+            return args[0].var(1, unbiased=False)
+
+    class Variance2(Module):
+        def forward(self, *args):
+            return args[0].var(dim=1, keepdim=False, unbiased=False)
+
+    class Variance3(Module):
+        def forward(self, *args):
+            return args[0].var(dim=2, keepdim=True, unbiased=False)
+
+    class Variance4(Module):
+        def forward(self, *args):
+            return args[0].var(dim=(2,3), keepdim=True, unbiased=False)
+
+    class Variance5(Module):
+        def forward(self, *args):
+            return args[0].var(dim=(2,3), keepdim=False, unbiased=False)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Variance1().float().eval(), input_data=input_data)
+    verify_model(Variance2().float().eval(), input_data=input_data)
+    verify_model(Variance3().float().eval(), input_data=input_data)
+    verify_model(Variance4().float().eval(), input_data=input_data)
+    verify_model(Variance5().float().eval(), input_data=input_data)
+
+
 if __name__ == "__main__":
     # Single operator tests
     test_forward_add()
@@ -1291,6 +1453,12 @@ if __name__ == "__main__":
     test_forward_squeeze()
     test_forward_unsqueeze()
     test_forward_concatenate()
+    test_forward_reduce_sum()
+    test_forward_reduce_prod()
+    test_forward_argmin()
+    test_forward_argmax()
+    test_forward_std()
+    test_forward_variance()
     test_forward_relu()
     test_forward_prelu()
     test_forward_leakyrelu()


Mime
View raw message