tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch v0.6 updated: [BACKPORT-0.6] Add ConstantNode to IsAtomic (#5831)
Date Wed, 17 Jun 2020 06:15:47 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch v0.6
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/v0.6 by this push:
     new ce56d5b  [BACKPORT-0.6] Add ConstantNode to IsAtomic (#5831)
ce56d5b is described below

commit ce56d5b8d335f78f755640d18dc41222d83b3c1b
Author: Zhi <5145158+zhiics@users.noreply.github.com>
AuthorDate: Tue Jun 16 23:15:36 2020 -0700

    [BACKPORT-0.6] Add ConstantNode to IsAtomic (#5831)
    
    * [Fix] Add ConstantNode to IsAtomic (#5457)
    
    * add constantnode to atomic
    
    * Add ToANormalForm to FoldConstant
    
    * fix test
---
 src/relay/pass/fold_constant.cc               |  1 +
 tests/python/relay/test_pass_fold_constant.py | 19 +++++++++++++++++++
 2 files changed, 20 insertions(+)

diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc
index b034c56..5d3ec03 100644
--- a/src/relay/pass/fold_constant.cc
+++ b/src/relay/pass/fold_constant.cc
@@ -184,6 +184,7 @@ class ConstantFolder : public ExprMutator {
   // Constant evaluate a expression.
   Expr ConstEvaluate(Expr expr) {
     std::vector<transform::Pass> passes = {transform::FuseOps(0),
+                                           transform::ToANormalForm(),
                                            transform::InferType()};
     Function func;
     if (expr.as<FunctionNode>()) {
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index 4752597..5167ff2 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -29,6 +29,25 @@ def run_opt_pass(expr, opt_pass):
     return entry if isinstance(expr, relay.Function) else entry.body
 
 
+def test_concatenate_const():
+    def before():
+        data = tvm.nd.array(np.array([1.0, 2.0, 3.0]))
+        const = relay.const(data)
+        concat = relay.op.concatenate([const, const], axis=0)
+        func = relay.Function([], concat)
+        return func
+
+    def expected():
+        data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
+        const = relay.const(data)
+        func = relay.Function([], const)
+        return func
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    assert relay.analysis.graph_equal(zz, zexpected)
+
+
 def test_fold_const():
     c_data = np.array([1, 2, 3]).astype("float32")
     t = relay.TensorType([1, 2, 3], "float32")


Mime
View raw message