tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mas...@apache.org
Subject [incubator-tvm] branch master updated: [PYTORCH]where, addcdiv, addcmul op support (#5383)
Date Fri, 24 Apr 2020 10:49:35 GMT
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new ba38222  [PYTORCH]where, addcdiv, addcmul op support (#5383)
ba38222 is described below

commit ba38222990feb7c2dbb18bd2e23ae7551d440fd3
Author: Samuel <siju.samuel@huawei.com>
AuthorDate: Fri Apr 24 16:19:26 2020 +0530

    [PYTORCH]where, addcdiv, addcmul op support (#5383)
    
    * [PYTORCH]Where, addcdiv, addcmul op support
    
    * Review comments fixed
---
 python/tvm/relay/frontend/pytorch.py          | 72 +++++++++++++++------------
 tests/python/frontend/pytorch/test_forward.py | 69 +++++++++++++++++++++++++
 2 files changed, 110 insertions(+), 31 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 0ade8af..a8eb9c4 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -279,15 +279,7 @@ def _select():
 def _take():
     def _impl(inputs, input_types):
         data = inputs[0]
-        import torch
-
-        if isinstance(inputs[1], _expr.Var):
-            indices = _op.cast(inputs[1], "int32")
-        elif isinstance(inputs[1], torch.Tensor):
-            indices = _wrap_const(inputs[1].numpy())
-        else:
-            msg = "Data type %s could not be parsed in take operator." % (type(inputs[1]))
-            raise AssertionError(msg)
+        indices = _op.cast(inputs[1], "int32")
 
         return _op.transform.take(data, indices=indices)
     return _impl
@@ -337,6 +329,40 @@ def _repeat_interleave():
         return _op.transform.repeat(data, repeats=repeats, axis=axis)
     return _impl
 
+
+def _addcdiv():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        c = _expr.const(inputs[3])
+        t1 = inputs[1]
+        t2 = inputs[2]
+
+        return data + (c * (t1 / t2))
+    return _impl
+
+
+def _addcmul():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        c = _expr.const(inputs[3])
+        t1 = inputs[1]
+        t2 = inputs[2]
+
+        return data + (c * (t1 * t2))
+    return _impl
+
+
+def _where():
+    def _impl(inputs, input_types):
+        cond = inputs[0]
+        x = inputs[1]
+        y = inputs[2]
+
+        return _op.where(cond, x, y)
+
+    return _impl
+
+
 def _ones():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1382,16 +1408,7 @@ def _bitwise_not():
 def _bitwise_xor():
     def _impl(inputs, input_types):
         lhs = inputs[0]
-
-        import torch
-        if isinstance(inputs[1], _expr.Var):
-            rhs = inputs[1]
-        elif isinstance(inputs[1], torch.Tensor):
-            rhs = _wrap_const(inputs[1].numpy())
-        else:
-            msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1]))
-            raise AssertionError(msg)
-
+        rhs = inputs[1]
         lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
         rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")
 
@@ -1410,17 +1427,7 @@ def _logical_not():
 def _logical_xor():
     def _impl(inputs, input_types):
         lhs = _op.cast(inputs[0], "bool")
-
-        import torch
-        if isinstance(inputs[1], _expr.Var):
-            rhs = inputs[1]
-        elif isinstance(inputs[1], torch.Tensor):
-            rhs = _wrap_const(inputs[1].numpy())
-        else:
-            msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1]))
-            raise AssertionError(msg)
-
-        rhs = _op.cast(rhs, "bool")
+        rhs = _op.cast(inputs[1], "bool")
 
         return _op.logical_xor(lhs, rhs)
     return _impl
@@ -1551,6 +1558,8 @@ def _get_convert_map(prelude):
         "aten::arange"                          : _arange(),
         "aten::div"                             : _elemwise("divide"),
         "aten::div_"                            : _elemwise("divide"),
+        "aten::addcdiv"                         : _addcdiv(),
+        "aten::addcmul"                         : _addcmul(),
         "aten::ones"                            : _ones(),
         "aten::ones_like"                       : _ones_like(),
         "aten::zeros"                           : _zeros(),
@@ -1570,6 +1579,7 @@ def _get_convert_map(prelude):
         "aten::split_with_sizes"                : _split_with_sizes(),
         "aten::select"                          : _select(),
         "aten::take"                            : _take(),
+        "aten::where"                           : _where(),
         "aten::topk"                            : _topk(),
         "aten::relu"                            : _relu(),
         "aten::relu_"                           : _relu(),
@@ -1832,7 +1842,7 @@ def _get_constant(node):
             tensor = node.t(attr_name)
             if len(tensor.shape) == 0:  # tensor(0.1)
                 return float(tensor)
-            return tensor
+            return _wrap_const(tensor.numpy())
         elif ty == "DeviceObjType":
             return node.s(attr_name)
         elif ty == "FunctionType":
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 0a0e6bb..91c2661 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1888,6 +1888,72 @@ def test_forward_unary():
     verify_model(Neg1().float().eval(), input_data=input_data)
 
 
+def test_forward_where():
+    torch.set_grad_enabled(False)
+
+    class Where1(Module):
+        def forward(self, *args):
+            y = torch.ones([3, 2])
+            if torch.cuda.is_available():
+                y = y.cuda()
+            return torch.where(args[0] > 0, args[0], y)
+
+    class Where2(Module):
+        def forward(self, *args):
+            return torch.where(args[0] > 0, args[0], args[1])
+
+    x = torch.rand([3, 2]).float()
+    verify_model(Where1().float().eval(), input_data=[x])
+    y = torch.rand([3, 2])
+    verify_model(Where2().float().eval(), input_data=[x, y])
+
+
+def test_forward_addcdiv():
+    torch.set_grad_enabled(False)
+
+    class Addcdiv1(Module):
+        def forward(self, *args):
+            t1 = torch.ones([3, 1])
+            t2 = torch.ones([1, 3])
+            if torch.cuda.is_available():
+                t1 = t1.cuda()
+                t2 = t2.cuda()
+            return torch.addcdiv(args[0], 0.1, t1, t2)
+
+    class Addcdiv2(Module):
+        def forward(self, *args):
+            return torch.addcdiv(args[0], 0.5, args[1], args[2])
+
+    input_data = torch.rand([1, 3]).float()
+    verify_model(Addcdiv1().float().eval(), input_data=input_data)
+    t1 = torch.rand([3, 1]).float()
+    t2 = torch.rand([1, 3]).float()
+    verify_model(Addcdiv2().float().eval(), input_data=[input_data, t1, t2])
+
+
+def test_forward_addcmul():
+    torch.set_grad_enabled(False)
+
+    class Addcmul1(Module):
+        def forward(self, *args):
+            t1 = torch.ones([3, 1])
+            t2 = torch.ones([1, 3])
+            if torch.cuda.is_available():
+                t1 = t1.cuda()
+                t2 = t2.cuda()
+            return torch.addcmul(args[0], 0.1, t1, t2)
+
+    class Addcmul2(Module):
+        def forward(self, *args):
+            return torch.addcmul(args[0], 0.5, args[1], args[2])
+
+    input_data = torch.rand([1, 3]).float()
+    verify_model(Addcmul1().float().eval(), input_data=input_data)
+    t1 = torch.rand([3, 1]).float()
+    t2 = torch.rand([1, 3]).float()
+    verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
+
+
 if __name__ == "__main__":
     # Single operator tests
     test_forward_add()
@@ -1933,6 +1999,9 @@ if __name__ == "__main__":
     test_forward_select()
     test_forward_take()
     test_forward_topk()
+    test_forward_where()
+    test_forward_addcdiv()
+    test_forward_addcmul()
     test_forward_clone()
     test_forward_softplus()
     test_forward_softsign()


Mime
View raw message