tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mas...@apache.org
Subject [incubator-tvm] branch master updated: Don't multiply by constant 1 uselessly in dense (#5911)
Date Wed, 24 Jun 2020 11:49:56 GMT
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 11815b8  Don't multiply by constant 1 uselessly in dense (#5911)
11815b8 is described below

commit 11815b8d8fd9255e2d5ea1fc9ada98222228d462
Author: Thomas Viehmann <tv.code@beamnet.de>
AuthorDate: Wed Jun 24 13:49:43 2020 +0200

    Don't multiply by constant 1 uselessly in dense (#5911)
---
 python/tvm/relay/frontend/pytorch.py          |  4 ++--
 tests/python/frontend/pytorch/test_forward.py | 19 +++++++++++++++++++
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9237303..84b0907 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -995,11 +995,11 @@ def _dense():
         beta = inputs[3]
         alpha = inputs[4]
 
-        if not isinstance(alpha, _expr.Expr):
+        if not isinstance(alpha, _expr.Expr) and alpha != 1:
             alpha = _create_typed_const(alpha, data_type)
             data *= alpha
 
-        if not isinstance(beta, _expr.Expr):
+        if not isinstance(beta, _expr.Expr) and beta != 1:
             beta = _create_typed_const(beta, data_type)
             weight *= beta
 
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 12d1260..0694fa5 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -33,6 +33,18 @@ from tvm.relay.testing.config import ctx_list
 
 sys.setrecursionlimit(10000)
 
+def list_ops(expr):
+    class OpLister(tvm.relay.ExprVisitor):
+        def visit_op(self, expr):
+            if expr not in self.node_set:
+                self.node_list.append(expr)
+            return super().visit_op(expr)
+        def list_nodes(self, expr):
+            self.node_set = {}
+            self.node_list = []
+            self.visit(expr)
+            return self.node_list
+    return OpLister().list_nodes(expr)
 
 def assert_shapes_match(tru, est):
     if tru.shape != est.shape:
@@ -1047,6 +1059,13 @@ def test_forward_dense():
     verify_model(Dense1().float().eval(), input_data=input_data)
     verify_model(Dense2().float().eval(), input_data=input_data)
 
+    trace = torch.jit.trace(Dense1(), [input_data])
+    mod, params = relay.frontend.from_pytorch(
+        trace,
+        [('input', input_shape)],
+    )
+    assert not any([op.name == "multiply" for op in list_ops(mod['main'])])
+
 def test_forward_dropout():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]


Mime
View raw message